[PS] BOJ 30641 / 회문 끝말잇기
문제 링크: https://www.acmicpc.net/problem/30641
Thumbnail: Photo by Sophia Richards (Unsplash)
수학 문제는 늘 머리가 아픕니다..
풀이
거듭제곱 속도 개선하기
회문의 길이는 최대 \(10^6\)이므로, 등비수열 꼴의 항을 점화식으로 (반복문으로 직접 곱해가며) 계산하게 되면 상당한 수의 거듭제곱 연산(정확히는 같은 수를 여러번 곱하는 과정)이 수행된다. 이는 분명 비효율적이므로, 개선할 수 있는 방법을 찾아보자.
\(a^n\)꼴의 수를 계산하려면, 단순 곱셈으로는 \(a\)에 \(a\)를 \(n-1\)번 곱해야 한다. 하지만 이 방식의 시간 복잡도는 \(O(N)\)으로, 이번 문제처럼 \(N\)이 큰 경우에는 당연히 비효율적이다.
하지만, '빠른 거듭제곱 연산' 을 사용하면 시간 복잡도를 \(O(\log{N})\)꼴로 개선할 수 있다.
원리는 간단하다. 지수를 2진수로 표현했을 때, 각 자리의 비트에 따라 계산을 달리한다.
- 비트가 0인 자리: 제곱만 한다.
- 비트가 1인 자리: 제곱하고 곱한다.
def fast_square(a, n, m):
"""a^n mod m 꼴의 계산을 square&multiply 방식으로 빠르게 계산한다.
"""
res = 1
while n > 0:
if n & 1 == 1: # 지수의 LSB가 1이라면 곱한다
res = (a * res) % m
a = (a * a) % m # 지수를 제곱한다.
n = n >> 1
return res회문 길이에 따른 개수 구하기
회문의 길이가 $$1, 2, \cdots, N$$일때 각각의 경우 회문의 개수는 다음과 같습니다.
| 회문의 길이 | 1 | 2 | 3 | 4 | 5 | ... | N |
|---|---|---|---|---|---|---|---|
| 회문의 개수 | 1 | 1 | 26 | 26 | 26^2 | ... | 26^((N+1) / 2 - 1) |
회문의 길이에 따른 회문의 개수는 등비수열 꼴로 나타남을 확인할 수 있습니다. (단, 각 항이 두번씩 연달아 나타납니다.)
등비수열의 합 공식을 쓰자니, 아무래도 회문의 길이가 최대 10^6이니 너무 큰 수가 나오며 나눗셈으로 인해 모듈로 연산을 분배할 수 없습니다! 따라서 반복문을 사용해 점화식으로 계산하는 쪽을 선택했습니다.
길이가 \(N\)인 회문의 개수를 \(a_{N}\)이라 할 때, \(a_N = 26^{(N+1) \div 2 - 1\)입니다.
이를 이용해 \(\sum_{i=L}^{N} a_i)\)를 계산해주면 됩니다.
def sum():
"""길이가 L이상 U이하인 모든 회문의 개수를 센다."""
res = 0
for i in range(L, U + 1):
if i <= 2:
res += 1
else:
res = (res % MOD + fast_square((i - 1) // 2) % MOD) % MOD
return res회문 개수를 계산하기
누가 이겼나요?
계산 과정에서 계속해서 \(1,000,000,007\)로 \(\mod\)연산을 진행하므로, 계산된 결과로는 누가 이겼는지 찾을 수 없습니다. 따라서 \(L\)과 \(U\)값만 가지고 누가 이겼는지 미리 계산해야 합니다.
앞서 정리했던 회문의 길이에 따른 회문의 개수를 보면, 길이 \(N\)이 3 이상인 경우는 항상 26의 거듭제곱 꼴로 나타내어지므로 언제나 짝수입니다.
그렇지만, \(N \le 2\)인 경우는 1로 나타내어지므로, \(N \le 2\)인 항이 홀수 개 들어가는 경우에만 회문의 개수의 합이 홀수가 됩니다.
- \(L = 2\)이고 \(U \ge 2\)인 경우
- \(L = U = 1\)인 경우
| 회문의 길이 | 1 | 2 | 3 | 4 | 5 | ... | N |
|---|---|---|---|---|---|---|---|
| 회문의 개수 | 1 | 1 | 26 | 26 | 26^2 | ... | 26^((N+1)//2 - 1) |
항상 호영이부터 시작하므로, 가능한 회문의 개수가 홀수라면 호영이가, 짝수라면 아란이가 승리합니다.
isH = (L == 1 and U == 1) or (L == 2 and U >= 2)
print("H" if isH else "A")누가 이길까?
전체 코드
input = open(0).readline
L, U = map(int, input().strip().split())
MOD = 1_000_000_007
def fast_square(x):
"""26^x mod 1,000,000,007 꼴의 계산을 square&multiply 방식으로 빠르게 계산한다.
a^x에서, x를 2진수로 나타냈을 때 각 자리의 비트값이 1일때만 그 자리에 해당하는 지수의 거듭제곱 꼴을 곱해주는 방식이다
ex) a^9 = a^(1001)2 = a^8 * 1 * 1 * a^1
"""
if x == 0:
return 1
elif x == 1:
return 26
res = 1
a = 26
while x > 0:
if x & 1 == 1: # 지수의 LSB가 1이라면 곱한다
res = (a * res) % MOD
a = (a * a) % MOD # 지수를 제곱한다.
x = x >> 1
return res
def sum():
res = 0
for i in range(L, U + 1):
if i <= 2:
res += 1
else:
res = (res % MOD + fast_square((i - 1) // 2) % MOD) % MOD
return res
isH = (L == 1 and U == 1) or (L == 2 and U >= 2)
print("H" if isH else "A")
print(sum())solution.py
시행착오들...
유난히 많이 시도해보고 많이 틀렸던 문제였습니다! 돌이켜보면 모듈로 연산에 대한 이해가 부족했던 게 원인이었는데, 다시금 공부하겠습니다...
문제였던 부분은 sum함수에서 회문의 개수를 계산할 때, 모듈로 덧셈의 분배법칙을 제대로 구현하지 않았던 부분이었습니다.
def fast_square(x):
if x == 0:
return 1
elif x == 1:
return 26
res = 1
a = 26
while x > 0:
if x & 1 == 1: # 지수의 LSB가 1이라면 곱한다
res = (a * res) % MOD
a = (a * a) % MOD # 지수를 제곱한다.
x = x >> 1
return res
def sum():
res = 0
for i in range(L, U + 1):
if i <= 2:
res += 1
else:
res = (res % MOD + fast_square((i - 1) // 2)) % MOD
return res모듈로 연산의 경우, 덧셈/뺄셈/곱셈에 대해 다음과 같은 방식으로 분배법칙이 적용됩니다.
- 덧셈: \((a + b) \mod M = (a \mod M + b \mod M) \mod M
- 뺄셈: \((a - b) \mod M = (a \mod M - b \mod M) \mod M
- 곱셈: \((a \times b) \mod M = (a \mod M \times b \mod M) \mod M
즉, sum 함수의 else구문에서 fast_square을 이전 res값과 더할 때, res와 fast_square 둘 다 모듈로 연산을 수행해야 합니다.