[PS] BOJ 15666 / N과 M (12)
문제 링크: https://www.acmicpc.net/problem/15666
Thumbnail: Photo by Tony Pham (Unsplash)
N과 M은 다시 말하지만 유명한 백트래킹 시리즈입니다.
풀이
기본적인 풀이는 15663번 문제(N과 M (9))와 동일합니다. 하지만, 입력으로 주어지는 수열에 숫자가 들어있기만 하면, 주어진 개수와 무관하게 몇 개든 꺼내 결과 수열에 사용할 수 있습니다. 또, 결과 수열은 비 내림차순 이어야 합니다.
결과 수열의 중복 방지하기
집합(set) 자료구조를 활용하면, 결과 수열을 중복해서 출력하는 것을 방지할 수 있습니다.
arr = [0 for _ in range(M)]
result = set()
def backtracking(depth):
if depth == M:
result.add(tuple(arr))
return
...
집합 사용하기
추가로, 입력 수열을 온전히 보전할 필요가 없습니다. 문제의 조건 상 입력 수열에 1이 한 개만 주어지더라도 1을 2개 이상 사용할 수 있으므로, 우리는 입력 수열에 주어진 숫자의 종류만 파악하면 됩니다.
만약, 입력 수열 안에 동일한 수가 반복되어 등장하는 경우 이를 그대로 backtracking 함수 안에서 반복하게 되면 중복되는 결과 수열을 불필요하게 만들어내게 됩니다.
A = list(map(int, input().split())) # 입력 수열
arr = [0 for _ in range(M)]
result = set()
def backtracking(depth):
if depth == M:
result.add(tuple(arr))
return
for num in A:
if num >= arr[depth - 1]:
arr[depth] = num
backtracking(depth + 1)
arr[depth] = 0
backtracking(0)Input:
4 2
9 9 9 1위와 같은 입력 데이터에서, 입력 수열 A를 그대로 반복하게 되면 backtracking 내부의 for문에서 9를 3번이나 반복하게 됩니다. 입력 수열의 중복을 제거한다면, 불필요한 반복을 줄여 코드의 실행 속도를 개선할 수 있습니다.
입력 수열의 중복 제거 또한 집합을 사용하면 간단히 구현할 수 있습니다.
A = set() # 입력 수열을 저장할 집합
for num in map(int, input().split()):
A.add(num)입력 수열도 집합에 저장해 중복을 제거한다.
결과 정렬하기
이전 N과 M (9)에서처럼, 문자열의 숫자 정렬 기준과 정수의 숫자 정렬 기준은 다릅니다.
sorted([1, 3, 2, 11])
>>> [1, 2, 3, 11]
sorted(["1", "3", "2", "11"])
>>> ["1", "11", "2", "3"]문자열의 사전 순 정렬 방식은 숫자의 정렬 순서와 다르다.
따라서, result 집합에 arr배열의 원소를 담은 tuple(Hashable한 배열)을 저장하고, 이후 result집합을 정렬했습니다.
def dfs(depth):
if depth == M:
result.add(tuple(arr))
return
...
for r in sorted(result):
print(" ".join(map(str, r)))tuple로 arr을 변환 후 저장한다.
전체 코드
input = open(0).readline
N, M = map(int, input().split())
A = set()
for num in map(int, input().split()):
A.add(num)
arr = [0 for _ in range(M)]
result = set()
def backtracking(depth):
if depth == M:
result.add(tuple(arr))
return
for num in A:
if num >= arr[depth - 1]:
arr[depth] = num
backtracking(depth + 1)
arr[depth] = 0
backtracking(0)
result = sorted(result)
print("\n".join(" ".join(map(str, r)) for r in result)) # 전체 출력을 하나의 문자열로 만들고 print문을 한 번만 실행하기.solution.py