[PS] BOJ 2357 / 최솟값과 최댓값
문제 링크: https://www.acmicpc.net/problem/2357
세그먼트 트리를 활용하는 기본적인 문제입니다. 다만, 구간 내 최솟값과 최댓값을 동시에 구해야 합니다.
풀이
세그먼트 트리?
세그먼트 트리에 대한 설명과 구간 합을 저장하는 세그먼트 이전 글을 참고해 주세요 😄
각 구간의 최솟값과 최댓값을 모두 구하기?
세그먼트 트리를 활용해 각 구간의 최솟값과 최댓값을 동시에 구해야 합니다. 트리 하나에 최솟값과 최댓값을 동시에 저장할 수도 있으나, 그냥 2개의 세그먼트 트리를 구현하는 것으로 해결했습니다.
1) 세그먼트 트리 구현
최솟값과 최댓값을 저장하는 세그먼트 트리는 기본 골격이 동일합니다. 단지, 트리의 노드에 저장할 값을 결정할 때 사용할 비교 구문과 구간 쿼리 중 구간에 포함되지 않는 경우 어떤 값을 반환할 지만 서로 다를 뿐입니다.
따라서, 비교 구문을 cmp라는 이름으로 세그먼트 트리 객체에 저장하고 사용했습니다. cmp는 2개의 매개변수를 비교해, 둘 중 하나를 반환하는 함수를 받습니다.
또, 구간 쿼리 중 구간에 포함되지 않는 경우 반환할 값은 fallback_value의 이름으로 객체 내부에 저장하고 사용했습니다.
최댓값을 저장할 세그먼트 트리의 경우, 비교 구문은 max 함수를, fallback_value는 기존 배열에 저장된 숫자 범위의 최솟값보다 더 작은 임의의 수를 사용하면 됩니다. 여기서는 0을 사용했습니다. (NUMBER_MIN)
최솟값을 저장할 세그먼트 트리의 경우, 비교 구문은 min 함수를, fallback_value는 기존 배열에 저장된 숫자 범위의 최댓값보다 더 큰 임의의 수를 사용하면 됩니다.
여기서는 1,000,000,001을 사용했습니다. (NUMBER_MAX)
문제에서는 구간 내에서 최솟값/최댓값을 찾는 입력만 있고, 배열을 갱신하는 입력은 없습니다. 따라서, query 함수만 구현했습니다. 클래스로 구현해 얻을 수 있는 이점을 살려, 구간 내에서 값을 찾을 시 찾는 구간 범위를 객체에 저장해 재귀 호출의 매개변수를 줄였습니다.
# Constants
N, M = map(int, input().split())
tree_height = ceil(log2(N))
tree_size = 1 << (tree_height + 1)
NUMBER_MAX = 1_000_000_001
NUMBER_MIN = 0
# Original Data
arr = list(int(input()) for _ in range(N))
# Segment Tree Impl
class QueryParam:
"""재귀 호출의 중복 인수를 제거하기 위한 데이터 객체"""
__slots__ = ("left", "right")
NONE = 0
ALL = 1
PARTIAL = 2
def __init__(self):
self.left = 0
self.right = N - 1
def set(self, left, right):
self.left = left
self.right = right
def check_range(self, start, end):
if start > self.right or end < self.left: # 포함 X
return QueryParam.NONE
elif start >= self.left and end <= self.right: # 완전히 포함됨
return QueryParam.ALL
else:
return QueryParam.PARTIAL
class CmpSegmentTree:
"""특정한 비교 조건을 통해 값을 저장하는 세그먼트 트리"""
__slots__ = ("tree", "query_param", "cmp", "fallback_value")
def __init__(self, cmp, fallback_value):
self.tree = [0] * tree_size
self.query_param = QueryParam()
self.cmp = cmp
self.fallback_value = fallback_value
def init(self, node, start, end):
if start == end:
self.tree[node] = arr[start]
else:
mid = (start + end) // 2
self.init(node * 2, start, mid)
self.init(node * 2 + 1, mid + 1, end)
self.tree[node] = self.cmp(self.tree[node * 2], self.tree[node * 2 + 1])
def _query_tree(self, node, start, end):
check = self.query_param.check_range(start, end)
if check == QueryParam.NONE: # 범위 밖 (포함 X)
return self.fallback_value
elif check == QueryParam.ALL: # 범위 안에 완전히 포함됨.
return self.tree[node]
else:
mid = (start + end) // 2
left_res = self._query_tree(node * 2, start, mid)
right_res = self._query_tree(node * 2 + 1, mid + 1, end)
return self.cmp(left_res, right_res)
def query(self, left, right):
self.query_param.set(left, right)
return self._query_tree(1, 0, N - 1)
max_segment_tree = CmpSegmentTree(max, NUMBER_MIN)
min_segment_tree = CmpSegmentTree(min, NUMBER_MAX)
max_segment_tree.init(1, 0, N - 1)
min_segment_tree.init(1, 0, N - 1)세그먼트 트리 구현
전체 코드
from math import ceil, log2
input = open(0).readline
# Constants
N, M = map(int, input().split())
tree_height = ceil(log2(N))
tree_size = 1 << (tree_height + 1)
NUMBER_MAX = 1_000_000_001
NUMBER_MIN = 0
# Original Data
arr = list(int(input()) for _ in range(N))
# Segment Tree Impl
class QueryParam:
"""재귀 호출의 중복 인수를 제거하기 위한 데이터 객체"""
__slots__ = ("left", "right")
NONE = 0
ALL = 1
PARTIAL = 2
def __init__(self):
self.left = 0
self.right = N - 1
def set(self, left, right):
self.left = left
self.right = right
def check_range(self, start, end):
if start > self.right or end < self.left: # 포함 X
return QueryParam.NONE
elif start >= self.left and end <= self.right: # 완전히 포함됨
return QueryParam.ALL
else:
return QueryParam.PARTIAL
class CmpSegmentTree:
"""특정한 비교 조건을 통해 값을 저장하는 세그먼트 트리"""
__slots__ = ("tree", "query_param", "cmp", "fallback_value")
def __init__(self, cmp, fallback_value):
self.tree = [0] * tree_size
self.query_param = QueryParam()
self.cmp = cmp
self.fallback_value = fallback_value
def init(self, node, start, end):
if start == end:
self.tree[node] = arr[start]
else:
mid = (start + end) // 2
self.init(node * 2, start, mid)
self.init(node * 2 + 1, mid + 1, end)
self.tree[node] = self.cmp(self.tree[node * 2], self.tree[node * 2 + 1])
def _query_tree(self, node, start, end):
check = self.query_param.check_range(start, end)
if check == QueryParam.NONE: # 범위 밖 (포함 X)
return self.fallback_value
elif check == QueryParam.ALL: # 범위 안에 완전히 포함됨.
return self.tree[node]
else:
mid = (start + end) // 2
left_res = self._query_tree(node * 2, start, mid)
right_res = self._query_tree(node * 2 + 1, mid + 1, end)
return self.cmp(left_res, right_res)
def query(self, left, right):
self.query_param.set(left, right)
return self._query_tree(1, 0, N - 1)
max_segment_tree = CmpSegmentTree(max, NUMBER_MIN)
min_segment_tree = CmpSegmentTree(min, NUMBER_MAX)
max_segment_tree.init(1, 0, N - 1)
min_segment_tree.init(1, 0, N - 1)
for _ in range(M):
a, b = map(lambda v: int(v) - 1, input().split())
if a > b:
a, b = b, a
print(min_segment_tree.query(a, b), max_segment_tree.query(a, b))solution.py