[PS] BOJ 11505 / 구간 곱 구하기
문제 링크: https://www.acmicpc.net/problem/11505
세그먼트 트리를 활용하는 기본적인 문제입니다.
풀이
세그먼트 트리?
세그먼트 트리에 대한 설명과 구간 합을 저장하는 세그먼트 이전 글을 참고해 주세요 😄
구간 합 대신에 구간 곱 저장하기
이전에 구현했던 세그먼트 트리는 구간의 합을 저장했습니다. 이번 문제에서는 구간 곱을 1,000,000,007로 나눈 나머지를 저장하도록 노드에 값을 저장하는 부분만 수정해주면 됩니다.
# 배열 정의하기
tree_height = ceil(log2(N)) # 트리의 높이
tree_size = 1 << (tree_height + 1) # 트리의 크기 (포화 이진 트리일 때의 노드 개수)
arr = [0] * N # 원본 배열
tree = [0] * tree_size # 세그먼트 트리를 저장할 배열
# 트리 배열을 초기화하는 함수
def build_segment_tree(node, start, end):
if start == end:
tree[node] = arr[start]
else:
mid = (start + end) // 2
build_segment_tree(node * 2, start, mid)
build_segment_tree(node * 2 + 1, mid + 1, end)
tree[node] = (tree[node * 2] * tree[node * 2 + 1]) % 1_000_000_007
build_segment_tree(1, 0, N - 1)트리 배열을 초기화하는 함수
def query(node, start, end, left, right):
"""세그먼트 트리에서 (left, right) 구간에 해당하는 구간 합을 구한다.
node: 현재 노드의 번호 (실제 tree 배열에 저장된 위치)
start, end: 현재 노드에 저장된 구간 정보
left, right: 찾고자 하는 구간 정보
"""
if left > end or right < start: # (l, r)이 이 (s, e)에 아예 포함되지 않는 경우
return 1 # 곱셈의 항등원은 1이다.
elif left <= start and right >= end: # (s, e)가 (l, r)에 완전히 포함되는 경우
return tree[node]
else: # (l, r)이 (s, e)에 완전히 포함되는 경우 / (s, e)가 (l, r)에 일부만 포함되는 경우
mid = (start + end) // 2
lsum = query(node * 2, start, mid, left, right)
rsum = query(node * 2 + 1, mid + 1, end, left, right)
return (lsum * rsum) % 1_000_000_007구간 곱을 구하는 함수
def update(node, start, end, index, val):
"""기존 배열에서 index 위치에 저장된 값을 val로 수정한다.
이후 세그먼트 트리에 저장된 구간 곱을 갱신한다.
node: 현재 노드의 번호 (tree에 저장된 위치)
start, end: 현재 노드에 저장된 구간
index, val: 기존 배열에서 값이 변한 위치와 그 값
"""
if start > index or end < index: # 현재 구간에 index를 포함되지 않는 경우.
return
elif start == end: # 현재 노드가 leaf일 경우.
tree[node] = val
arr[index] = val
else: # 현재 구간에 index가 포함되는 경우
mid = (start + end) // 2
update(node * 2, start, mid, index, val)
update(node * 2 + 1, mid + 1, end, index, val)
tree[node] = (tree[node * 2] * tree[node * 2 + 1]) % 1_000_000_007배열의 원소를 변경하는 함수 구현
전체 코드
from math import ceil, log2
MOD = 1_000_000_007
input = open(0).readline
N, M, K = map(int, input().split())
tree_height = ceil(log2(N))
tree_size = 1 << (tree_height + 1)
arr = list(int(input()) for _ in range(N))
tree = [0] * tree_size
def build_segment_tree(node, start, end):
"""세그먼트 트리를 초기화한다."""
if start == end:
tree[node] = arr[start]
else:
mid = (start + end) // 2
build_segment_tree(node * 2, start, mid)
build_segment_tree(node * 2 + 1, mid + 1, end)
tree[node] = (tree[node * 2] * tree[node * 2 + 1]) % MOD
build_segment_tree(1, 0, N - 1)
def query(node, start, end, left, right):
"""세그먼트 트리에서 (left, right) 구간에 해당하는 곱을 구한다.
node: 현재 노드의 번호 (tree에 저장된 위치)
start, end: 현재 노드에 저장된 구간
left, right: 찾고자 하는 구간
"""
if start > right or end < left: # 찾고자 하는 구간에 현재 구간이 완전히 포함되지 않는 경우
return 1
elif start >= left and end <= right: # 찾고자 하는 구간에 현재 구간이 완전히 포함될 경우
return tree[node]
else: # 찾고자 하는 구간에 현재 구간의 일부만 포함될 경우
mid = (start + end) // 2
lmul = query(node * 2, start, mid, left, right)
rmul = query(node * 2 + 1, mid + 1, end, left, right)
return (lmul * rmul) % MOD
def update(node, start, end, index, val):
"""기존 배열에서 index 위치에 저장된 값을 val로 수정한다.
이후 세그먼트 트리에 저장된 구간 곱을 갱신한다.
node: 현재 노드의 번호 (tree에 저장된 위치)
start, end: 현재 노드에 저장된 구간
index, val: 기존 배열에서 값이 변한 위치와 그 값
"""
if start > index or end < index: # 현재 구간에 index를 포함되지 않는 경우.
return
elif start == end: # 현재 노드가 leaf일 경우.
tree[node] = val
arr[index] = val
else: # 현재 구간에 index가 포함되는 경우
mid = (start + end) // 2
update(node * 2, start, mid, index, val)
update(node * 2 + 1, mid + 1, end, index, val)
tree[node] = (tree[node * 2] * tree[node * 2 + 1]) % MOD
for _ in range(M + K):
a, b, c = map(int, input().split())
if a == 1:
update(1, 0, N - 1, b - 1, c) # 실제 인덱스는 0...N-1이므로 1씩 빼서 사용.
else:
print(query(1, 0, N - 1, b - 1, c - 1)) # 실제 인덱스는 0...N-1이므로 1씩 빼서 사용.
solution.py