[PS] BOJ 32036 / 수열과 쿼리 45

[PS] BOJ 32036 / 수열과 쿼리 45
Thumbnail: Photo by Kevin Snow (Unsplash)
문제 링크: https://www.acmicpc.net/problem/32036

아이디어를 도저히 떠올리지 못해서 다른 사람들의 풀이를 참고했습니다,,
설명을 이해하고 나니, 재밌는 문제였습니다.

풀이

문제 이해하기

길이가 $2 \times $ 109 + 1인 배열 $A$에는 -109 $\cdots$ 109까지의 각 인덱스에 0이 초기값으로 들어 있습니다. 이제, 다음의 두 쿼리로 배열 $A$를 조작합니다.

  1. 모든 -109 $ \leq i \leq $ 109 에 대해, $A[i] = A[i] + \vert i - a \vert + b$를 대입합니다.
  2. $A$의 최솟값의 위치 $m$과, $A[m]$을 출력합니다.

$A$를 실제 배열로 구현하기엔 메모리를 과도하게 사용하며, 일일이 1번 쿼리를 계산하는 것 또한 효율적이지 않습니다.

문제 변환하기

1번 쿼리를 $N$회 반복해서 시행하게 되면, $x$번 인덱스에 저장된 값은 다음과 같아집니다.

$$\begin{align*} A[x] &= \sum_{i=1}^{N} \vert x - a_{i} \vert + b_{i} \\ &= \sum_{i=1}^{N} \vert x - a_{i} \vert + \sum_{i=1}^{N} b_{i} \end{align*}$$

이 식을 x에 대한 함수 $F(x)$라고 하면, 2번 쿼리는 $y = F(x)$에서 $y$가 최소가 되도록 하는 $x$의 값을 찾는 문제가 됩니다.

A의 최솟값의 위치 찾기

앞서 $A[x]$를 $x$에 대한 함수 $F(x)$로 변환했으니, A의 최솟값의 위치는 $F(x)$의 값이 최소가 되는 $x$를 찾는 문제로 생각할 수 있습니다.

$a_i$를 크기 순으로 정렬했을 때, x가 $a_i$의 중앙값일 때 $F(x)$가 최소가 될 것이라 생각할 수 있습니다. 따라서, 이 문제는 $a_i$의 중앙값을 찾는 문제로 다시 한 번 바꿔 생각할 수 있습니다.

$a_i$의 중앙값 찾기

$a_i$의 중앙값은 2개의 우선순위 큐(heap)를 활용해 구했습니다.

정렬된 상태의 배열을 절반은 최대 힙에, 절반은 최소 힙에 저장해 중앙값을 찾는 방식으로, 이 경우 최대 힙의 루트에 중앙값이 저장됩니다.

$a_i$의 배열에 새로운 수를 추가할 때, 최대 힙과 최소 힙을 아래 규칙을 유지하면서 관리해야 합니다:

  1. 최대 힙의 크기는 최소 힙의 크기와 같거나 하나 더 크다.
  2. 최대 힙의 최댓값은 최소 힙의 최솟값보다 작거나 같다.

새로운 수를 추가하는 함수는 다음과 같이 구현합니다.

max_heap = []
min_heap = []

def add_element(v):
    if len(max_heap) == len(min_heap):
        heappush(max_heap, (-v, v)) # Python3의 heapq는 최소 힙으로 구현되어 있으므로, 최대 힙으로 사용하려면 우선순위 값을 명시해 줘야 한다.
    else:
        heappush(min_heap, v)

    # 만약 최대 힙의 최댓값이 최소 힙의 최솟값보다 크다면, 서로 바꿔준다.
    if len(max_heap) > 0 and len(min_heap) > 0:
        max_top = max_heap[0][1]
        min_top = min_heap[0]

        if max_top > min_top:
            heappop(max_heap)
            heappop(min_heap)
            heappush(max_heap, (-min_top, min_top))
            heappush(min_heap, max_top)

# 이후 중앙값이 필요할 때에는 최대 힙의 루트에 저장된 값을 쓴다.
_, mid = max_heap[0]

중앙값을 찾는 방법

$F(x)$ 계산하기

이제 중앙값을 찾았으니, 이를 토대로 $F(x)$의 값을 계산해야 합니다.
앞서 정리한 $F(x)$의 식은 다음과 같습니다.

$$F(x) = A[x] = \sum_{i=1}^{N} \vert x - a_{i} \vert + \sum_{i=1}^{N} b_{i}$$

$\vert x - a_i \vert$는 $x$의 값에 따라 나누어 계산하면 됩니다. 앞서 구한 중앙값을 $a_m$, 중앙값의 인덱스를 $m$이라고 할 때, 위 식은 다음과 같이 정리됩니다.

$$\begin{align*} F(x) &= \sum_{i=1}^{m} {a_{i} - x} + \sum_{i=m}^{N} {x - a_{i}} + \sum_{i=1}^{N} b_{i} \\ &= ( \sum_{i=1}^{m} a_{i} - m \times x ) + ( (N-M) \times x - \sum_{i=m+1}^{N} a_{i} ) + \sum_{i=1}^{N} b_{i}\end{align*}$$

  • $b_{i}$의 합은 1번 쿼리를 수행할 때 마다 미리 계산해 두면 됩니다.
  • $a_{i}$의 합의 경우, 최대 힙과 최소 힙에 저장된 원소 합을 각각 계산해 두고, 이를 $\sum_{i=1}^{m} a_{i}$, $\sum_{i=m+1}^{N} a_{i}$의 계산 결과로 사용하면 됩니다.
max_heap = []
min_heap = []
sum_max_heap = 0 # 최대 힙의 원소 합
sum_min_heap = 0 # 최소 힙의 원소 합
sum_b = 0        # b의 합

def add_element(v):
    global sum_max_heap, sum_min_heap
    if len(max_heap) == len(min_heap):
        heappush(max_heap, (-v, v))
        sum_max_heap += v # 최대 힙의 원소 합 갱신
    else:
        heappush(min_heap, v)
        sum_min_heap += v # 최소 힙의 원소 합 갱신
    
    if len(max_heap) > 0 and len(min_heap) > 0:
        max_top = max_heap[0][1]
        min_top = min_heap[0]

        if max_top > min_top:
            heappop(max_heap)
            heappop(min_heap)
            heappush(max_heap, (-min_top, min_top))
            heappush(min_heap, max_top)
            # 최대 힙과 최소 힙의 원소 합도 변경해준다.
            sum_max_heap += min_top - max_top
            sum_min_heap += max_top - min_top

for _ in range(int(input())):
    cmd = input().split()
    if cmd[0] == "1":
        a = int(cmd[1])
        b = int(cmd[2])
        sum_b += b
        add_element(a)
        
    else:
        x = max_heap[0][1] # 중앙값
        y = (x * len(max_heap) - sum_max_heap) + (sum_min_heap - x * len(min_heap)) + sum_b # F(x)
        print(x, y)

전체 코드

from heapq import heappush, heappop

input = open(0).readline

max_heap = []
min_heap = []
sum_max_heap = 0
sum_min_heap = 0
sum_b = 0

def add_element(v):
    global sum_max_heap, sum_min_heap
    if len(max_heap) == len(min_heap):
        heappush(max_heap, (-v, v))
        sum_max_heap += v
    else:
        heappush(min_heap, v)
        sum_min_heap += v
    
    if len(max_heap) > 0 and len(min_heap) > 0:
        max_top = max_heap[0][1]
        min_top = min_heap[0]

        if max_top > min_top:
            heappop(max_heap)
            heappop(min_heap)
            heappush(max_heap, (-min_top, min_top))
            heappush(min_heap, max_top)
            sum_max_heap += min_top - max_top
            sum_min_heap += max_top - min_top

for _ in range(int(input())):
    cmd = input().split()
    if cmd[0] == "1":
        a = int(cmd[1])
        b = int(cmd[2])
        sum_b += b
        add_element(a)
        
    else:
        x = max_heap[0][1]
        y = (x * len(max_heap) - sum_max_heap) + (sum_min_heap - x * len(min_heap)) + sum_b
        print(x, y)

solution.py