[PS] BOJ 2268 / 수들의 합 7
문제 링크: https://www.acmicpc.net/problem/2268
세그먼트 트리를 활용하는 기본적인 문제입니다.
풀이
세그먼트 트리 (Segment Tree)
세그먼트 트리의 동작 과정과 예제를 함께 볼 수 있는 자세한 글을 BOJ BOOK에서 확인할 수 있습니다.
https://book.acmicpc.net/ds/segment-tree
다음과 같은 문제를 생각해 봅시다.
크기가 N인 정수 배열 A가 있고, 여기서 다음과 같은 연산을 최대 M번 수행해야 하는 문제가 있습니다.
- 구간 l, r (l≤r)이 주어졌을 때, A[l]+A[l+1]+⋯+A[r−1]+A[r]을 구해서 출력하기
- i번째 수를 v로 바꾸기 (A[i]=v)
1) 그냥 매번 더하기
1번 연산을 매번 배열의 합을 구하는 방식으로 구현할 수 있습니다.
for _ in range(M):
cmd, *args = map(int, input().split())
if cmd == 1: # 1번 연산
l, r = args
total = 0
for i in range(l, r + 1):
total += A[i]
print(total)
else:
i, v = args
A[i] = v가장 단순한 구현
- 1번 연산: $O(N)$
- 2번 연산: $O(1)$
- 전체 시간 복잡도: $O(NM)$
2) 누적 합 방식
미리 누적 합을 구해 두면 1번 연산을 $O(1)$로 구할 수 있습니다. 하지만, 이 경우 2번 연산이 $O(N)$이 됩니다.
prefix_sum = [0 for _ in range(N + 1)]
for i in range(N):
prefix_sum[i + 1] = prefix_sum[i - 1] + A[i]
for _ in range(M):
cmd, *args = map(int, input().split())
if cmd == 1: # 1번 연산
print(prefix_sum[r] - prefix_sum[l - 1])
else:
i, v = args
diff = v - A[i]
for i in range(i, N + 1): # i번째 원소가 포함된 모든 누적 합을 갱신해준다.
prefix_sum[i] += diff가장 단순한 구현
- 1번 연산: $O(1)$
- 2번 연산: $O(N)$
- 전체 시간 복잡도: $O(NM)$
3) 세그먼트 트리 사용하기
세그먼트 트리는 1번 연산과 2번 연산을 모두 $O(lg N)$로 수행할 수 있는 자료구조입니다.
세그먼트 트리 구현하기
세그먼트 트리의 각 노드에는 다음과 같은 정보를 저장합니다.
- 리프 노드: 원본 배열의 원소
- 리프 노드가 아닌 모든 노드: 왼쪽 자식과 오른쪽 자식의 합
세그먼트 트리는 기본적으로 정 이진 트리(Full Binary Tree) 구조를 가지며, N이 2의 거듭제곱 꼴인 경우 포화 이진 트리(Perfect Binary Tree)가 됩니다.
정 이진 트리(Full Binary Tree): 모든 노드가 0개 또는 2개의 자식 노드를 가지는 이진 트리
포화 이진 트리(Perfect Binary Tree): 모든 리프 노드가 동일한 깊이(또는 레벨)을 가지며 리프 노드가 아닌 모든 노드가 2개의 자식을 가지는 이진 트리

세그먼트 트리를 배열에 저장하기 위해, 세그먼트 트리의 각 노드 번호는 다음과 같이 부여합니다.
현재 노드의 번호를 $x$라 할 때,
- 왼쪽 자식 노드: $2x$
- 오른쪽 자식 노드: $2x + 1$

세그먼트 트리의 구조를 보면, 리프 노드가 $N$개인 세그먼트 트리에는 리프 노드가 아닌 노드가 $N-1$개 존재합니다. 따라서, 전체 노드의 개수는 $2N - 1$개입니다.
세그먼트 트리의 높이는 N이 2의 제곱꼴이 아닌 경우에 $H=\lceil lg N \rceil$입니다.
배열의 크기는 세그먼트 트리가 포화 이진 트리가 될 경우의 노드 개수인 $2^{H + 1}$과 같습니다.
이제, N개의 원소를 가지는 원본 배열의 구간 합을 저장하는 세그먼트 트리를 만들어 봅시다.
1) 세그먼트 트리 배열 초기화하기
앞서 트리의 높이와 포화 이진 트리일 때의 노드 개수를 각각 구했습니다. 이를 통해, 세그먼트 트리를 저장할 배열의 크기를 알 수 있습니다.
트리 배열을 초기화하는 함수는 재귀적으로 동작하며, 다음을 수행합니다.
- 현재 노드가 리프 노드라면 (start = end)
- 원본 배열의 수를 노드에 저장
- 그렇지 않은 경우
- 왼쪽 자식 노드를 초기화
- 오른쪽 자식 노드를 초기화
- 현재 노드의 값은 왼쪽 자식 노드와 오른쪽 자식 노드의 합이 된다.
tree[$x$] = tree[$2x$] + tree[$2x + 1$]
# 배열 정의하기
tree_height = ceil(log2(N)) # 트리의 높이
tree_size = 1 << (tree_height + 1) # 트리의 크기 (포화 이진 트리일 때의 노드 개수)
arr = [0] * N # 원본 배열
tree = [0] * tree_size # 세그먼트 트리를 저장할 배열
# 트리 배열을 초기화하는 함수
def build_segment_tree(node, start, end):
if start == end:
tree[node] = arr[start]
else:
mid = (start + end) // 2
build_segment_tree(node * 2, start, mid)
build_segment_tree(node * 2 + 1, mid + 1, end)
tree[node] = tree[node * 2] + tree[node * 2 + 1]
build_segment_tree(1, 0, N - 1)트리 배열을 초기화하는 함수
2) 세그먼트 트리에서 구간 합 구하기
앞선 문제의 1번 연산에 해당하는, 임의의 구간 $(l, r)$의 합을 구하는 함수 query를 구현해봅시다.
기본적으로, 트리의 각 노드에서 구간 합을 구하는 경우는 다음과 같습니다:
현재 노드에 저장된 구간 $(s, e)$와 합을 구해야 하는 구간 $(l, r)$에 대해, 가능한 경우는 4가지입니다.
- $(l, r)$에 완전히 포함될 경우
- 현재 노드에 저장된 값을 그대로 반환하고 탐색 종료.
- $(l, r)$에 완전히 포함되지 않는 경우
- 0을 반환하고 탐색 종료.
- $(l, r)$에 일부만 포함되는 경우
- 두 자식 노드에 대해 재귀적으로 탐색을 수행한 후 그 합을 반환한다.
- $(s, e)$에 $(l, r)$이 완전히 포함되는 경우
- 3과 동일
def query(node, start, end, left, right):
"""세그먼트 트리에서 (left, right) 구간에 해당하는 구간 합을 구한다.
node: 현재 노드의 번호 (실제 tree 배열에 저장된 위치)
start, end: 현재 노드에 저장된 구간 정보
left, right: 찾고자 하는 구간 정보
"""
if left > end or right < start: # (l, r)이 이 (s, e)에 아예 포함되지 않는 경우
return 0
elif left <= start and right >= end: # (s, e)가 (l, r)에 완전히 포함되는 경우
return tree[node]
else: # (l, r)이 (s, e)에 완전히 포함되는 경우 / (s, e)가 (l, r)에 일부만 포함되는 경우
mid = (start + end) // 2
lsum = query(node * 2, start, mid, left, right)
rsum = query(node * 2 + 1, mid + 1, end, left, right)
return lsum + rsum구간 합을 구하는 함수
3) 배열 원소 변경
앞선 문제의 2번 연산에 해당하는, 배열의 $i$번째 원소를 $v$로 변경하는 함수 update를 구현해 봅시다.
update_tree함수는 $i$가 포함되는 모든 구간에 대해 노드에 저장된 값을 갱신합니다.update함수는 update_tree의 긴 호출 구문을 감싸고, $i$번 노드의 값을 직접 갱신해주는 함수입니다.
이 구현 방식은 기존에 새로 저장할 값 $v$와 $i$번에 저장되어 있던 값의 차이 diff를 직접적으로 각 노드에 더해가며 $i$가 포함된 모든 구간의 노드를 갱신합니다.
def update_tree(node, start, end, index, diff):
"""배열의 index번째 원소를 변경한다. 구간 합 정보를 트리 전체에 갱신해준다.
node: 현재 노드의 번호 (실제 tree 배열에 저장된 위치)
start, end: 현재 노드에 저장된 구간 정보
index: 갱신된 배열 원소의 위치
diff: 기존 값과 갱신된 값의 차.
"""
if index < start or index > end: # 갱신된 원소가 이 구간에 포함되지 않는다.
return
tree[node] += diff
if start != end: # leaf 노드가 아니라면
mid = (start + end) // 2
update_tree(node * 2, start, mid, index, diff)
update_tree(node * 2 + 1, mid + 1, end, index, diff)
def update(index, val):
"""배열의 index번째 원소를 val로 변경한다."""
diff = val - arr[index]
arr[index] = val
update_tree(1, 0, N - 1, index, diff)배열의 원소를 변경하는 함수 구현 (1)
diff를 사용하지 않고 직접 구간 합을 계산해가며 세그먼트 트리를 갱신할 수도 있습니다.
def update(node, start, end, index, val):
"""기존 배열에서 index 위치에 저장된 값을 val로 수정한다.
이후 세그먼트 트리에 저장된 구간 곱을 갱신한다.
node: 현재 노드의 번호 (tree에 저장된 위치)
start, end: 현재 노드에 저장된 구간
index, val: 기존 배열에서 값이 변한 위치와 그 값
"""
if start > index or end < index: # 현재 구간에 index를 포함되지 않는 경우.
return
elif start == end: # 현재 노드가 leaf일 경우.
tree[node] = val
arr[index] = val
else: # 현재 구간에 index가 포함되는 경우
mid = (start + end) // 2
update(node * 2, start, mid, index, val)
update(node * 2 + 1, mid + 1, end, index, val)
tree[node] = tree[node * 2] + tree[node * 2 + 1]배열의 원소를 변경하는 함수 구현 (2)
전체 코드
from math import ceil, log2
input = open(0).readline
N, M = map(int, input().split())
tree_height = ceil(log2(N))
tree_size = 1 << (tree_height + 1)
arr = [0] * N
tree = [0] * tree_size
def build_segment_tree(node, start, end):
if start == end:
tree[node] = arr[start]
else:
mid = (start + end) // 2
build_segment_tree(node * 2, start, mid)
build_segment_tree(node * 2 + 1, mid + 1, end)
tree[node] = tree[node * 2] + tree[node * 2 + 1]
build_segment_tree(1, 0, N - 1)
def query(node, start, end, left, right):
"""세그먼트 트리에서 (left, right) 구간에 해당하는 구간 합을 구한다.
node: 현재 노드의 번호 (실제 tree 배열에 저장된 위치)
start, end: 현재 노드에 저장된 구간 정보
left, right: 찾고자 하는 구간 정보
"""
if left > end or right < start: # (l, r)이 이 (s, e)에 아예 포함되지 않는 경우
return 0
elif left <= start and right >= end: # (s, e)가 (l, r)에 완전히 포함되는 경우
return tree[node]
else: # (l, r)이 (s, e)에 완전히 포함되는 경우 / (s, e)가 (l, r)에 일부만 포함되는 경우
mid = (start + end) // 2
lsum = query(node * 2, start, mid, left, right)
rsum = query(node * 2 + 1, mid + 1, end, left, right)
return lsum + rsum
def update_tree(node, start, end, index, diff):
"""배열의 index번째 원소를 변경한다. 구간 합 정보를 트리 전체에 갱신해준다.
node: 현재 노드의 번호 (실제 tree 배열에 저장된 위치)
start, end: 현재 노드에 저장된 구간 정보
index: 갱신된 배열 원소의 위치
diff: 기존 값과 갱신된 값의 차.
"""
if index < start or index > end: # 갱신된 원소가 이 구간에 포함되지 않는다.
return
tree[node] += diff
if start != end: # leaf 노드가 아니라면
mid = (start + end) // 2
update_tree(node * 2, start, mid, index, diff)
update_tree(node * 2 + 1, mid + 1, end, index, diff)
def update(index, val):
"""배열의 index번째 원소를 val로 변경한다."""
diff = val - arr[index]
arr[index] = val
update_tree(1, 0, N - 1, index, diff)
for _ in range(M):
cmd, i, j = map(int, input().split())
if cmd == 0:
if i > j:
i, j = j, i
print(query(1, 0, N - 1, i - 1, j - 1))
else:
update(i - 1, j)solution.py