[PS] BOJ 11505 / 구간 곱 구하기

[PS] BOJ 11505 / 구간 곱 구하기
Thumbnail: Photo by Sofiya Levchenko (Unsplash)
문제 링크: https://www.acmicpc.net/problem/11505

세그먼트 트리를 활용하는 기본적인 문제입니다.

풀이

세그먼트 트리?

세그먼트 트리에 대한 설명과 구간 합을 저장하는 세그먼트 이전 글을 참고해 주세요 😄

구간 합 대신에 구간 곱 저장하기

이전에 구현했던 세그먼트 트리는 구간의 합을 저장했습니다. 이번 문제에서는 구간 곱을 1,000,000,007로 나눈 나머지를 저장하도록 노드에 값을 저장하는 부분만 수정해주면 됩니다.

# 배열 정의하기
tree_height = ceil(log2(N))            # 트리의 높이
tree_size = 1 << (tree_height + 1)     # 트리의 크기 (포화 이진 트리일 때의 노드 개수)
arr = [0] * N                          # 원본 배열
tree = [0] * tree_size                 # 세그먼트 트리를 저장할 배열

# 트리 배열을 초기화하는 함수
def build_segment_tree(node, start, end):
    if start == end:
        tree[node] = arr[start]
    else:
        mid = (start + end) // 2
        build_segment_tree(node * 2, start, mid)
        build_segment_tree(node * 2 + 1, mid + 1, end)
        tree[node] = (tree[node * 2] * tree[node * 2 + 1]) % 1_000_000_007

build_segment_tree(1, 0, N - 1)

트리 배열을 초기화하는 함수

def query(node, start, end, left, right):
    """세그먼트 트리에서 (left, right) 구간에 해당하는 구간 합을 구한다.
    node: 현재 노드의 번호 (실제 tree 배열에 저장된 위치)
    start, end: 현재 노드에 저장된 구간 정보
    left, right: 찾고자 하는 구간 정보
    """
    if left > end or right < start: # (l, r)이 이 (s, e)에 아예 포함되지 않는 경우
        return 1 # 곱셈의 항등원은 1이다.
    elif left <= start and right >= end: # (s, e)가 (l, r)에 완전히 포함되는 경우
        return tree[node]
    else: # (l, r)이 (s, e)에 완전히 포함되는 경우 / (s, e)가 (l, r)에 일부만 포함되는 경우
        mid = (start + end) // 2
        lsum = query(node * 2, start, mid, left, right)
        rsum = query(node * 2 + 1, mid + 1, end, left, right)
        return (lsum * rsum) % 1_000_000_007

구간 곱을 구하는 함수

def update(node, start, end, index, val):
    """기존 배열에서 index 위치에 저장된 값을 val로 수정한다.
    이후 세그먼트 트리에 저장된 구간 곱을 갱신한다.
    node: 현재 노드의 번호 (tree에 저장된 위치)
    start, end: 현재 노드에 저장된 구간
    index, val: 기존 배열에서 값이 변한 위치와 그 값
    """
    if start > index or end < index: # 현재 구간에 index를 포함되지 않는 경우.
        return
    elif start == end: # 현재 노드가 leaf일 경우.
        tree[node] = val
        arr[index] = val
    else: # 현재 구간에 index가 포함되는 경우
        mid = (start + end) // 2
        update(node * 2, start, mid, index, val)
        update(node * 2 + 1, mid + 1, end, index, val)
        tree[node] = (tree[node * 2] * tree[node * 2 + 1]) % 1_000_000_007

배열의 원소를 변경하는 함수 구현

전체 코드

from math import ceil, log2

MOD = 1_000_000_007
input = open(0).readline
N, M, K = map(int, input().split())

tree_height = ceil(log2(N))
tree_size = 1 << (tree_height + 1)
arr = list(int(input()) for _ in range(N))
tree = [0] * tree_size

def build_segment_tree(node, start, end):
    """세그먼트 트리를 초기화한다."""
    if start == end:
        tree[node] = arr[start]
    else:
        mid = (start + end) // 2
        build_segment_tree(node * 2, start, mid)
        build_segment_tree(node * 2 + 1, mid + 1, end)
        tree[node] = (tree[node * 2] * tree[node * 2 + 1]) % MOD

build_segment_tree(1, 0, N - 1)

def query(node, start, end, left, right):
    """세그먼트 트리에서 (left, right) 구간에 해당하는 곱을 구한다.
    node: 현재 노드의 번호 (tree에 저장된 위치)
    start, end: 현재 노드에 저장된 구간
    left, right: 찾고자 하는 구간
    """
    if start > right or end < left: # 찾고자 하는 구간에 현재 구간이 완전히 포함되지 않는 경우
        return 1
    elif start >= left and end <= right: # 찾고자 하는 구간에 현재 구간이 완전히 포함될 경우
        return tree[node]
    else: # 찾고자 하는 구간에 현재 구간의 일부만 포함될 경우
        mid = (start + end) // 2
        lmul = query(node * 2, start, mid, left, right)
        rmul = query(node * 2 + 1, mid + 1, end, left, right)
        return (lmul * rmul) % MOD

def update(node, start, end, index, val):
    """기존 배열에서 index 위치에 저장된 값을 val로 수정한다.
    이후 세그먼트 트리에 저장된 구간 곱을 갱신한다.
    node: 현재 노드의 번호 (tree에 저장된 위치)
    start, end: 현재 노드에 저장된 구간
    index, val: 기존 배열에서 값이 변한 위치와 그 값
    """
    if start > index or end < index: # 현재 구간에 index를 포함되지 않는 경우.
        return
    elif start == end: # 현재 노드가 leaf일 경우.
        tree[node] = val
        arr[index] = val
    else: # 현재 구간에 index가 포함되는 경우
        mid = (start + end) // 2
        update(node * 2, start, mid, index, val)
        update(node * 2 + 1, mid + 1, end, index, val)
        tree[node] = (tree[node * 2] * tree[node * 2 + 1]) % MOD
    

for _ in range(M + K):
    a, b, c = map(int, input().split())
    if a == 1:
        update(1, 0, N - 1, b - 1, c)      # 실제 인덱스는 0...N-1이므로 1씩 빼서 사용.
    else:
        print(query(1, 0, N - 1, b - 1, c - 1)) # 실제 인덱스는 0...N-1이므로 1씩 빼서 사용.

solution.py