[PS] BOJ 14889 / 스타트와 링크
문제 링크: https://www.acmicpc.net/problem/14889
백트래킹을 활용하는 문제입니다.
풀이
고찰
$N$은 언제나 짝수로 주어지고, $N / 2$명씩 두 팀으로 나눠 두 팀의 능력치 차가 최소가 되는 경우를 찾아야 합니다.
$4 \leq N \leq 20$이므로, 백트래킹으로 충분히 모든 경우를 탐색해볼 수 있습니다.
구현
백트래킹으로 가능한 모든 조합을 시도하며, 구성한 팀원의 수가 $N/2$가 될 때 탐색을 종료하고 두 팀의 능력치 차를 계산합니다.
한 팀의 능력치란, 임의의 두 팀원 $i, j$에 대해 두 사람의 능력치를 $S_ij$라고 하면, 모든 $S_ij$쌍의 합으로 정의됩니다.
전체 코드
input = open(0).readline
N = int(input())
stats = [list(map(int, input().split())) for _ in range(N)]
group = [False for _ in range(N)]
teams = [-1 for _ in range(N // 2)]
min_diff = float("inf")
def backtrack(depth, member_start, team1_stat):
if depth == N // 2:
global min_diff
team2 = [i for i in range(N) if not group[i]]
team2_stat = 0
for player in team2:
for other in team2:
if player != other:
team2_stat += stats[player][other]
diff = abs(team1_stat - team2_stat)
min_diff = min(min_diff, diff)
return
original_stat = team1_stat
for i in range(member_start, N):
if not group[i]:
group[i] = True
teams[depth] = i
for j in range(depth):
team1_stat += stats[i][teams[j]] + stats[teams[j]][i]
backtrack(depth + 1, i + 1, team1_stat)
team1_stat = original_stat
group[i] = False
teams[depth] = -1
backtrack(0, 0, 0)
print(min_diff)solution.py
시행착오
1) 팀의 능력치 계산 문제
백트래킹을 통해 만든 1번 팀의 능력치는 잘 계산되나, 2번 팀의 능력치를 계산할 때 중복으로 계산되는 문제가 있었습니다.
def backtrack(depth, member_start, team1_stat):
if depth == N // 2:
global min_diff
team2 = [i for i in range(N) if not group[i]]
team2_stat = 0
for player in team2:
for other in team2:
if player != other:
team2_stat += stats[player][other] + stats[other][player]
diff = abs(team1_stat - team2_stat)
min_diff = min(min_diff, diff)
return
...잘못 계산하는 코드
아래와 같이, player->other 방향의 시너지만 더해서 중복을 제거했습니다.
team2 = [i for i in range(N) if not group[i]]
team2_stat = 0
for player in team2:
for other in team2:
if player != other:
team2_stat += stats[player][other]수정한 코드
2) 시간 초과 문제
백트래킹을 통해 중복된 경우도 계속 탐색하게 되어 발생한 문제였습니다. (1, 3, 6)이 팀이 되나 (6, 3, 1)이 팀이 되나 팀의 능력치 합은 동일하기 때문에, 언제나 오름차순으로만 팀원을 찾도록 변경해 탐색 횟수를 줄였습니다.
def backtrack(depth, team1_stat):
if depth == N // 2:
...
original_stat = team1_stat
for i in range(N):
if not group[i]:
group[i] = True
teams[depth] = i
for j in range(depth):
team1_stat += stats[i][teams[j]] + stats[teams[j]][i]
backtrack(depth + 1, team1_stat)
team1_stat = original_stat
group[i] = False
teams[depth] = -1
backtrack(0, 0)시간 초과가 나는 코드
def backtrack(depth, member_start, team1_stat):
if depth == N // 2:
...
original_stat = team1_stat
for i in range(member_start, N):
if not group[i]:
group[i] = True
teams[depth] = i
for j in range(depth):
team1_stat += stats[i][teams[j]] + stats[teams[j]][i]
backtrack(depth + 1, i + 1, team1_stat)
team1_stat = original_stat
group[i] = False
teams[depth] = -1
backtrack(0, 0, 0)수정한 코드