[PS] BOJ 2118 / 두 개의 탑

[PS] BOJ 2118 / 두 개의 탑
문제 링크: https://www.acmicpc.net/problem/2118
Thumbnail: Photo by Quentin BASNIER (Unsplash)

투 포인터인데 누적합입니다.

풀이

최초 구상

\(N\)개의 지점이 원형으로 이어진 상태에서, 임의의 두 지점을 골라 탑을 배치할 때 두 탑의 최대 거리를 구하는 문제입니다. 이때, 두 지점 사이의 거리는 시계 방향/반시계 방향 거리 중 짧은 쪽으로 계산합니다.

입력으로 주어지는 거리 배열은 다음과 같습니다:

1번 지점과 2번 지점 사이의 거리,
2번 지점과 3번 지점 사이의 거리,
...
N번 지점과 1번 지점 사이의 거리

DISTANCE 배열 구성

두 지점 사이의 거리는 시계방향과 반시계방향 거리 중 짧은 쪽을 사용하므로, 이를 함수로 구현했습니다.

def calculate_distance(s, e):
    """시작점 s에서 끝점 e까지의 거리 계산"""
    return min(sum(DISTANCE[s:e % N]), sum(DISTANCE[e % N:] + DISTANCE[:s]))

그리고, 시작점을 1번 지점부터 N번 지점까지 바꿔가며, 다른 모든 지점에 대해 거리를 계산해 최대 거리를 계산하려 했습니다.

input = open(0).readline
N = int(input())
DISTANCE = [int(input()) for _ in range(N)] # 1과 2 사이의 거리, 2와 3 사이의 거리, ..., N과 1 사이의 거리 (원형 구조)

def calculate_distance(s, e):
    """시작점 s에서 끝점 e까지의 거리 계산"""
    return min(sum(DISTANCE[s:e % N]), sum(DISTANCE[e % N:] + DISTANCE[:s]))

max_distance = 0
for s in range(N):
    for e in range(s + 1, s + N):
        if e - s == N:  # 원형 구조이므로 시작점과 끝점이 같아지는 경우는 제외
            continue
        distance = calculate_distance(s, e)
        max_distance = max(max_distance, distance)

print(max_distance)  # 최대거리 출력

당연하게도, 시간 초과를 받습니다.

하지만, 이는 \(O(N^3)\)의 풀이로 시간 초과를 받을 수 밖에 없습니다.
어디서 개선해야할까요?

1) 두 지점 사이의 거리를 상수 시간으로 구하기

현재 거리 계산에서는 sum함수를 통해 DISTANCE배열의 일부 구간의 합을 구하고 있습니다. 이는 결국 \(O(N)\)의 동작입니다.

하지만, 누적 합을 사용하면 두 지점 사이의 거리를 선형 시간으로 구할 수 있습니다.

N = int(input())
distance = [int(input()) for _ in range(N)]
total_distance = sum(distance)
prefix_sum = [0 for _ in range(2 * N + 1)]    # 원형 구조를 고려해 2배 크기로 설정.
for i in range(1, 2 * N):
    prefix_sum[i + 1] = prefix_sum[i] + distance[i % N]

누적 합 계산해두기

left_distance = prefix_sum[right] - prefix_sum[left]

누적 합 활용해 거리 계산하기

원형 구조이므로, 배열의 뒤쪽에서 다시 앞쪽 지점으로의 거리를 계산하게 될 상황도 충분히 존재할 수 있으므로 누적 합 배열을 2배 크기로 계산해 두었습니다.

2) 불필요한 경우는 탐색하지 않기

앞선 구현에서는 \(O(N^2)\)으로 임의의 두 지점을 고를 수 있는 모든 경우에 대해 탐색했습니다.

for s in range(N):
    for e in range(s + 1, s + N):
        if e - s == N:  # 원형 구조이므로 시작점과 끝점이 같아지는 경우는 제외
            continue
        distance = calculate_distance(s, e)
        max_distance = max(max_distance, distance)

앞선 구현

하지만, 두 지점 사이의 거리를 계산할 때 시계 방향과 반시계 방향 거리를 모두 고려한다는 점을 생각한다면 반복문에서 실제로 반복하는 횟수를 절반 정도로 줄일 수 있습니다.

결국 원형 구조이므로, 두 지점의 시계 방향 거리와 반시계 방향 거리의 합은 항상 일정합니다.

이런 상황에서, 시계 방향 거리가 반시계 방향 거리보다 짧은 상태에서 점점 시계 방향의 거리가 증가한다면, 반시계 방향 거리는 반대로 점점 감소하다가 결국 시계 방향 거리보다 짧아집니다.

다시 말해, 시계 방향 거리가 반시계 방향 거리보다 작거나 같은 동안에만 반복하면 됩니다.

max_distance = 0
right = 1
for left in range(2 * N):
    while right < 2 * N + 1 and (left_distance := prefix_sum[right] - prefix_sum[left]) <= total_distance - left_distance:
        max_distance = max(max_distance, left_distance)
        right += 1

투 포인터 구현

전체 코드

input = open(0).readline
N = int(input())
distance = [int(input()) for _ in range(N)]
total_distance = sum(distance)
prefix_sum = [0 for _ in range(2 * N + 1)]    # 원형 구조를 고려해 2배 크기로 설정.
for i in range(1, 2 * N):
    prefix_sum[i + 1] = prefix_sum[i] + distance[i % N]

max_distance = 0
right = 1
for left in range(2 * N):
    while right < 2 * N + 1 and (left_distance := prefix_sum[right] - prefix_sum[left]) <= total_distance - left_distance:
        max_distance = max(max_distance, left_distance)
        right += 1

print(max_distance)  # 최대거리 출력

solution.py