[PS] BOJ 18436 / 수열과 쿼리 37
문제 링크: https://www.acmicpc.net/problem/18436
세그먼트 트리를 활용해 구간 내 짝수와 홀수의 개수를 각각 저장해주면 됩니다.
풀이
세그먼트 트리?
세그먼트 트리에 대한 설명과 구간 합을 저장하는 세그먼트 이전 글을 참고해 주세요 😄
세그먼트 트리에 짝수/홀수의 개수 저장하기
가장 기본적인 세그먼트 트리의 구현은 구간 합을 저장하는 형태였습니다. 홀수와 짝수의 개수를 어떻게 세그먼트 트리에 저장할 수 있을까요?
간단하게, 입력으로 주어진 배열을 통해 세그먼트 트리를 만드는 대신에
배열의 각 수가 홀수인지/짝수인지를 1 또는 0으로 나타낸 배열로 세그먼트 트리를 만들면, 구간 합으로 홀수/짝수의 개수를 구할 수 있습니다!
세그먼트 트리 구현
class SegmentTree:
"""세그먼트 트리 구현체"""
__slots__ = ("tree", "base_arr", "query_param", "update_param") # 객체를 가볍게 만들기 위해 사용 (Python3)
def __init__(self, base_arr):
self.tree = [0 for _ in range(tree_size)]
self.base_arr = base_arr
self.init(1, 0, N - 1)
self.query_param = [0, N - 1] # start, end
self.update_param = [0, 0] # idx, value
def init(self, node, start, end):
if start == end:
self.tree[node] = self.base_arr[start]
else:
mid = (start + end) // 2
self.init(node * 2, start, mid)
self.init(node * 2 + 1, mid + 1, end)
self.tree[node] = self.tree[node * 2] + self.tree[node * 2 + 1]
def _query(self, node, start, end):
if start > self.query_param[1] or end < self.query_param[0]: # 범위 밖 (포함 X)
return 0
elif start >= self.query_param[0] and end <= self.query_param[1]: # 범위 안에 완전히 포함
return self.tree[node]
else:
mid = (start + end) // 2
lsum = self._query(node * 2, start, mid)
rsum = self._query(node * 2 + 1, mid + 1, end)
return lsum + rsum
def query(self, left, right):
self.query_param[0] = left
self.query_param[1] = right
return self._query(1, 0, N - 1)
def _update(self, node, start, end):
if start > self.update_param[0] or end < self.update_param[0]:
return
elif start == end and start == self.update_param[0]:
self.tree[node] = self.update_param[1]
else:
mid = (start + end) // 2
self._update(node * 2, start, mid)
self._update(node * 2 + 1, mid + 1, end)
self.tree[node] = self.tree[node * 2] + self.tree[node * 2 + 1]
def update(self, idx, value):
if self.base_arr[idx] != value: # 홀수/짝수를 구분하는 배열이므로, 수가 바뀌어도 홀짝이 같다면 갱신할 필요가 없다.
self.base_arr[idx] = value
self.update_param[0] = idx
self.update_param[1] = value
self._update(1, 0, N - 1)
is_odd = [i & 1 for i in arr]
odd_tree = SegmentTree(is_odd) # 홀수면 1, 짝수면 0
even_tree = SegmentTree([1 - i for i in is_odd]) # 짝수면 1, 홀수면 0세그먼트 트리 구현체
홀수, 짝수를 각각의 세그먼트 트리에 저장하기 위해, 세그먼트 트리를 클래스로 구현 후 사용했습니다.
입력 배열을 변환한 새 배열로 구간 합을 계산하는 방법으로 구현했으니, 세그먼트 트리는 구간 합을 구하고 변경해주는 형태로 구현했습니다.
객체지향의 이점을 살려 재귀 호출의 매개변수를 줄이기 위해, query_param과 update_param이라는 객체 변수를 만들었습니다.
1) 세그먼트 트리 초기화
class SegmentTree:
def __init__(self, base_arr):
self.tree = [0 for _ in range(tree_size)]
self.base_arr = base_arr
self.init(1, 0, N - 1)
self.query_param = [0, N - 1] # start, end
self.update_param = [0, 0] # idx, value
def init(self, node, start, end):
if start == end:
self.tree[node] = self.base_arr[start]
else:
mid = (start + end) // 2
self.init(node * 2, start, mid)
self.init(node * 2 + 1, mid + 1, end)
self.tree[node] = self.tree[node * 2] + self.tree[node * 2 + 1]세그먼트 트리를 초기화하는 코드
세그먼트 트리 객체는 참고할 원본 배열(base_arr)을 매개변수로 받아, 세그먼트 트리를 초기화합니다. 객체 생성자(__init__)에서 세그먼트 트리 배열을 초기화해주는 init함수를 함께 호출합니다.
2) 구간 쿼리
class SegmentTree:
...
def _query(self, node, start, end):
if start > self.query_param[1] or end < self.query_param[0]: # 범위 밖 (포함 X)
return 0
elif start >= self.query_param[0] and end <= self.query_param[1]: # 범위 안에 완전히 포함
return self.tree[node]
else:
mid = (start + end) // 2
lsum = self._query(node * 2, start, mid)
rsum = self._query(node * 2 + 1, mid + 1, end)
return lsum + rsum
def query(self, left, right):
self.query_param[0] = left
self.query_param[1] = right
return self._query(1, 0, N - 1)구간 내 최솟값을 구하는 코드
트리의 각 노드에는 해당 구간 내의 최솟값의 인덱스가 저장되어 있습니다.
세그먼트 트리에서 특정 구간에 대한 답을 찾을 때, 발생할 수 있는 경우는 3가지로 나눌 수 있습니다.
- 현재 찾는 구간 (l, r)에 탐색중인 구간 (s, e)가 온전히 포함되는 경우
- 현재 트리의 노드 값을 그대로 반환하고 탐색 종료.
- 현재 찾는 구간 (l, r)에 탐색중인 구간 (s, e)가 아예 포함되지 않는 경우
- 탐색이 실패했음을 뜻하는 임의의 값을 반환하고 탐색 종료.
- 구간 합을 구해야 하므로, 덧셈의 항등원인 0을 반환합니다.
- 현재 찾는 구간 (l, r)에 탐색중인 구간 (s, e)가 일부만 포함되는 경우
- 재귀 호출을 통해 탐색 후 결과 종합
query 함수는 입력에서 주어지는 매개변수 2개를 받아서, 실제로 호출해야 하는 query_segment_tree를 호출하도록 구현했습니다.
이때, 탐색할 구간 (left, right)는 객체 변수인 query_param에 저장된 뒤 _query 재귀 호출 내부에서 참조됩니다. (재귀 호출의 매개변수를 줄이기 위함)
3) 요소 변경
class SegmentTree:
...
def _update(self, node, start, end):
if start > self.update_param[0] or end < self.update_param[0]:
return
elif start == end and start == self.update_param[0]:
self.tree[node] = self.update_param[1]
else:
mid = (start + end) // 2
self._update(node * 2, start, mid)
self._update(node * 2 + 1, mid + 1, end)
self.tree[node] = self.tree[node * 2] + self.tree[node * 2 + 1]
def update(self, idx, value):
if self.base_arr[idx] != value: # 홀수/짝수를 구분하는 배열이므로, 수가 바뀌어도 홀짝이 같다면 갱신할 필요가 없다.
self.base_arr[idx] = value
self.update_param[0] = idx
self.update_param[1] = value
self._update(1, 0, N - 1)특정 위치의 요소를 변경하는 코드
변경된 인덱스가 구간에 포함되지 않는다면 재귀 중단, 리프 노드일 경우 트리 노드 자체를 갱신합니다.
그 외의 경우에서는 재귀호출로 자식 노드들을 갱신해준 뒤, 현재 노드는 구간 쿼리와 동일한 방식으로 구간 합 저장합니다.
update 함수는 입력에서 주어지는 매개변수 2개를 받아서, 배열의 원소를 직접 갱신한 뒤 세그먼트 트리를 갱신하는 update_segment_tree를 호출하도록 구현했습니다.
변경된 인덱스와 값은 update_param이라는 객체 변수에 저장되어 _update의 재귀 호출 내부에서 참조합니다.
또한, 홀수/짝수 여부를 저장한 배열 base_arr의 구간 합을 구하고 있으므로 입력 배열 (arr)의 원소가 다른 수로 바뀌었지만 수의 홀짝은 그대로일 수 있습니다. 이 경우, 홀수/짝수의 개수는 변하지 않으므로 _update를 호출하지 않습니다.
전체 코드
from math import ceil, log2
input = open(0).readline
N = int(input())
arr = list(map(int, input().split()))
tree_size = 1 << (ceil(log2(N)) +1)
class SegmentTree:
"""세그먼트 트리 구현체"""
__slots__ = ("tree", "base_arr", "query_param", "update_param") # 객체를 가볍게 만들기 위해 사용 (Python3)
def __init__(self, base_arr):
self.tree = [0 for _ in range(tree_size)]
self.base_arr = base_arr
self.init(1, 0, N - 1)
self.query_param = [0, N - 1] # start, end
self.update_param = [0, 0] # idx, value
def init(self, node, start, end):
if start == end:
self.tree[node] = self.base_arr[start]
else:
mid = (start + end) // 2
self.init(node * 2, start, mid)
self.init(node * 2 + 1, mid + 1, end)
self.tree[node] = self.tree[node * 2] + self.tree[node * 2 + 1]
def _query(self, node, start, end):
if start > self.query_param[1] or end < self.query_param[0]: # 범위 밖 (포함 X)
return 0
elif start >= self.query_param[0] and end <= self.query_param[1]: # 범위 안에 완전히 포함
return self.tree[node]
else:
mid = (start + end) // 2
lsum = self._query(node * 2, start, mid)
rsum = self._query(node * 2 + 1, mid + 1, end)
return lsum + rsum
def query(self, left, right):
self.query_param[0] = left
self.query_param[1] = right
return self._query(1, 0, N - 1)
def _update(self, node, start, end):
if start > self.update_param[0] or end < self.update_param[0]:
return
elif start == end and start == self.update_param[0]:
self.tree[node] = self.update_param[1]
else:
mid = (start + end) // 2
self._update(node * 2, start, mid)
self._update(node * 2 + 1, mid + 1, end)
self.tree[node] = self.tree[node * 2] + self.tree[node * 2 + 1]
def update(self, idx, value):
if self.base_arr[idx] != value: # 홀수/짝수를 구분하는 배열이므로, 수가 바뀌어도 홀짝이 같다면 갱신할 필요가 없다.
self.base_arr[idx] = value
self.update_param[0] = idx
self.update_param[1] = value
self._update(1, 0, N - 1)
is_odd = [i & 1 for i in arr]
odd_tree = SegmentTree(is_odd) # 홀수면 1, 짝수면 0
even_tree = SegmentTree([1 - i for i in is_odd]) # 짝수면 1, 홀수면 0
for _ in range(int(input())): # 쿼리 입력
op, i, j = map(int, input().split())
if op == 1: # update i x
i -= 1
arr[i] = j
d = j & 1 # j가 홀수면 1, 짝수면 0
odd_tree.update(i, d)
even_tree.update(i, 1 - d) # d가 1이면 0, 0이면 1이 된다.
elif op == 2: # query(even) l r
print(even_tree.query(i - 1, j - 1))
else: # query(odd) l r
print(odd_tree.query(i - 1, j - 1))
solution.py