[PS] BOJ 30641 / 회문 끝말잇기

[PS] BOJ 30641 / 회문 끝말잇기
문제 링크: https://www.acmicpc.net/problem/30641
Thumbnail: Photo by Sophia Richards (Unsplash)

수학 문제는 늘 머리가 아픕니다..

풀이

거듭제곱 속도 개선하기

회문의 길이는 최대 \(10^6\)이므로, 등비수열 꼴의 항을 점화식으로 (반복문으로 직접 곱해가며) 계산하게 되면 상당한 수의 거듭제곱 연산(정확히는 같은 수를 여러번 곱하는 과정)이 수행된다. 이는 분명 비효율적이므로, 개선할 수 있는 방법을 찾아보자.

\(a^n\)꼴의 수를 계산하려면, 단순 곱셈으로는 \(a\)에 \(a\)를 \(n-1\)번 곱해야 한다. 하지만 이 방식의 시간 복잡도는 \(O(N)\)으로, 이번 문제처럼 \(N\)이 큰 경우에는 당연히 비효율적이다.

하지만, '빠른 거듭제곱 연산' 을 사용하면 시간 복잡도를 \(O(\log{N})\)꼴로 개선할 수 있다.

원리는 간단하다. 지수를 2진수로 표현했을 때, 각 자리의 비트에 따라 계산을 달리한다.

  • 비트가 0인 자리: 제곱만 한다.
  • 비트가 1인 자리: 제곱하고 곱한다.
def fast_square(a, n, m):
    """a^n mod m 꼴의 계산을 square&multiply 방식으로 빠르게 계산한다.
    """
    res = 1
    while n > 0:
        if n & 1 == 1:  # 지수의 LSB가 1이라면 곱한다
            res = (a * res) % m
        a = (a * a) % m # 지수를 제곱한다.
        n = n >> 1
    return res

회문 길이에 따른 개수 구하기

회문의 길이가 $$1, 2, \cdots, N$$일때 각각의 경우 회문의 개수는 다음과 같습니다.

회문의 길이 1 2 3 4 5 ... N
회문의 개수 1 1 26 26 26^2 ... 26^((N+1) / 2 - 1)

회문의 길이에 따른 회문의 개수는 등비수열 꼴로 나타남을 확인할 수 있습니다. (단, 각 항이 두번씩 연달아 나타납니다.)

등비수열의 합 공식을 쓰자니, 아무래도 회문의 길이가 최대 10^6이니 너무 큰 수가 나오며 나눗셈으로 인해 모듈로 연산을 분배할 수 없습니다! 따라서 반복문을 사용해 점화식으로 계산하는 쪽을 선택했습니다.

길이가 \(N\)인 회문의 개수를 \(a_{N}\)이라 할 때, \(a_N = 26^{(N+1) \div 2 - 1\)입니다.

이를 이용해 \(\sum_{i=L}^{N} a_i)\)를 계산해주면 됩니다.

def sum():
    """길이가 L이상 U이하인 모든 회문의 개수를 센다."""
    res = 0
    for i in range(L, U + 1):
        if i <= 2:
            res += 1
        else:
            res = (res % MOD + fast_square((i - 1) // 2) % MOD) % MOD
    return res

회문 개수를 계산하기

누가 이겼나요?

​계산 과정에서 계속해서 \(1,000,000,007\)로 \(\mod\)연산을 진행하므로, 계산된 결과로는 누가 이겼는지 찾을 수 없습니다. 따라서 \(L\)과 \(U\)값만 가지고 누가 이겼는지 미리 계산해야 합니다.

앞서 정리했던 회문의 길이에 따른 회문의 개수를 보면, 길이 \(N\)이 3 이상인 경우는 항상 26의 거듭제곱 꼴로 나타내어지므로 언제나 짝수입니다.

그렇지만, \(N \le 2\)인 경우는 1로 나타내어지므로, \(N \le 2\)인 항이 홀수 개 들어가는 경우에만 회문의 개수의 합이 홀수가 됩니다.

  • \(L = 2\)이고 \(U \ge 2\)인 경우
  • \(L = U = 1\)인 경우
회문의 길이 1 2 3 4 5 ... N
회문의 개수 1 1 26 26 26^2 ... 26^((N+1)//2 - 1)

항상 호영이부터 시작하므로, 가능한 회문의 개수가 홀수라면 호영이가, 짝수라면 아란이가 승리합니다.

isH = (L == 1 and U == 1) or (L == 2 and U >= 2)
print("H" if isH else "A")

누가 이길까?

전체 코드

input = open(0).readline
L, U = map(int, input().strip().split())
MOD = 1_000_000_007

def fast_square(x):
    """26^x mod 1,000,000,007 꼴의 계산을 square&multiply 방식으로 빠르게 계산한다.
    a^x에서, x를 2진수로 나타냈을 때 각 자리의 비트값이 1일때만 그 자리에 해당하는 지수의 거듭제곱 꼴을 곱해주는 방식이다
    ex) a^9 = a^(1001)2 = a^8 * 1 * 1 * a^1
    """
    if x == 0:
        return 1
    elif x == 1:
        return 26
    
    res = 1
    a = 26
    while x > 0:
        if x & 1 == 1:  # 지수의 LSB가 1이라면 곱한다
            res = (a * res) % MOD
        a = (a * a) % MOD # 지수를 제곱한다.
        x = x >> 1
    return res

def sum():
    res = 0
    for i in range(L, U + 1):
        if i <= 2:
            res += 1
        else:
            res = (res % MOD + fast_square((i - 1) // 2) % MOD) % MOD
    return res

isH = (L == 1 and U == 1) or (L == 2 and U >= 2)
print("H" if isH else "A")
print(sum())

solution.py

시행착오들...

유난히 많이 시도해보고 많이 틀렸던 문제였습니다! 돌이켜보면 모듈로 연산에 대한 이해가 부족했던 게 원인이었는데, 다시금 공부하겠습니다...

문제였던 부분은 sum함수에서 회문의 개수를 계산할 때, 모듈로 덧셈의 분배법칙을 제대로 구현하지 않았던 부분이었습니다.

def fast_square(x):
    if x == 0:
        return 1
    elif x == 1:
        return 26
    
    res = 1
    a = 26
    while x > 0:
        if x & 1 == 1:  # 지수의 LSB가 1이라면 곱한다
            res = (a * res) % MOD
        a = (a * a) % MOD # 지수를 제곱한다.
        x = x >> 1
    return res

def sum():
    res = 0
    for i in range(L, U + 1):
        if i <= 2:
            res += 1
        else:
            res = (res % MOD + fast_square((i - 1) // 2)) % MOD
    return res

모듈로 연산의 경우, 덧셈/뺄셈/곱셈에 대해 다음과 같은 방식으로 분배법칙이 적용됩니다.

  • 덧셈: \((a + b) \mod M = (a \mod M + b \mod M) \mod M
  • 뺄셈: \((a - b) \mod M = (a \mod M - b \mod M) \mod M
  • 곱셈: \((a \times b) \mod M = (a \mod M \times b \mod M) \mod M

즉, sum 함수의 else구문에서 fast_square을 이전 res값과 더할 때,
resfast_square 둘 다 모듈로 연산을 수행해야 합니다.