[PS] BOJ 11404 / 플로이드

[PS] BOJ 11404 / 플로이드
문제 링크: https://www.acmicpc.net/problem/11404
Thumbnail: Photo by Gerson Repreza (Unsplash)

플로이드-워셜(Floyd-Warshall) 알고리즘을 활용하는 정석적인 문제입니다.

풀이

플로이드-워셜(Floyd-Warshall)!

플로이드-워셜 알고리즘은 모든 정점에서 모든 정점으로 가는 최단 거리를 탐색하는 방법입니다. 다익스트라 알고리즘이 한 정점에서 다른 모든 정점으로의 최단 거리를 계산하는 점에서 차이가 있습니다.

​모든 정점에서 모든 정점으로 가는 최단 거리를 탐색하기 때문에, $O(N^3)$의 시간 복잡도를 가지고 있어 정점의 개수가 적은 문제일 경우에만 사용할 수 있습니다.
또한, 알고리즘에서 각 정점 사이의 최단 거리를 인접 행렬의 형태로 관리하기 때문에, 간선 또한 인접 행렬로 받는 편이 편리합니다.

N = int(input())
M = int(input())
INF = int(1e9)

graph = [[INF for _ in range(N + 1)] for _ in range(N + 1)]

for _ in range(M):
    u, v, cost = map(int, input().split())
    if cost < graph[u][v]: # 두 정점이 같은 간선이면 가중치가 적은 쪽을 저장한다.
        graph[u][v] = cost

인접 행렬로 그래프 입력 받기

문제에서, 주어질 수 있는 간선의 조건을 다음과 같이 정의했습니다:

  • 시작 도시와 도착 도시가 같은 경우는 없다. 비용은 100,000보다 작거나 같은 자연수이다.
    • 1->1, 2->2와 같은 간선은 존재하지 않습니다.
  • 시작 도시와 도착 도시를 연결하는 노선은 하나가 아닐 수 있다.
    • 1->2 (3), 1->2 (5) 처럼 출발 정점과 도착 정점이 같은 간선이 여러 개 입력될 수 있습니다.
    • 이 경우, 가중치가 제일 작은 간선만 남겨두고 나머지는 굳이 저장할 필요가 없습니다.

구현

이제 플로이드-워셜 알고리즘을 구현해봅시다.

먼저, 각각의 정점 사이의 최단 거리를 저장할 인접 행렬 dist를 정의합니다.
이후, 기존 그래프의 간선 정보를 토대로 초기화합니다.

dist = [[INF for _ in range(N + 1)] for _ in range(N + 1)]
for i in range(1, N + 1):
    for j in range(1, N + 1):
        if i == j:
            dist[i][j] = 0
        elif graph[i][j] != INF:
            dist[i][j] = graph[i][j]

dist 인접 행렬 초기화

다음으로, 그래프의 모든 두 정점에 대해 최단 거리를 찾아야 합니다. 이는 3중 반복문으로 진행합니다.

  1. 두 정점 사이에 거쳐가는 정점 $k$를 선택합니다.
  2. 첫 번째 정점 $i$를 선택합니다.
  3. 두 번째 정점 $j$를 선택합니다.
    이후, $i$ -> $j$ 로의 거리를 $i$ -> $j$와 $i$ -> $k$ -> $j$ 중 짧은 쪽으로 갱신합니다.
for k in range(1, N + 1):
    for i in range(1, N + 1):
        for j in range(1, N + 1):
            if dist[i][j] > dist[i][k] + dist[k][j]:
                dist[i][j] = dist[i][k] + dist[k][j]

플로이드-워셜 알고리즘

출력하기

이번 문제에서는 최단 거리를 저장한 dist배열을 그대로 출력해주면 됩니다. 단, 실제 정점은 1번부터 사용되었으므로 0번 행/열에 해당하는 부분은 출력하지 않습니다.

for r in range(1, N + 1):
    print(" ".join(map(lambda d: "0" if d == INF else str(d), dist[r][1:])))

결과 출력

전체 코드

input = open(0).readline
N = int(input())
M = int(input())
INF = int(1e9)

graph = [[INF for _ in range(N + 1)] for _ in range(N + 1)]

for _ in range(M):
    u, v, cost = map(int, input().split())
    if cost < graph[u][v]:
        graph[u][v] = cost

dist = [[INF for _ in range(N + 1)] for _ in range(N + 1)]
for i in range(1, N + 1):
    for j in range(1, N + 1):
        if i == j:
            dist[i][j] = 0
        elif graph[i][j] != INF:
            dist[i][j] = graph[i][j]

for k in range(1, N + 1):
    for i in range(1, N + 1):
        for j in range(1, N + 1):
            if dist[i][j] > dist[i][k] + dist[k][j]:
                dist[i][j] = dist[i][k] + dist[k][j]

for r in range(1, N + 1):
    print(" ".join(map(lambda d: "0" if d == INF else str(d), dist[r][1:])))

solution.py