[PS] BOJ 1717 / 집합의 표현

[PS] BOJ 1717 / 집합의 표현
문제 링크: https://www.acmicpc.net/problem/1717
Thumbnail: Photo by Orevaoghene Ahia (Unsplash)

분리 집합을 구현하는 방법인 Union-Find 알고리즘을 구현하는 문제입니다.

풀이

분리 집합(Disjoint Set)?

분리 집합의 수학적 정의는 다음과 같습니다:

전체 집합 \(U\)에 대해, 서로 겹치는 원소를 가지지 않는 둘 이상의 부분 집합 \(A, B\)를 말합니다. 두 개의 부분 집합으로 나눌 때, 조건은 다음과 같습니다.

  • \(A \subset U\), \(B \subset U\)
  • \(A \cap B = \O\)
  • \(A \cup B = U\)

Union-Find?

Union-Find 알고리즘은 분리 집합 자료구조를 구현하기 위해 사용하는 방법입니다.

Union-Find 알고리즘은 2가지 연산을 구현합니다.

  • find_parent(x) : 임의의 원소 \(x\)가 속한 집합을 구합니다.
  • union(x, y) : 두 원소 \(x, y\)가 속한 두 부분 집합을 병합합니다.
# Union-Find
parents = [i for i in range(N + 1)] # 1부터 N까지의 각 원소에 대해 자신이 속한 집합의 대표 원소를 저장하는 배열

def find_parent(x):
    """임의의 원소 x가 속한 부분 집합을 구합니다.
    반환되는 정수 값은 x가 속한 부분 집합의 루트 노드입니다."""
    while x != parents[x]:
        path.append(x)
        x = parents[x]
    return x

def union(x, y):
    """두 원소 x, y가 속한 두 부분 집합을 병합합니다."""
    parent_x = find_parent(x)
    parent_y = find_parent(y)

    if parent_x != parent_y:
        parents[parent_y] = parent_x

Union-Find 알고리즘의 구현

Union-Find 알고리즘 최적화

위 Union-Find 알고리즘의 구현을 보면, 매번 union 연산을 수행할 때마다 한쪽 부분 집합의 루트 노드가 다른 쪽 부분 집합의 자식 노드로 포함되게 됩니다.
이런 방식으로는 입력 데이터에 따라 집합을 저장하는 트리가 한쪽 방향으로 깊게 늘어질 수 있으며, 이런 상태에서는 find_parent(x)에서 반복하는 횟수가 증가해 전체적인 계산 시간이 느려집니다.

이 문제는 Union-Find의 구현을 개선해서 예방할 수 있습니다.

1) find_root(x) 최적화: 경로 압축(Path Compression)

매번 find_root(x)를 진행할 때 마다, 자식 노드부터 루트 노드까지 반복해서 거슬러 올라가는 과정은 불필요합니다. 한 번 거슬러 올라간 뒤, 방문했던 모든 노드들의 부모 노드를 바로 루트 노드로 변경한다면, 다음 번 find_root(x)는 1회 탐색만으로 답을 찾아낼 수 있습니다.

이미지 출처: Union-Find 최적화
# 재귀 구현
def find_parent(x):
    if x == parents[x]:
        return x
    parents[x] = find_parent(parents[x])
    return parents[x]

# 반복문 구현
def find_parent(x):
    path = []
    while x != parents[x]:
        path.append(x)
        x = parents[x]
    for p in path:
        parents[p] = x
    return x

Path Compression 구현

2) union(x, y) 최적화: Union by Rank

기존 union(x, y)연산은 항상 x가 속한 집합(집합 B)에 y가 속한 집합(집합 A)을 하위 노드로 추가하고 있습니다.

만약 집합 A의 트리의 높이(rank)가 더 크다면, 반대로 집합 B를 집합 A의 하위 노드로 추가하는 편이 탐색 시간을 줄여 전체 성능에 유리합니다.

# Union-Find
parents = [i for i in range(N + 1)] # 1부터 N까지의 각 원소에 대해 자신이 속한 집합의 대표 원소를 저장하는 배열
ranks = [0 for _ in range(N + 1)] # 각 원소가 속한 집합의 트리의 높이를 저장할 배열

def find_parent(x):
    ... # 기존과 동일

def union(x, y):
    """두 원소 x, y가 속한 두 부분 집합을 병합합니다."""
    parent_x = find_parent(x)
    parent_y = find_parent(y)

    if parent_x == parent_y: # 이미 x와 y가 같은 집합에 속해 있다면 합치지 않는다.
        return
    # 항상 높이가 더 낮은 트리를 높이가 높은 트리 밑에 넣는다.
    # 즉, 높이가 더 높은 쪽을 root로 삼음
    if ranks[parent_x] < ranks[parent_y]:
        parents[parent_x] = parent_y
    else:
        parents[parent_y] = parent_x

        if ranks[parent_x] == ranks[parent_y]:
            ranks[parent_x] += 1 # 만약 높이가 같다면 합친 후 (x의 높이 + 1)

Union by Rank 구현

전체 코드

input = open(0).readline
N, M = map(int, input().split())

# Union-Find
parents = [i for i in range(N + 1)]
ranks = [0 for _ in range(N + 1)] # 각 원소가 속한 집합의 트리의 높이를 저장할 배열

def find_parent(x):
    path = []
    while x != parents[x]:
        path.append(x)
        x = parents[x]
    for p in path:
        parents[p] = x
    return x

def union(x, y):
    """두 원소 x, y가 속한 두 부분 집합을 병합합니다."""
    parent_x = find_parent(x)
    parent_y = find_parent(y)

    if parent_x == parent_y: # 이미 x와 y가 같은 집합에 속해 있다면 합치지 않는다.
        return
    # 항상 높이가 더 낮은 트리를 높이가 높은 트리 밑에 넣는다.
    # 즉, 높이가 더 높은 쪽을 root로 삼음
    if ranks[parent_x] < ranks[parent_y]:
        parents[parent_x] = parent_y
    else:
        parents[parent_y] = parent_x

        if ranks[parent_x] == ranks[parent_y]:
            ranks[parent_x] += 1 # 만약 높이가 같다면 합친 후 (x의 높이 + 1)

for _ in range(M):
    cmd, a, b = map(int, input().split())
    if cmd == 0:
        union(a, b)
    else:
        print("YES" if find_parent(a) == find_parent(b) else "NO")

solution.py