[PS] BOJ 1967 / 트리의 지름

[PS] BOJ 1967 / 트리의 지름
문제 링크: https://www.acmicpc.net/problem/1967
Thumbnail: Photo by Gene Gallin (Unsplash)

다양한 방법으로 풀 수 있겠지만, 그래프 탐색을 두 번 하는 것으로 풀었습니다.

풀이

가장 거리가 먼 두개의 정점 찾기

트리(tree)는 사이클이 없는 무방향 그래프이다.
트리에서는 어떤 두 노드를 선택해도 둘 사이에 경로가 항상 하나만 존재하게 된다.

트리에서 어떤 두 노드를 선택해서 양쪽으로 쫙 당길 때, 가장 길게 늘어나는 경우가 있을 것이다.
이럴 때 트리의 모든 노드들은 이 두 노드를 지름의 끝 점으로 하는 원 안에 들어가게 된다.

이런 두 노드 사이의 경로의 길이를 트리의 지름이라고 한다. 정확히 정의하자면 트리에 존재하는 모든 경로들 중에서 가장 긴 것의 길이를 말한다.
트리의 지름 예시 (출처: 백준의 해당 문제에 첨부된 사진)

​요약하자면, 결국 트리의 지름은 트리 내에서 서로 간의 거리가 가장 먼 두 노드(정점) 사이의 거리입니다. 그렇다면, 가장 거리가 먼 두 개의 정점은 어떻게 찾아야 할까요?

여러 가지 방법이 있겠지만, 가장 간단한 것은 루트 노드에서 가장 먼 정점을 찾은 후, 해당 정점에서 가장 먼 정점을 찾는 것입니다. 루트 노드는 위치 상 트리의 지름을 이루는 양 끝 노드보다는 중앙에 있을 가능성이 높다고 생각해, 루트 노드에서 가장 먼 노드를 지름의 한 쪽 끝으로 잡고 반대쪽 노드를 탐색하는 방법입니다.

두 번의 노드를 탐색하는 과정 모두 BFS로 구현했습니다.

먼저, 인접 리스트의 형태로 그래프의 간선 정보를 입력받았습니다.

N = int(input())
edges = [[] for _ in range(N + 1)]

for _ in range(N - 1):
    u, v, w = map(int, input().split())
    edges[u].append((v, w))
    edges[v].append((u, w))

인접 리스트로 구현한 그래프

이후, BFS를 구현했습니다. 이후 계산에 사용하기 위해, 이번 BFS 탐색의 시작 정점(start_node) 에서 가장 먼 정점과 그 거리를 반환합니다.

from collections import deque

def bfs(start_node):
    visited = [-1 for _ in range(N + 1)]
    queue = deque()
    queue.append(start_node)
    visited = [-1 for _ in range(N + 1)]
    visited[start_node] = 0
    max_distance = 0
    max_node = 0
    while queue:
        cur_node = queue.popleft()

        if visited[cur_node] > max_distance:
            max_distance = visited[cur_node]
            max_node = cur_node

        for next_node, distance in edges[cur_node]:
            if visited[next_node] == -1:
                visited[next_node] = visited[cur_node] + distance
                queue.append(next_node)
    return (max_node, max_distance)

BFS

두 번의 BFS를 통해, 트리의 지름을 계산합니다.

# bfs (1) : 루트 정점(1)에서 가장 거리가 먼 정점 찾기 -> max_node
# bfs (2) : max_node에서 가장 거리가 먼 정점 찾기
max_node, _ = bfs(1)
_, max_distance = bfs(max_node)

print(max_distance)

결과 출력

전체 코드

from collections import deque
input = open(0).readline
N = int(input())
edges = [[] for _ in range(N + 1)]

for _ in range(N - 1):
    u, v, w = map(int, input().split())
    edges[u].append((v, w))
    edges[v].append((u, w))

def bfs(start_node):
    visited = [-1 for _ in range(N + 1)]
    queue = deque()
    queue.append(start_node)
    visited = [-1 for _ in range(N + 1)]
    visited[start_node] = 0
    max_distance = 0
    max_node = 0
    while queue:
        cur_node = queue.popleft()

        if visited[cur_node] > max_distance:
            max_distance = visited[cur_node]
            max_node = cur_node

        for next_node, distance in edges[cur_node]:
            if visited[next_node] == -1:
                visited[next_node] = visited[cur_node] + distance
                queue.append(next_node)
    return (max_node, max_distance)

# bfs (1) : 루트 정점(1)에서 가장 거리가 먼 정점 찾기 -> max_node
# bfs (2) : max_node에서 가장 거리가 먼 정점 찾기
max_node, _ = bfs(1)
_, max_distance = bfs(max_node)

print(max_distance)

solution.py