[PS] BOJ 20303 / 할로윈의 양아치

[PS] BOJ 20303 / 할로윈의 양아치
문제 링크: https://www.acmicpc.net/problem/20303
Thumbnail: Photo by Frames For Your Heart (Unsplash)

분리 집합 + 배낭 문제의 환상의 콜라보!

풀이

아이디어

아이들 사이의 친구 관계는 그래프로 나타낼 수 있습니다. 이 때, 간선에는 가중치가 없으며 무방향 그래프로 생각할 수 있습니다.

또한, 스브러스는 매우 공평한 사람이기 때문에 한 아이의 사탕을 뺏으면 그 아이 친구들의 사탕도 모조리 뺏어버린다. (친구의 친구는 친구다?!)

이 조건에 따라, 각각의 그래프 전체를 하나의 물건으로 볼 수 있으며, 그에 따라 이 문제는 배낭 문제가 됩니다.

따라서, 문제를 푸는 순서는 다음과 같습니다.

  1. 주어진 친구 관계(간선) 을 토대로 그래프 구성하기
    1. 이때, 분리 집합을 사용해 같은 그래프에 속하는 아이들을 분류할 수 있습니다.
  2. 분리 집합에 저장된 정보를 토대로, 각 그래프를 무게와 가치의 정보를 가진 물건의 형태로 변환하기.
  3. 2에서 변환한 정보를 토대로 0-1 배낭 문제 풀기

전체 코드

input = open(0).readline

# Union Find

roots = [i for i in range(30001)]
ranks = [0 for _ in range(30001)]

def find_root(node):
    path = []
    while node != roots[node]:
        path.append(node)
        node = roots[node]
    for p in path:
        roots[p] = node
    return node

def union(node_x, node_y):
    root_x = find_root(node_x)
    root_y = find_root(node_y)

    if root_x == root_y: # same parent
        return False
    
    if ranks[root_x] < ranks[root_y]:
        roots[root_x] = root_y
    else:
        roots[root_y] = roots[root_x]

        if ranks[root_x] == ranks[root_y]:
            ranks[root_x] += 1
    return True

N, M, K = map(int, input().split())
candies = list(map(int, input().split()))
edges = [[] for _ in range(N + 1)]

# 친구 관계 (간선) 입력 받기
for _ in range(M):
    u, v = map(int, input().split())
    edges[u].append(v)
    edges[v].append(u)

# 1. 친구 관계 (간선) 정보를 토대로 그래프 구성하기 (분리 집합을 통해 각 정점간의 연결 관계 기록)
for i in range(1, N + 1):
    for j in edges[i]:
        union(i, j)

# 2. 1에서 기록한 분리 집합 정보를 토대로 각 그래프를 무게&가치의 물건으로 변환하기
items_total = {}
for i in range(1, N + 1):
    root = find_root(i)
    try:
        items_total[root][0] += 1
        items_total[root][1] += candies[i - 1]
    except KeyError:
        items_total[root] = [1, candies[i - 1]] # (weight, value)
items = [(0, 0), *items_total.values()]

# 3. 2에서 변환한 물건 정보로 0-1 배낭 문제 풀기
DP = [[0 for _ in range(K)] for _ in range(len(items))] # DP[item_idx][weight_left]

for i in range(1, len(items)):
    for w in range(K):
        if items[i][0] <= w:
            DP[i][w] = max(DP[i-1][w], items[i][1] + DP[i-1][w-items[i][0]])
        else:
            DP[i][w] = DP[i-1][w]

print(DP[len(items) - 1][K - 1])

solution.py