[PS] BOJ 1967 / 트리의 지름
문제 링크: https://www.acmicpc.net/problem/1967
Thumbnail: Photo by Gene Gallin (Unsplash)
다양한 방법으로 풀 수 있겠지만, 그래프 탐색을 두 번 하는 것으로 풀었습니다.
풀이
가장 거리가 먼 두개의 정점 찾기
트리(tree)는 사이클이 없는 무방향 그래프이다.
트리에서는 어떤 두 노드를 선택해도 둘 사이에 경로가 항상 하나만 존재하게 된다.
트리에서 어떤 두 노드를 선택해서 양쪽으로 쫙 당길 때, 가장 길게 늘어나는 경우가 있을 것이다.
이럴 때 트리의 모든 노드들은 이 두 노드를 지름의 끝 점으로 하는 원 안에 들어가게 된다.
이런 두 노드 사이의 경로의 길이를 트리의 지름이라고 한다. 정확히 정의하자면 트리에 존재하는 모든 경로들 중에서 가장 긴 것의 길이를 말한다.

요약하자면, 결국 트리의 지름은 트리 내에서 서로 간의 거리가 가장 먼 두 노드(정점) 사이의 거리입니다. 그렇다면, 가장 거리가 먼 두 개의 정점은 어떻게 찾아야 할까요?
여러 가지 방법이 있겠지만, 가장 간단한 것은 루트 노드에서 가장 먼 정점을 찾은 후, 해당 정점에서 가장 먼 정점을 찾는 것입니다. 루트 노드는 위치 상 트리의 지름을 이루는 양 끝 노드보다는 중앙에 있을 가능성이 높다고 생각해, 루트 노드에서 가장 먼 노드를 지름의 한 쪽 끝으로 잡고 반대쪽 노드를 탐색하는 방법입니다.
두 번의 노드를 탐색하는 과정 모두 BFS로 구현했습니다.
먼저, 인접 리스트의 형태로 그래프의 간선 정보를 입력받았습니다.
N = int(input())
edges = [[] for _ in range(N + 1)]
for _ in range(N - 1):
u, v, w = map(int, input().split())
edges[u].append((v, w))
edges[v].append((u, w))인접 리스트로 구현한 그래프
이후, BFS를 구현했습니다. 이후 계산에 사용하기 위해, 이번 BFS 탐색의 시작 정점(start_node) 에서 가장 먼 정점과 그 거리를 반환합니다.
from collections import deque
def bfs(start_node):
visited = [-1 for _ in range(N + 1)]
queue = deque()
queue.append(start_node)
visited = [-1 for _ in range(N + 1)]
visited[start_node] = 0
max_distance = 0
max_node = 0
while queue:
cur_node = queue.popleft()
if visited[cur_node] > max_distance:
max_distance = visited[cur_node]
max_node = cur_node
for next_node, distance in edges[cur_node]:
if visited[next_node] == -1:
visited[next_node] = visited[cur_node] + distance
queue.append(next_node)
return (max_node, max_distance)BFS
두 번의 BFS를 통해, 트리의 지름을 계산합니다.
# bfs (1) : 루트 정점(1)에서 가장 거리가 먼 정점 찾기 -> max_node
# bfs (2) : max_node에서 가장 거리가 먼 정점 찾기
max_node, _ = bfs(1)
_, max_distance = bfs(max_node)
print(max_distance)결과 출력
전체 코드
from collections import deque
input = open(0).readline
N = int(input())
edges = [[] for _ in range(N + 1)]
for _ in range(N - 1):
u, v, w = map(int, input().split())
edges[u].append((v, w))
edges[v].append((u, w))
def bfs(start_node):
visited = [-1 for _ in range(N + 1)]
queue = deque()
queue.append(start_node)
visited = [-1 for _ in range(N + 1)]
visited[start_node] = 0
max_distance = 0
max_node = 0
while queue:
cur_node = queue.popleft()
if visited[cur_node] > max_distance:
max_distance = visited[cur_node]
max_node = cur_node
for next_node, distance in edges[cur_node]:
if visited[next_node] == -1:
visited[next_node] = visited[cur_node] + distance
queue.append(next_node)
return (max_node, max_distance)
# bfs (1) : 루트 정점(1)에서 가장 거리가 먼 정점 찾기 -> max_node
# bfs (2) : max_node에서 가장 거리가 먼 정점 찾기
max_node, _ = bfs(1)
_, max_distance = bfs(max_node)
print(max_distance)solution.py