[PS] BOJ 3224 / grupe

[PS] BOJ 3224 / grupe
문제 링크: https://www.acmicpc.net/problem/3224
Thumbnail: Photo by Annie Williams (Unsplash)

며칠 고심하면서 이해해봤는데, 그만큼 재밌는 문제였습니다.

풀이

스택 (Stack)

기본적인 풀이는 스택을 이용합니다. 고정 크기의 스택을 먼저 구현해줍시다.

class Stack:
    """Fixed size stack implementation."""
    def __init__(self, max_size, default_value = 0):
        self.stack = [default_value for _ in range(max_size)]
        self.size = 0

    def push(self, item):
        if self.size + 1 > len(self.stack):
            raise OverflowError("push() to full stack")
        self.stack[self.size] = item
        self.size += 1

    def pop(self):
        if self.size == 0:
            raise IndexError("pop() from empty stack")
        self.size -= 1
        return self.stack[self.size]

    def top(self):
        if self.size == 0:
            raise IndexError("top() from empty stack")
        return self.stack[self.size - 1]

    def is_empty(self):
        return self.size == 0

    def clear(self):
        self.size = 0

그룹을 합치는 규칙

두 그룹을 합칠 수 있는 조건은 다음과 같습니다:

  • 두 그룹의 수가 연속할 수 있으면 됩니다
    • [1 2][3 4][1 2 3 4]로, 합친 이후의 그룹도 수가 연속하기 때문에 가능합니다.
    • [2 3][5 4][2 3 5 4] 합칠 수 있습니다. 그룹 내의 수가 연속해서 있을 필요는 없고, 전체 수의 범위가 하나로 연속할 수 있으면 됩니다.
    • [2 3][6 7]은 수의 범위가 연속하지 않으므로 합칠 수 없습니다.

결국 우리가 그룹을 관리하기 위해 필요한 정보는, 그룹 내의 최대값 및 최솟값 뿐입니다.

데이터 구성

기본적으로, 초기 입력으로 주어진 수의 배열과 각 인덱스에 대응하는 그룹의 정보를 저장할 배열이 필요합니다.

  • 숫자 배열(numbers) 의 경우, 초기에 주어진 수를 저장합니다.
  • 그룹 배열(groups) 의 경우, 숫자 배열(numbers) 와 같은 인덱스에 해당 수의 그룹 정보를 저장합니다.
  • 결과 배열(results)의 경우, 두 그룹을 합칠 때 마다 합쳐진 그룹의 "{최소값} {최대값}"을 저장합니다. 이후 결과 출력에 사용합니다.
input = open(0).readline
N = int(input())
numbers = [0 for _ in range(N)]
groups = [[0, 0] for _ in range(N)]
results = []

# 초기 상태에는 모든 숫자가 각자 자신의 그룹에 속해 있다.
for idx, num in zip(range(N), map(int, input().split())):
    numbers[idx] = num
    groups[idx][0] = groups[idx][1] = num

stack = Stack(N, 0)
stack.push(0)   # 스택에 첫번째 원소의 인덱스를 넣는다.

Initial variable declaration & initialization.

그룹 합치기

기본적인 규칙은 한가지입니다.

  • 배열에서 다음 수를 가져옵니다.
    • 스택에서 수를 꺼내, 현재 수와 합칠 수 있다면,
      • 합쳐진 그룹의 최대값, 최솟값을 갱신합니다.
      • 스택에서 꺼낸 수와 현재 배열에서 가져온 수의 그룹 정보를 갱신합니다.
    • 합칠 수 없다면 스택에 현재 수를 집어넣고 다시 반복합니다.
def merge_groups(idx1, idx2, min, max):
    """두 그룹을 병합하고 결과를 기록한다."""
    groups[idx1][0] = groups[idx2][0] = min
    groups[idx1][1] = groups[idx2][1] = max
    results.append(f"{min} {max}")
    # print(f">>> merge! ({min}, {max})")

for idx in range(1, N):   # 먼저 배열을 순회하며 가능한 만큼 병합하고 나머지는 스택에 쌓는다.
    while not stack.is_empty(): # 다음 원소가 현재 그룹과 연속되는 수인 경우 병합할 수 있다.
        min, max = groups[stack.top()]
        # print(f"comp ({numbers[stack.top()]}, {numbers[idx]})")
        # print(f"- stack.top(): [{min}, {max}]")
        # print(f"- from array: {groups[idx]}")
        if abs(min - groups[idx][0]) == 1 or abs(max - groups[idx][1]) == 1:
            if min > groups[idx][0]:
                min = groups[idx][0]
            if max < groups[idx][1]:
                max = groups[idx][1]
            prev = stack.pop()
            merge_groups(prev, idx, min, max)
        elif abs(min - groups[idx][1]) == 1:
            min = groups[idx][0]
            prev = stack.pop()
            merge_groups(prev, idx, min, max)
        elif abs(max - groups[idx][0]) == 1:
            max = groups[idx][1]
            prev = stack.pop()
            merge_groups(prev, idx, min, max)
        else:
            break
    # print(f">>> push! {numbers[idx]} ({groups[idx]})")
    stack.push(idx)

Group merge algorithm.

결과 출력

반복문이 끝난 뒤, 초기 상태의 수를 하나의 그룹으로 합칠 수 있는지 판단하는 기준은 현재 스택의 상태입니다.

  • 현재 스택에 원소가 단 1개만 있다면:
    • 모두 합쳐지고 남은 하나의 그룹입니다. 따라서, 초기 상태를 합칠 수 있던 것이므로 "DA"를 출력한 뒤 results배열에 저장해둔 병합 과정을 출력합니다.
  • 그렇지 않다면:
    • 합칠 수 없는 경우이므로 "NE"를 출력합니다.
if stack.size > 1:
    print("NE")
else:
    print("DA")
    print("\n".join(map(str, results)))

Print result.

전체 코드

class Stack:
    """Fixed size stack implementation."""
    def __init__(self, max_size, default_value = 0):
        self.stack = [default_value for _ in range(max_size)]
        self.size = 0

    def push(self, item):
        if self.size + 1 > len(self.stack):
            raise OverflowError("push() to full stack")
        self.stack[self.size] = item
        self.size += 1

    def pop(self):
        if self.size == 0:
            raise IndexError("pop() from empty stack")
        self.size -= 1
        return self.stack[self.size]

    def top(self):
        if self.size == 0:
            raise IndexError("top() from empty stack")
        return self.stack[self.size - 1]

    def is_empty(self):
        return self.size == 0

    def clear(self):
        self.size = 0

input = open(0).readline
N = int(input())
numbers = [0 for _ in range(N)]
groups = [[0, 0] for _ in range(N)]
results = []

def merge_groups(idx1, idx2, min, max):
    """두 그룹을 병합하고 결과를 기록한다."""
    groups[idx1][0] = groups[idx2][0] = min
    groups[idx1][1] = groups[idx2][1] = max
    results.append(f"{min} {max}")
    # print(f">>> merge! ({min}, {max})")

# 초기 상태에는 모든 숫자가 각자 자신의 그룹에 속해 있다.
for idx, num in zip(range(N), map(int, input().split())):
    numbers[idx] = num
    groups[idx][0] = groups[idx][1] = num

stack = Stack(N, 0)
stack.push(0)   # 스택에 첫번째 원소의 인덱스를 넣는다.
for idx in range(1, N):   # 먼저 배열을 순회하며 가능한 만큼 병합하고 나머지는 스택에 쌓는다.
    while not stack.is_empty(): # 다음 원소가 현재 그룹과 연속되는 수인 경우 병합할 수 있다.
        min, max = groups[stack.top()]
        # print(f"comp ({numbers[stack.top()]}, {numbers[idx]})")
        # print(f"- stack.top(): [{min}, {max}]")
        # print(f"- from array: {groups[idx]}")
        if abs(min - groups[idx][0]) == 1 or abs(max - groups[idx][1]) == 1:
            if min > groups[idx][0]:
                min = groups[idx][0]
            if max < groups[idx][1]:
                max = groups[idx][1]
            prev = stack.pop()
            merge_groups(prev, idx, min, max)
        elif abs(min - groups[idx][1]) == 1:
            min = groups[idx][0]
            prev = stack.pop()
            merge_groups(prev, idx, min, max)
        elif abs(max - groups[idx][0]) == 1:
            max = groups[idx][1]
            prev = stack.pop()
            merge_groups(prev, idx, min, max)
        else:
            break
    # print(f">>> push! {numbers[idx]} ({groups[idx]})")
    stack.push(idx)

if stack.size > 1:
    print("NE")
else:
    print("DA")
    print("\n".join(map(str, results)))

solution.py