[PS] BOJ 18436 / 수열과 쿼리 37

[PS] BOJ 18436 / 수열과 쿼리 37
Thumbnail: Photo by 🇸🇮 Janko Ferlič (Unsplash)
문제 링크: https://www.acmicpc.net/problem/18436

세그먼트 트리를 활용해 구간 내 짝수와 홀수의 개수를 각각 저장해주면 됩니다.

풀이

세그먼트 트리?

세그먼트 트리에 대한 설명과 구간 합을 저장하는 세그먼트 이전 글을 참고해 주세요 😄​

세그먼트 트리에 짝수/홀수의 개수 저장하기

​가장 기본적인 세그먼트 트리의 구현은 구간 합을 저장하는 형태였습니다. 홀수와 짝수의 개수를 어떻게 세그먼트 트리에 저장할 수 있을까요?

간단하게, 입력으로 주어진 배열을 통해 세그먼트 트리를 만드는 대신에
배열의 각 수가 홀수인지/짝수인지를 1 또는 0으로 나타낸 배열로 세그먼트 트리를 만들면, 구간 합으로 홀수/짝수의 개수를 구할 수 있습니다!

세그먼트 트리 구현

class SegmentTree:
    """세그먼트 트리 구현체"""
    __slots__ = ("tree", "base_arr", "query_param", "update_param") # 객체를 가볍게 만들기 위해 사용 (Python3)

    def __init__(self, base_arr):
        self.tree = [0 for _ in range(tree_size)]
        self.base_arr = base_arr
        self.init(1, 0, N - 1)
        self.query_param = [0, N - 1]   # start, end
        self.update_param = [0, 0]      # idx, value
    
    def init(self, node, start, end):
        if start == end:
            self.tree[node] = self.base_arr[start]
        else:
            mid = (start + end) // 2
            self.init(node * 2, start, mid)
            self.init(node * 2 + 1, mid + 1, end)
            self.tree[node] = self.tree[node * 2] + self.tree[node * 2 + 1]
    
    def _query(self, node, start, end):
        if start > self.query_param[1] or end < self.query_param[0]: # 범위 밖 (포함 X)
            return 0
        elif start >= self.query_param[0] and end <= self.query_param[1]: # 범위 안에 완전히 포함
            return self.tree[node]
        else:
            mid = (start + end) // 2
            lsum = self._query(node * 2, start, mid)
            rsum = self._query(node * 2 + 1, mid + 1, end)
            return lsum + rsum
    
    def query(self, left, right):
        self.query_param[0] = left
        self.query_param[1] = right
        return self._query(1, 0, N - 1)

    def _update(self, node, start, end):
        if start > self.update_param[0] or end < self.update_param[0]:
            return
        elif start == end and start == self.update_param[0]:
            self.tree[node] = self.update_param[1]
        else:
            mid = (start + end) // 2
            self._update(node * 2, start, mid)
            self._update(node * 2 + 1, mid + 1, end)
            self.tree[node] = self.tree[node * 2] + self.tree[node * 2 + 1]
    
    def update(self, idx, value):
        if self.base_arr[idx] != value: # 홀수/짝수를 구분하는 배열이므로, 수가 바뀌어도 홀짝이 같다면 갱신할 필요가 없다.
            self.base_arr[idx] = value
            self.update_param[0] = idx
            self.update_param[1] = value
            self._update(1, 0, N - 1)

is_odd = [i & 1 for i in arr]
odd_tree = SegmentTree(is_odd)                      # 홀수면 1, 짝수면 0
even_tree = SegmentTree([1 - i for i in is_odd])    # 짝수면 1, 홀수면 0

세그먼트 트리 구현체

​홀수, 짝수를 각각의 세그먼트 트리에 저장하기 위해, 세그먼트 트리를 클래스로 구현 후 사용했습니다.

입력 배열을 변환한 새 배열로 구간 합을 계산하는 방법으로 구현했으니, 세그먼트 트리는 구간 합을 구하고 변경해주는 형태로 구현했습니다.

객체지향의 이점을 살려 재귀 호출의 매개변수를 줄이기 위해, query_paramupdate_param이라는 객체 변수를 만들었습니다.

1) 세그먼트 트리 초기화

class SegmentTree:
    def __init__(self, base_arr):
        self.tree = [0 for _ in range(tree_size)]
        self.base_arr = base_arr
        self.init(1, 0, N - 1)
        self.query_param = [0, N - 1]   # start, end
        self.update_param = [0, 0]      # idx, value
    
    def init(self, node, start, end):
        if start == end:
            self.tree[node] = self.base_arr[start]
        else:
            mid = (start + end) // 2
            self.init(node * 2, start, mid)
            self.init(node * 2 + 1, mid + 1, end)
            self.tree[node] = self.tree[node * 2] + self.tree[node * 2 + 1]

세그먼트 트리를 초기화하는 코드

세그먼트 트리 객체는 참고할 원본 배열(base_arr)을 매개변수로 받아, 세그먼트 트리를 초기화합니다. 객체 생성자(__init__)에서 세그먼트 트리 배열을 초기화해주는 init함수를 함께 호출합니다.

2) 구간 쿼리

class SegmentTree:
    ...
    
    def _query(self, node, start, end):
        if start > self.query_param[1] or end < self.query_param[0]: # 범위 밖 (포함 X)
            return 0
        elif start >= self.query_param[0] and end <= self.query_param[1]: # 범위 안에 완전히 포함
            return self.tree[node]
        else:
            mid = (start + end) // 2
            lsum = self._query(node * 2, start, mid)
            rsum = self._query(node * 2 + 1, mid + 1, end)
            return lsum + rsum
    
    def query(self, left, right):
        self.query_param[0] = left
        self.query_param[1] = right
        return self._query(1, 0, N - 1)

구간 내 최솟값을 구하는 코드

트리의 각 노드에는 해당 구간 내의 최솟값의 인덱스가 저장되어 있습니다.

세그먼트 트리에서 특정 구간에 대한 답을 찾을 때, 발생할 수 있는 경우는 3가지로 나눌 수 있습니다.

  • 현재 찾는 구간 (l, r)에 탐색중인 구간 (s, e)가 온전히 포함되는 경우
    • 현재 트리의 노드 값을 그대로 반환하고 탐색 종료.
  • 현재 찾는 구간 (l, r)에 탐색중인 구간 (s, e)가 아예 포함되지 않는 경우
    • 탐색이 실패했음을 뜻하는 임의의 값을 반환하고 탐색 종료.
    • 구간 합​을 구해야 하므로, 덧셈의 항등원인 0을 반환합니다.
  • 현재 찾는 구간 (l, r)에 탐색중인 구간 (s, e)가 일부만 포함되는 경우
    • 재귀 호출을 통해 탐색 후 결과 종합

query 함수는 입력에서 주어지는 매개변수 2개를 받아서, 실제로 호출해야 하는 query_segment_tree를 호출하도록 구현했습니다.

이때, 탐색할 구간 (left, right)는 객체 변수인 query_param에 저장된 뒤 _query 재귀 호출 내부에서 참조됩니다. (재귀 호출의 매개변수를 줄이기 위함)

3) 요소 변경

class SegmentTree:
    ...

    def _update(self, node, start, end):
        if start > self.update_param[0] or end < self.update_param[0]:
            return
        elif start == end and start == self.update_param[0]:
            self.tree[node] = self.update_param[1]
        else:
            mid = (start + end) // 2
            self._update(node * 2, start, mid)
            self._update(node * 2 + 1, mid + 1, end)
            self.tree[node] = self.tree[node * 2] + self.tree[node * 2 + 1]
    
    def update(self, idx, value):
        if self.base_arr[idx] != value: # 홀수/짝수를 구분하는 배열이므로, 수가 바뀌어도 홀짝이 같다면 갱신할 필요가 없다.
            self.base_arr[idx] = value
            self.update_param[0] = idx
            self.update_param[1] = value
            self._update(1, 0, N - 1)

특정 위치의 요소를 변경하는 코드

변경된 인덱스가 구간에 포함되지 않는다면 재귀 중단, 리프 노드일 경우 트리 노드 자체를 갱신합니다.

그 외의 경우에서는 재귀호출로 자식 노드들을 갱신해준 뒤, 현재 노드는 구간 쿼리와 동일한 방식으로 구간 합 저장합니다.

update 함수는 입력에서 주어지는 매개변수 2개를 받아서, 배열의 원소를 직접 갱신한 뒤 세그먼트 트리를 갱신하는 update_segment_tree를 호출하도록 구현했습니다.
변경된 인덱스와 값은 update_param이라는 객체 변수에 저장되어 _update의 재귀 호출 내부에서 참조합니다.

또한, 홀수/짝수 여부를 저장한 배열 base_arr의 구간 합을 구하고 있으므로 입력 배열 (arr)의 원소가 다른 수로 바뀌었지만 수의 홀짝은 그대로일 수 있습니다. 이 경우, 홀수/짝수의 개수는 변하지 않으므로 _update를 호출하지 않습니다.

전체 코드

from math import ceil, log2

input = open(0).readline

N = int(input())
arr = list(map(int, input().split()))
tree_size = 1 << (ceil(log2(N)) +1)

class SegmentTree:
    """세그먼트 트리 구현체"""
    __slots__ = ("tree", "base_arr", "query_param", "update_param") # 객체를 가볍게 만들기 위해 사용 (Python3)

    def __init__(self, base_arr):
        self.tree = [0 for _ in range(tree_size)]
        self.base_arr = base_arr
        self.init(1, 0, N - 1)
        self.query_param = [0, N - 1]   # start, end
        self.update_param = [0, 0]      # idx, value
    
    def init(self, node, start, end):
        if start == end:
            self.tree[node] = self.base_arr[start]
        else:
            mid = (start + end) // 2
            self.init(node * 2, start, mid)
            self.init(node * 2 + 1, mid + 1, end)
            self.tree[node] = self.tree[node * 2] + self.tree[node * 2 + 1]
    
    def _query(self, node, start, end):
        if start > self.query_param[1] or end < self.query_param[0]: # 범위 밖 (포함 X)
            return 0
        elif start >= self.query_param[0] and end <= self.query_param[1]: # 범위 안에 완전히 포함
            return self.tree[node]
        else:
            mid = (start + end) // 2
            lsum = self._query(node * 2, start, mid)
            rsum = self._query(node * 2 + 1, mid + 1, end)
            return lsum + rsum
    
    def query(self, left, right):
        self.query_param[0] = left
        self.query_param[1] = right
        return self._query(1, 0, N - 1)

    def _update(self, node, start, end):
        if start > self.update_param[0] or end < self.update_param[0]:
            return
        elif start == end and start == self.update_param[0]:
            self.tree[node] = self.update_param[1]
        else:
            mid = (start + end) // 2
            self._update(node * 2, start, mid)
            self._update(node * 2 + 1, mid + 1, end)
            self.tree[node] = self.tree[node * 2] + self.tree[node * 2 + 1]
    
    def update(self, idx, value):
        if self.base_arr[idx] != value: # 홀수/짝수를 구분하는 배열이므로, 수가 바뀌어도 홀짝이 같다면 갱신할 필요가 없다.
            self.base_arr[idx] = value
            self.update_param[0] = idx
            self.update_param[1] = value
            self._update(1, 0, N - 1)

is_odd = [i & 1 for i in arr]
odd_tree = SegmentTree(is_odd)                      # 홀수면 1, 짝수면 0
even_tree = SegmentTree([1 - i for i in is_odd])    # 짝수면 1, 홀수면 0

for _ in range(int(input())): # 쿼리 입력
    op, i, j = map(int, input().split())
    if op == 1:     # update i x
        i -= 1
        arr[i] = j
        d = j & 1 # j가 홀수면 1, 짝수면 0
        odd_tree.update(i, d)
        even_tree.update(i, 1 - d) # d가 1이면 0, 0이면 1이 된다.

    elif op == 2:   # query(even) l r
        print(even_tree.query(i - 1, j - 1))

    else:           # query(odd) l r
        print(odd_tree.query(i - 1, j - 1))

solution.py