[PS] BOJ 1167 / 트리의 지름

[PS] BOJ 1167 / 트리의 지름
Thumbnail: Photo by Matheo JBT (Unsplash)
문제 링크: https://www.acmicpc.net/problem/1167

풀이

동일한 내용의 다른 문제가 있습니다.

[PS] BOJ 1967 / 트리의 지름
그래프 탐색 문제인 BOJ 1967을 풀어봤습니다.

트리의 지름?

트리의 지름이란, 트리의 임의의 두 정점 간의 거리 중 가장 긴 것을 말합니다.

1967번 문제의 그림을 참고했습니다.

문제에서는 100,000개 이하의 정점이 주어지고, 각 간선의 가중치(거리)는 10,000 이하의 정수로 주어집니다.

그래프 탐색으로 트리의 지름 구하기

주어진 트리의 지름은 2번의 BFS로 구할 수 있습니다.

  1. 임의의 정점에서 BFS를 통해 가장 먼 정점을 찾습니다.
  2. 1에서 찾은 정점에서 가장 먼 정점까지의 거리를 구합니다.

트리의 지름을 구하는 과정을 단계별로 생각해 봅시다. 트리의 아무 정점을 루트로 잡았을 때,

  1. 루트에 연결된 서브트리 중 가장 깊은 트리가 2개 이상일 경우, 두 서브트리의 가장 깊은 정점 2개를 잡아 그 거리를 구하면 됩니다.
  2. 루트에 연결된 서브트리 중 가장 깊은 트리가 1개라면,
    1. 이 서브트리와 다음으로 깊은 서브트리에서 각각 가장 깊은 정점을 찾아 그 거리가 지름이 되거나
    2. 루트에 연결된 가장 깊은 서브트리의 지름이 전체 트리의 지름이 됩니다.

따라서, 아무 정점이나 루트로 잡고 그 정점에서 가장 멀리 떨어진 정점을 찾고 (1), 1에서 찾은 정점으로부터 가장 멀리 떨어진 정점을 찾으면 그 거리가 전체 트리의 지름임이 보장됩니다.

전체 코드

from collections import deque
input = open(0).readline
V = int(input())
edges = tuple([] for _ in range(100_001))

for _ in range(V):
    edge_inputs = map(int, input().split())
    u = next(edge_inputs)
    while edge_inputs:
        v = next(edge_inputs)
        if v == -1:
            break
            
        w = next(edge_inputs)
        edges[u].append((v, w))
        edges[v].append((u, w))

queue = deque()

def bfs(start_node):
    visited = [-1] * (V + 1)
    visited[start_node] = 0
    queue.clear()
    queue.append(start_node) # (정점, 거리)
    max_distance = 0
    deepest_node = start_node

    while queue:
        cur_node = queue.popleft()
        for next_node, weight in edges[cur_node]:
            if visited[next_node] == -1:
                visited[next_node] = visited[cur_node] + weight
                queue.append(next_node)
                if visited[next_node] > max_distance:
                    max_distance = visited[next_node]
                    deepest_node = next_node
    return deepest_node, max_distance

before_node = 1
root_node = 1

# bfs (1) : 루트 정점(1)에서 가장 거리가 먼 정점 찾기 -> max_node
# bfs (2) : max_node에서 가장 거리가 먼 정점 찾기
max_node, _ = bfs(1)
_, max_distance = bfs(max_node)

print(max_distance)

solution.py