[PS] BOJ 14428 / 수열과 쿼리 16

[PS] BOJ 14428 / 수열과 쿼리 16
Thumbnail: Photo by Markus Winkler (Unsplash)
문제 링크: https://www.acmicpc.net/problem/14428

세그먼트 트리를 활용하는 기본적인 문제입니다. 다만, 세그먼트 트리에서 구간 내 최솟값이 아닌, 최솟값의 "인덱스"를 조회해야 합니다. 이를 위해 트리의 노드에 값을 어떻게 저장할지 고민이 필요했습니다.

풀이

세그먼트 트리?

세그먼트 트리에 대한 설명과 구간 합을 저장하는 세그먼트 이전 글을 참고해 주세요 😄

구간 합 대신에 최솟값 저장하기?

이전에 세그먼트 트리를 통해 구간 내 최솟값을 구하는 문제를 풀 때는 세그먼트 트리의 각 노드에 해당하는 구간 내의 최솟값 자체를 저장해 풀었습니다.

하지만, 이번 문제는 최솟값이 아닌 최솟값의 인덱스를 구해야 합니다. 게다가, 최솟값이 구간 내에 여러 개 중복으로 존재할 수 있으며 이 경우 최솟값의 인덱스 중 가장 작은 값을 찾아야 합니다.

하지만, 트리의 각 노드에 해당 구간 내의 최솟값의 인덱스 (최솟값이 여러 개 있는 경우, 가장 작은 인덱스)를 저장하는 방식으로 구현하면 트리의 노드에 값을 저장하고 읽는 부분만 수정해 기존 세그먼트 트리의 구조를 그대로 사용할 수 있습니다.

1) 세그먼트 트리 초기화

def build_segment_tree(node, start, end):
    if start == end:
        tree[node] = start # 세그먼트 트리에는 인덱스 자체를 저장하기
    else:
        mid = (start + end) // 2
        build_segment_tree(node * 2, start, mid)
        build_segment_tree(node * 2 + 1, mid + 1, end)

        # 항상 구간 내 최솟값의 인덱스를 트리의 노드에 저장해준다.
        if arr[tree[node * 2]] > arr[tree[node * 2 + 1]]:
            tree[node] = tree[node * 2 + 1]
        elif arr[tree[node * 2]] < arr[tree[node * 2 + 1]]:
            tree[node] = tree[node * 2]
        else:
            tree[node] = min(tree[node * 2], tree[node * 2 + 1])

트리 배열을 초기화하는 함수

리프 노드의 경우 인덱스에 해당하는 start값을 그대로 저장하고, 이외의 모든 노드에는 두 자식 노드에 저장된 인덱스에 해당하는 배열 값을 비교해, "최솟값의 인덱스"를 찾아 저장하도록 구현했습니다.

2) 구간 쿼리

def query_segment_tree(node, start, end, left, right):
    if start >= left and end <= right:
        return tree[node]
    elif start > right or end < left:
        return N
    else:
        mid = (start + end) // 2
        left_res = query_segment_tree(node * 2, start, mid, left, right)
        right_res = query_segment_tree(node * 2 + 1, mid + 1, end, left, right)

        # 항상 구간 내 최솟값의 인덱스를 구한다.
        if left_res == N:
            if right_res == N:
                # NO!!!
                return -1
            return right_res
        elif right_res == N:
            return left_res

        if arr[left_res] > arr[right_res]:
            return right_res
        elif arr[left_res] < arr[right_res]:
            return left_res
        else: # 두 인덱스가 같은 수를 가리킬 경우, 더 작은 인덱스를 답으로 사용한다.
            return min(left_res, right_res)

def query(left, right):
    return query_segment_tree(1, 0, N - 1, left - 1, right - 1)

구간 내 최솟값을 구하는 함수

트리의 각 노드에는 해당 구간 내의 최솟값의 인덱스가 저장되어 있습니다.

세그먼트 트리에서 특정 구간에 대한 답을 찾을 때, 발생할 수 있는 경우는 3가지로 나눌 수 있습니다.

  • 현재 찾는 구간 (l, r)에 탐색중인 구간 (s, e)가 온전히 포함되는 경우
    • 현재 트리의 노드 값을 그대로 반환하고 탐색 종료.
  • 현재 찾는 구간 (l, r)에 탐색중인 구간 (s, e)가 아예 포함되지 않는 경우
    • 탐색이 실패했음을 뜻하는 임의의 값을 반환하고 탐색 종료.
    • 구간 합 또는 곱을 구할 때는 해당하는 연산의 항등원을 사용했는데, 이번 문제의 경우 "최솟값의 인덱스"를 구하는 연산이므로 항등원으로 쓸 값을 찾지 못했습니다.
      따라서, 현재 탐색중인 구간에 변경된 인덱스가 포함되지 않았다면 배열의 인덱스를 초과하는 값인 $N$을 반환하고, 이후 재귀 호출로 결과를 종합할 때 한 쪽 결과가 $N$이라면 $N$이 아닌 결과를 답으로 사용하도록 구현했습니다.
  • 현재 찾는 구간 (l, r)에 탐색중인 구간 (s, e)가 일부만 포함되는 경우
    • 재귀 호출을 통해 탐색 후 결과 종합

query 함수는 입력에서 주어지는 매개변수 2개를 받아서, 실제로 호출해야 하는 query_segment_tree에 매개변수를 채워서 호출하도록 구현했습니다.

3) 요소 변경

def update_segment_tree(node, start, end, index):
    if start == end and start == index:
        tree[node] = index
    elif start > index or end < index: # index가 포함되지 않는 구간
        return
    else:
        mid = (start + end) // 2
        update_segment_tree(node * 2, start, mid, index)
        update_segment_tree(node * 2 + 1, mid + 1, end, index)

        # 항상 구간 내 최솟값의 인덱스를 트리의 노드에 저장해준다.
        if arr[tree[node * 2]] > arr[tree[node * 2 + 1]]:
            tree[node] = tree[node * 2 + 1]
        elif arr[tree[node * 2]] < arr[tree[node * 2 + 1]]:
            tree[node] = tree[node * 2]
        else:
            tree[node] = min(tree[node * 2], tree[node * 2 + 1])

def update(index, value):
    arr[index - 1] = value
    update_segment_tree(1, 0, N - 1, index - 1)

특정 위치의 요소를 변경하는 코드

변경된 인덱스가 구간에 포함되지 않는다면 재귀 중단, 리프 노드일 경우 트리 노드 자체를 갱신합니다.

그 외의 경우에서는 재귀호출로 자식 노드들을 갱신해준 뒤, 현재 노드는 구간 쿼리와 동일한 방식으로 최솟값의 인덱스를 찾아 저장합니다.

update 함수는 입력에서 주어지는 매개변수 2개를 받아서, 배열의 원소를 직접 갱신한 뒤 세그먼트 트리를 갱신하는 update_segment_tree를 호출하도록 구현했습니다.

전체 코드

from math import ceil, log2

input = open(0).readline
NUMBER_MAX = 1_000_000_001

N = int(input())
arr = list(map(int, input().split()))
tree = [0] * (1 << ceil(log2(N)) + 1)

def build_segment_tree(node, start, end):
    if start == end:
        tree[node] = start # 세그먼트 트리에는 인덱스 자체를 저장하기
    else:
        mid = (start + end) // 2
        build_segment_tree(node * 2, start, mid)
        build_segment_tree(node * 2 + 1, mid + 1, end)

        # 항상 구간 내 최솟값의 인덱스를 트리의 노드에 저장해준다.
        if arr[tree[node * 2]] > arr[tree[node * 2 + 1]]:
            tree[node] = tree[node * 2 + 1]
        elif arr[tree[node * 2]] < arr[tree[node * 2 + 1]]:
            tree[node] = tree[node * 2]
        else:
            tree[node] = min(tree[node * 2], tree[node * 2 + 1])

def query_segment_tree(node, start, end, left, right):
    if start >= left and end <= right:
        return tree[node]
    elif start > right or end < left:
        return N
    else:
        mid = (start + end) // 2
        left_res = query_segment_tree(node * 2, start, mid, left, right)
        right_res = query_segment_tree(node * 2 + 1, mid + 1, end, left, right)

        # 항상 구간 내 최솟값의 인덱스를 구한다.
        if left_res == N:
            if right_res == N:
                # NO!!!
                return -1
            return right_res
        elif right_res == N:
            return left_res

        if arr[left_res] > arr[right_res]:
            return right_res
        elif arr[left_res] < arr[right_res]:
            return left_res
        else: # 두 인덱스가 같은 수를 가리킬 경우, 더 작은 인덱스를 답으로 사용한다.
            return min(left_res, right_res)

def query(left, right):
    return query_segment_tree(1, 0, N - 1, left - 1, right - 1)

def update_segment_tree(node, start, end, index):
    if start == end and start == index:
        tree[node] = index
    elif start > index or end < index: # index가 포함되지 않는 구간
        return
    else:
        mid = (start + end) // 2
        update_segment_tree(node * 2, start, mid, index)
        update_segment_tree(node * 2 + 1, mid + 1, end, index)

        # 항상 구간 내 최솟값의 인덱스를 트리의 노드에 저장해준다.
        if arr[tree[node * 2]] > arr[tree[node * 2 + 1]]:
            tree[node] = tree[node * 2 + 1]
        elif arr[tree[node * 2]] < arr[tree[node * 2 + 1]]:
            tree[node] = tree[node * 2]
        else:
            tree[node] = min(tree[node * 2], tree[node * 2 + 1])

def update(index, value):
    arr[index - 1] = value
    update_segment_tree(1, 0, N - 1, index - 1)

build_segment_tree(1, 0, N - 1)

for _ in range(int(input())):
    op, *args = map(int, input().split())
    if op == 1:
        update(*args)
    else:
        print(query(*args) + 1)

solution.py