[PS] BOJ 18769 / 그리드 네트워크
문제 링크: https://www.acmicpc.net/problem/18769
Thumbnail: Photo by Nastya Dulhiier (Unsplash)
격자 그래프에서 최소 신장 트리(MST)를 구하는 문제입니다.
풀이
기본적인 풀이 요령은 MST를 구하는 다른 문제들과 동일합니다. 다만, 정점이 번호로 주어지지 않고 격자 그래프의 형태로 주어지니, 간단한 전처리를 통해 일반적인 그래프의 형태로 저장해봅시다.
1) 간선 입력
각 테스트 케이스에서, 간선 정보는 다음과 같이 입력됩니다.
R개의 줄에 걸쳐서 각 줄에 $C-1$개의 정수가 주어지는데,
이 정수는 각 행에 놓인 $C$개의 서버에서 좌우로 연결하는 통신망을 설치하는 비용을 나타낸다.
다음 $R-1$개의 줄에 걸쳐서 각 줄에 $C$개의 정수가 주어지는데,
이 정수는 $i$번째 행과 $i+1$번째 행에 놓인 $C$개의 서버를 상하로 연결하는 통신망을 설치하는 비용을 나타낸다.
입력받는 그래프는 격자 그래프의 형태이므로, 좌측 상단 모서리의 좌표를 $(0, 0)$, 정점 번호는 0이라 하고 다음 기준에 따라 정점 번호를 부여했습니다.
- 좌표는 (row, col) 형태로 읽습니다.
- 정점 번호는 $row \times C + col$ 으로 계산합니다.
def solution():
R, C = map(int, input().split())
edges = []
for r in range(R): # 수평 방향 간선 입력
for idx, cost in enumerate(map(int, input().split())):
u = r * C + idx
v = u + 1
edges.append((u, v, cost))
edges.append((v, u, cost))
for r in range(R - 1): # 수직 방향 간선 입력
for idx, cost in enumerate(map(int, input().split())):
u = r * C + idx
v = u + C
edges.append((u, v, cost))
edges.append((v, u, cost))간선 입력 받기
2) Kruskal 알고리즘으로 MST 길이 계산하기
이제부터는 Kruskal 알고리즘을 사용해 MST를 계산해주기만 하면 됩니다..
💡 Kruskal 알고리즘에 대한 자세한 설명은 이전 글을 참고해 주세요!
구현 상의 편의를 위해, Union-Find 알고리즘에서 사용하는 분리 집합의 정보를 저장하는 배열 roots는 solution함수 내부에서 정의된 뒤 매개변수를 통해 전달됩니다.
def find_root(roots, x):
if roots[x] == x:
return x
roots[x] = find_root(roots, roots[x])
return roots[x]
def union(roots, x, y):
root_x = find_root(roots, x)
root_y = find_root(roots, y)
if root_x != root_y:
roots[root_y] = root_x
return True
return False
def solution():
... # 앞 코드에서 계속
edges.sort(key=lambda x: x[2]) # 비용 기준으로 정렬
roots = [i for i in range(R * C)] # Union-Find 초기화 / 0부터 R x C - 1까지의 정점 번호 사용
total_cost = 0
total_edges = 0
for u, v, cost in edges:
if union(roots, u, v):
total_cost += cost
if total_edges == R * C - 1: # 모든 정점이 연결되면 종료
break
print(total_cost)Kruskal 알고리즘
전체 코드
input = open(0).readline
def find_root(roots, x):
if roots[x] == x:
return x
roots[x] = find_root(roots, roots[x])
return roots[x]
def union(roots, x, y):
root_x = find_root(roots, x)
root_y = find_root(roots, y)
if root_x != root_y:
roots[root_y] = root_x
return True
return False
def solution():
R, C = map(int, input().split())
edges = []
for r in range(R): # 수평 방향 간선 입력
for idx, cost in enumerate(map(int, input().split())):
u = r * C + idx
v = u + 1
edges.append((u, v, cost))
edges.append((v, u, cost))
for r in range(R - 1): # 수직 방향 간선 입력
for idx, cost in enumerate(map(int, input().split())):
u = r * C + idx
v = u + C
edges.append((u, v, cost))
edges.append((v, u, cost))
edges.sort(key=lambda x: x[2]) # 비용 기준으로 정렬
roots = [i for i in range(R * C)] # Union-Find 초기화 / 0부터 R x C - 1까지의 정점 번호 사용
total_cost = 0
total_edges = 0
for u, v, cost in edges:
if union(roots, u, v):
total_cost += cost
if total_edges == R * C - 1: # 모든 정점이 연결되면 종료
break
print(total_cost)
for _ in range(int(input())):
solution()solution.py