[PS] BOJ 2887 / 행성 터널

[PS] BOJ 2887 / 행성 터널
문제 링크: https://www.acmicpc.net/problem/2887
Thumbnail: Photo by Louis Reed (Unsplash)

최소 신장 트리(MST)를 구하는 문제입니다. 하지만, 이번엔 간선도 직접 찾아야 합니다.

풀이

모든 행성을 터널로 연결하는데 필요한 최소 비용을 찾아야 합니다. 즉, 최소 신장 트리(MST)를 구하는 문제입니다.

하지만, 문제 입력에는 간선이 주어지는 대신, 각 행성(3차원 상의 위치)의 좌표가 주어집니다. 이를 이용해, 먼저 간선을 찾아야 합니다.

1) 간선 찾기

먼저, 두 행성을 연결할 때 그 사이 거리는 어떻게 계산되는지 알아봅시다.

행성은 3차원 좌표위의 한 점으로 생각하면 된다. 두 행성 A(xA, yA, zA)와 B(xB, yB, zB)를 터널로 연결할 때 드는 비용은 min(|xA-xB|, |yA-yB|, |zA-zB|)이다.

이제 주어진 모든 행성에 대해 서로 연결하는 모든 간선을 다 찾아보면...

N = int(input())
planets = [(idx,) + tuple(map(int, input().split())) for idx in range(N)]

edges = [[] for _ in range(N)]
for i in range(N):
    for j in range(i + 1, N):
        weight = min(abs(planets[i][axis] - planets[j][axis]) for axis in range(3))
        edges[i].append(j, weight)
        edges[j].append(i, weight)

...짜잔! 메모리 초과를 받게 됩니다! 이 문제는 128MB라는 비교적 적은 메모리 제한을 가지고 있습니다. N개의 원소 중에서 2개를 택할 수 있는 모든 경우의 수를 계산하게되면, $$O(N^2)$$의 공간복잡도를 가지기 때문에, N이 최대 100,000임을 감안하면 적절하지 않습니다.

2) 메모리 초과를 피해 최소한으로 간선 구하기

모든 간선을 저장한 뒤 MST를 만들게되면 메모리 초과가 발생했습니다. 우리가 필요한 간선 이외에는 저장하지 않는다면, 메모리 초과를 피할 수 있지 않을까요?

🙏 아래 고찰은 질문 게시판에서 많은 도움을 받았습니다! 원본 글을 읽고 싶으시다면 여기로 가주세요.

이에 대해, 임의의 차원의 좌표계에 주어진 N개의 정점에 대해 MST가 어떻게 만들어지는지 고찰해볼 필요가 있습니다. 먼저 가장 간단한 1차원부터 시작해봅시다.

1차원 좌표계에서 최단 거리로 각 좌표를 연결하는 방법

1차원 좌표계의 경우, Case 1처럼 인접한 정점끼리 연결한 경우가 MST가 됩니다.
Case 2처럼 인접하지 않은 정점을 연결하게 될 경우, 최단 거리가 아니게 됨을 그림을 통해 간단히 볼 수 있습니다.

이를 이번 문제와 같은 3차원 좌표계로 확장해 봅시다. 결국 3차원 좌표계에서도, 인접한 행성들끼리 연결하는게 1차원의 경우와 같이 MST가 됨을 알 수 있습니다.

이 개념을 토대로, 모든 행성 간의 거리를 간선으로 기록하는 대신에, 서로 인접하는 행성들끼리의 거리만 간선으로 저장해주면 됩니다! "서로 인접하는 행성"이라는 개념이 좀 모호할 수 있는데, 그냥 x/y/z 각 축 좌표에 대해 서로 인접하는 행성들 사이의 거리를 모두 계산해도 무방합니다.

동일한 경우가 중복해서 들어갈 수 있는데, Kruskal 알고리즘의 경우 Union-Find를 통해 걸러지기 때문에 실제 MST 계산 과정에는 영향이 없습니다.

N = int(input())
planets = [(idx,) + tuple(map(int, input().split())) for idx in range(N)]

# Kruskal 알고리즘을 사용하기 위해 간선을 만든 뒤 가중치의 크기 순으로 오름차순 정렬한다.
# https://www.acmicpc.net/board/view/145011
# 결국 최소 거리로 두 행성을 연결하는 경우는, x/y/z 3개 중 1개 축에서 인접하는 행성끼리 연결하는 경우 뿐이다.
# 1차원 좌표들로 MST를 만들면 일직선으로 형성된다는 점을 3차원으로 확장해보면 된다.

edges = []
for axis in range(1, 4):
    axis_sorted = sorted(planets, key=lambda p: p[axis])
    for i in range(N - 1):
        weight = abs(axis_sorted[i][axis] - axis_sorted[i + 1][axis])
        edges.append((axis_sorted[i][0], axis_sorted[i + 1][0], weight))

edges.sort(key=lambda e: e[2]) # 간선을 가중치가 작은 순으로 정렬한다.

Kruskal 알고리즘으로 MST 찾기

자세한 구현은 이전 글을 참고해주세요 😄

전체 코드

input = open(0).readline
N = int(input())
planets = [(idx,) + tuple(map(int, input().split())) for idx in range(N)]

# Kruskal 알고리즘을 사용하기 위해 간선을 만든 뒤 가중치의 크기 순으로 오름차순 정렬한다.
# https://www.acmicpc.net/board/view/145011
# 결국 최소 거리로 두 행성을 연결하는 경우는, x/y/z 3개 중 1개 축에서 인접하는 행성끼리 연결하는 경우 뿐이다.
# 1차원 좌표들로 MST를 만들면 일직선으로 형성된다는 점을 3차원으로 확장해보면 된다.

edges = []
for axis in range(1, 4):
    axis_sorted = sorted(planets, key=lambda p: p[axis])
    for i in range(N - 1):
        weight = abs(axis_sorted[i][axis] - axis_sorted[i + 1][axis])
        edges.append((axis_sorted[i][0], axis_sorted[i + 1][0], weight))

edges.sort(key=lambda e: e[2]) # 간선을 가중치가 작은 순으로 정렬한다.

# Union Find 알고리즘
parents = [i for i in range(N)]

def find_parent(vertex):
    """어떤 정점(vertex)이 속한 분리 집합의 최상위 노드를 반환한다.
    분리 집합(트리)의 깊이를 줄이기 위해, 임의의 노드 x에 대해 최상위 노드를 탐색한 뒤 최상위 노드를 바로 부모 노드로 설정한다.
    """
    if vertex == parents[vertex]:
        return vertex
    parents[vertex] = find_parent(parents[vertex])
    return parents[vertex]

def union(vertex_a, vertex_b):
    """두 정점 vertex_a와 vertex_b가 속한 두 분리 집합을 병합한다."""
    parent_a = find_parent(vertex_a)
    parent_b = find_parent(vertex_b)

    if parent_a != parent_b:
        parents[parent_b] = parent_a

# Kruskal 알고리즘
mst_len = 0
total_cost = 0
for x, y, c in edges:
    if find_parent(x) == find_parent(y):    # 사이클이 생기는 간선은 MST에 추가해선 안된다.
        continue
    mst_len += 1
    total_cost += c
    union(x, y)

    if mst_len == N - 1:   # MST는 항상 N개의 정점을 가지는 그래프에 대해 N-1개의 간선을 가지는 트리이다.
        break

print(total_cost)

solution.py