[PS] BOJ 18769 / 그리드 네트워크

[PS] BOJ 18769 / 그리드 네트워크
문제 링크: https://www.acmicpc.net/problem/18769
Thumbnail: Photo by Nastya Dulhiier (Unsplash)

격자 그래프에서 최소 신장 트리(MST)를 구하는 문제입니다.

풀이

​기본적인 풀이 요령은 MST를 구하는 다른 문제들과 동일합니다. 다만, 정점이 번호로 주어지지 않고 격자 그래프의 형태로 주어지니, 간단한 전처리를 통해 일반적인 그래프의 형태로 저장해봅시다.

1) 간선 입력

각 테스트 케이스에서, 간선 정보는 다음과 같이 입력됩니다.

R개의 줄에 걸쳐서 각 줄에 $C-1$개의 정수가 주어지는데,
이 정수는 각 행에 놓인 $C$개의 서버에서 좌우로 연결하는 통신망을 설치하는 비용을 나타낸다. 

다음 $R-1$개의 줄에 걸쳐서 각 줄에 $C$개의 정수가 주어지는데,
이 정수는 $i$번째 행과 $i+1$번째 행에 놓인 $C$개의 서버를 상하로 연결하는 통신망을 설치하는 비용을 나타낸다.

입력받는 그래프는 격자 그래프의 형태이므로, 좌측 상단 모서리의 좌표를 $(0, 0)$, 정점 번호는 0이라 하고 다음 기준에 따라 정점 번호를 부여했습니다.

  • 좌표는 (row, col) 형태로 읽습니다.
  • 정점 번호는 $row \times C + col$ 으로 계산합니다.
def solution():
    R, C = map(int, input().split())
    edges = []
    for r in range(R):  # 수평 방향 간선 입력
        for idx, cost in enumerate(map(int, input().split())):
            u = r * C + idx
            v = u + 1
            edges.append((u, v, cost))
            edges.append((v, u, cost))
    for r in range(R - 1):  # 수직 방향 간선 입력
        for idx, cost in enumerate(map(int, input().split())):
            u = r * C + idx
            v = u + C
            edges.append((u, v, cost))
            edges.append((v, u, cost))

간선 입력 받기

2) Kruskal 알고리즘으로 MST 길이 계산하기

​이제부터는 Kruskal 알고리즘을 사용해 MST를 계산해주기만 하면 됩니다..

💡 Kruskal 알고리즘에 대한 자세한 설명은 이전 글을 참고해 주세요!

구현 상의 편의를 위해, Union-Find 알고리즘에서 사용하는 분리 집합의 정보를 저장하는 배열 rootssolution함수 내부에서 정의된 뒤 매개변수를 통해 전달됩니다.

def find_root(roots, x):
    if roots[x] == x:
        return x
    roots[x] = find_root(roots, roots[x])
    return roots[x]

def union(roots, x, y):
    root_x = find_root(roots, x)
    root_y = find_root(roots, y)
    if root_x != root_y:
        roots[root_y] = root_x
        return True
    return False

def solution():
    ... # 앞 코드에서 계속
    edges.sort(key=lambda x: x[2])  # 비용 기준으로 정렬

    roots = [i for i in range(R * C)]  # Union-Find 초기화 / 0부터 R x C - 1까지의 정점 번호 사용
    total_cost = 0
    total_edges = 0
    for u, v, cost in edges:
        if union(roots, u, v):
            total_cost += cost
            
        if total_edges == R * C - 1:  # 모든 정점이 연결되면 종료
            break
    print(total_cost)

Kruskal 알고리즘

전체 코드

input = open(0).readline

def find_root(roots, x):
    if roots[x] == x:
        return x
    roots[x] = find_root(roots, roots[x])
    return roots[x]

def union(roots, x, y):
    root_x = find_root(roots, x)
    root_y = find_root(roots, y)
    if root_x != root_y:
        roots[root_y] = root_x
        return True
    return False

def solution():
    R, C = map(int, input().split())
    edges = []
    for r in range(R):  # 수평 방향 간선 입력
        for idx, cost in enumerate(map(int, input().split())):
            u = r * C + idx
            v = u + 1
            edges.append((u, v, cost))
            edges.append((v, u, cost))
    for r in range(R - 1):  # 수직 방향 간선 입력
        for idx, cost in enumerate(map(int, input().split())):
            u = r * C + idx
            v = u + C
            edges.append((u, v, cost))
            edges.append((v, u, cost))
    edges.sort(key=lambda x: x[2])  # 비용 기준으로 정렬

    roots = [i for i in range(R * C)]  # Union-Find 초기화 / 0부터 R x C - 1까지의 정점 번호 사용
    total_cost = 0
    total_edges = 0
    for u, v, cost in edges:
        if union(roots, u, v):
            total_cost += cost
            
        if total_edges == R * C - 1:  # 모든 정점이 연결되면 종료
            break
    print(total_cost)

for _ in range(int(input())):
    solution()

solution.py