[PS] BOJ 2239 / 스도쿠

[PS] BOJ 2239 / 스도쿠
문제 링크: https://www.acmicpc.net/problem/2239
Thumbnail: Photo by Richard Bell (Unsplash)

스도쿠를 풀어봅시다. 백트래킹으로요.

풀이

스도쿠의 규칙은 다음과 같습니다.

  • 같은 행에는 1~9까지의 수가 1개씩만 포함됩니다.
  • 같은 열에는 1~9까지의 수가 1개씩만 포함됩니다.
  • 같은 칸(3x3)에는 1~9까지의 수가 1개씩만 포함됩니다.

입력으로 주어지는 스도쿠는 빈칸이 굉장히 많이 있습니다. 따라서, 가능한 정답이 여러개인데, 이 중 사전순으로 가장 빠른 정답을 출력해야 합니다.

절차

  • 각 빈칸에 적절한 수를 채우며 DFS를 진행합니다.
    • 이때, 적절하지 않은 수에 대해서는 더 이상 탐색하지 않습니다 (백트래킹)
  • 만약 모든 빈칸을 채웠다면, 기존에 찾은 정답과 비교해 사전순으로 빠른 경우에만 정답을 갱신합니다.

백트래킹

백트래킹에 사용하는 조건은 스도쿠의 규칙을 그대로 사용합니다. 백트래킹의 각 분기마다 빈칸을 하나씩 채워야 하므로, 다음에 탐색할 분기가 유망한지 판단하려면 해당 수가 스도쿠의 규칙을 만족하는지를 알아야 합니다. 이를 위해서, 행/열/칸 검사를 진행합니다.

행/열의 수 검사는 단순히 2차원 배열을 활용하면 됩니다.

rows = [[False for _ in range(10)] for _ in range(9)]
cols = [[False for _ in range(10)] for _ in range(9)]

행/열 검사

스도쿠는 행/열 이외에도 $3 \times 3$의 같은 칸 내에서 9개의 수를 겹치지 않게 채워야 합니다. 이 칸(grid)를 왼쪽 위부터 $0, 1, \cdots, 8$번이라고 할 때 다음과 같은 식으로 계산할 수 있습니다. r은 현재 스도쿠 칸의 행 좌표, c는 열 좌표입니다.

grids = [[False for _ in range(10)] for _ in range(9)]
grids_idx = (r // 3) * 3 + (c // 3)

3x3 칸 검사

입력 처리하기

이제, 스도쿠를 입력받으며 사전에 필요한 정보들을 저장해둬야 합니다.

  • sudoku: 입력받은 스도쿠를 저장하는 2차원 배열입니다. 이후 결과 출력시 사용합니다.
  • rows, cols, grids: 백트래킹시 분기의 유망성을 판단하기 위해 사용하는 검사 배열입니다. 2차원 배열로 구성됩니다.
  • blanks: 스도쿠의 빈칸 좌표를 저장합니다.
  • values: 백트래킹을 진행하면서 답을 채워나갈 배열입니다.
  • answer: 백트래킹을 통해 찾은 정답을 저장할 배열입니다.
sudoku = [[0 for _ in range(9)] for _ in range(9)]
rows = [[False for _ in range(10)] for _ in range(9)]
cols = [[False for _ in range(10)] for _ in range(9)]
grids = [[False for _ in range(10)] for _ in range(9)]
blanks = []     # [(r, c)] 형태로 빈칸의 좌표를 저장한다.
values = []     # 백트래킹을 진행하면서 답을 채울 배열
answer = []     # 사전순으로 가장 빠른 답을 저장할 배열

for r in range(9):
    for c, v in enumerate(map(int, input().strip())):
        sudoku[r][c] = v
        if v == 0:
            blanks.append((r, c))
        else:
            grids_idx = (r // 3) * 3 + (c // 3)
            rows[r][v] = True
            cols[c][v] = True
            grids[grids_idx][v] = True

blank_count = len(blanks)
values.extend(0 for _ in range(blank_count))
answer.extend(10 for _ in range(blank_count))

입력 처리하기

백트래킹 구현하기

백트래킹 함수는 매개변수로 현재 분기의 깊이(depth)를 받습니다. 이는 몇 번째 빈칸을 탐색하고 있는지 나타냅니다.

정답 배열을 갱신하는 과정은 별도 함수인 update_answer로 분리했습니다.

백트래킹을 진행하면서, 해당 숫자가 스도쿠 규칙에 맞는지를 판단하기 전에 추가로 한가지 더 판단합니다. 결과적으로 사전순으로 제일 빠른 답을 출력해야 하기에, 이전에 발견한 정답보다 사전순으로 뒤처지는 결과는 굳이 더 계산할 필요가 없습니다.

따라서, 현재 탐색중인 빈칸에 이전 정답보다 더 큰 수를 채우는 경우는 더 이상 탐색할 필요가 없습니다.

def update_answer():
    """정답 배열(answer)을 현재 분기에서 찾은 답(values)으로 갱신합니다."""
    for i in range(blank_count):
        if values[i] < answer[i]:
            for j in range(blank_count):
                answer[j] = values[j]
            return

def backtracking(depth):
    if depth == blank_count:
        for i in range(blank_count):
            if values[i] < answer[i]:
                update_answer()
                return
    
    r, c = blanks[depth]
    grids_idx = (r // 3) * 3 + (c // 3)
    for v in range(1, 10):
        # 이전에 발견한 정답보다 사전순으로 뒤에 위치하는 분기는 탐색하지 않는다.
        if v > answer[depth]:
            break
        
        if not rows[r][v] and not cols[c][v] and not grids[grids_idx][v]:
            # 분기 변수 설정
            sudoku[r][c] = v
            rows[r][v] = True
            cols[c][v] = True
            grids[grids_idx][v] = True
            values[depth] = v
            # DFS (백트래킹)
            backtracking(depth + 1)
            # 분기 변수 초기화
            rows[r][v] = False
            cols[c][v] = False
            grids[grids_idx][v] = False
        values[depth] = 0

backtracking(0)

전체 코드

input = open(0).readline

sudoku = [[0 for _ in range(9)] for _ in range(9)]
rows = [[False for _ in range(10)] for _ in range(9)]
cols = [[False for _ in range(10)] for _ in range(9)]
grids = [[False for _ in range(10)] for _ in range(9)]
blanks = []     # [(r, c)] 형태로 빈칸의 좌표를 저장한다.
values = []     # 백트래킹을 진행하면서 답을 채울 배열
answer = []     # 사전순으로 가장 빠른 답을 저장할 배열

for r in range(9):
    for c, v in enumerate(map(int, input().strip())):
        sudoku[r][c] = v
        if v == 0:
            blanks.append((r, c))
        else:
            grids_idx = (r // 3) * 3 + (c // 3)
            rows[r][v] = True
            cols[c][v] = True
            grids[grids_idx][v] = True

blank_count = len(blanks)
values.extend(0 for _ in range(blank_count))
answer.extend(10 for _ in range(blank_count))

def update_answer():
    """정답 배열(answer)을 현재 분기에서 찾은 답(values)으로 갱신합니다."""
    for i in range(blank_count):
        if values[i] < answer[i]:
            for j in range(blank_count):
                answer[j] = values[j]
            return

def backtracking(depth):
    if depth == blank_count:
        for i in range(blank_count):
            if values[i] < answer[i]:
                update_answer()
                return
    
    r, c = blanks[depth]
    grids_idx = (r // 3) * 3 + (c // 3)
    for v in range(1, 10):
        # 이전에 발견한 정답보다 사전순으로 뒤에 위치하는 분기는 탐색하지 않는다.
        if v > answer[depth]:
            break
        
        if not rows[r][v] and not cols[c][v] and not grids[grids_idx][v]:
            # 분기 변수 설정
            sudoku[r][c] = v
            rows[r][v] = True
            cols[c][v] = True
            grids[grids_idx][v] = True
            values[depth] = v
            # DFS (백트래킹)
            backtracking(depth + 1)
            # 분기 변수 초기화
            rows[r][v] = False
            cols[c][v] = False
            grids[grids_idx][v] = False
        values[depth] = 0

backtracking(0)

# 결과 출력하기
for i in range(blank_count):
    r, c = blanks[i]
    sudoku[r][c] = answer[i]

print("\n".join("".join(map(str, row)) for row in sudoku))

solution.py