[PS] BOJ 32036 / 수열과 쿼리 45
문제 링크: https://www.acmicpc.net/problem/32036
아이디어를 도저히 떠올리지 못해서 다른 사람들의 풀이를 참고했습니다,,
설명을 이해하고 나니, 재밌는 문제였습니다.
풀이
문제 이해하기
길이가 $2 \times $ 109 + 1인 배열 $A$에는 -109 $\cdots$ 109까지의 각 인덱스에 0이 초기값으로 들어 있습니다. 이제, 다음의 두 쿼리로 배열 $A$를 조작합니다.
- 모든 -109 $ \leq i \leq $ 109 에 대해, $A[i] = A[i] + \vert i - a \vert + b$를 대입합니다.
- $A$의 최솟값의 위치 $m$과, $A[m]$을 출력합니다.
$A$를 실제 배열로 구현하기엔 메모리를 과도하게 사용하며, 일일이 1번 쿼리를 계산하는 것 또한 효율적이지 않습니다.
문제 변환하기
1번 쿼리를 $N$회 반복해서 시행하게 되면, $x$번 인덱스에 저장된 값은 다음과 같아집니다.
$$\begin{align*} A[x] &= \sum_{i=1}^{N} \vert x - a_{i} \vert + b_{i} \\ &= \sum_{i=1}^{N} \vert x - a_{i} \vert + \sum_{i=1}^{N} b_{i} \end{align*}$$
이 식을 x에 대한 함수 $F(x)$라고 하면, 2번 쿼리는 $y = F(x)$에서 $y$가 최소가 되도록 하는 $x$의 값을 찾는 문제가 됩니다.
A의 최솟값의 위치 찾기
앞서 $A[x]$를 $x$에 대한 함수 $F(x)$로 변환했으니, A의 최솟값의 위치는 $F(x)$의 값이 최소가 되는 $x$를 찾는 문제로 생각할 수 있습니다.
$a_i$를 크기 순으로 정렬했을 때, x가 $a_i$의 중앙값일 때 $F(x)$가 최소가 될 것이라 생각할 수 있습니다. 따라서, 이 문제는 $a_i$의 중앙값을 찾는 문제로 다시 한 번 바꿔 생각할 수 있습니다.
$a_i$의 중앙값 찾기
$a_i$의 중앙값은 2개의 우선순위 큐(heap)를 활용해 구했습니다.
정렬된 상태의 배열을 절반은 최대 힙에, 절반은 최소 힙에 저장해 중앙값을 찾는 방식으로, 이 경우 최대 힙의 루트에 중앙값이 저장됩니다.
$a_i$의 배열에 새로운 수를 추가할 때, 최대 힙과 최소 힙을 아래 규칙을 유지하면서 관리해야 합니다:
- 최대 힙의 크기는 최소 힙의 크기와 같거나 하나 더 크다.
- 최대 힙의 최댓값은 최소 힙의 최솟값보다 작거나 같다.
새로운 수를 추가하는 함수는 다음과 같이 구현합니다.
max_heap = []
min_heap = []
def add_element(v):
if len(max_heap) == len(min_heap):
heappush(max_heap, (-v, v)) # Python3의 heapq는 최소 힙으로 구현되어 있으므로, 최대 힙으로 사용하려면 우선순위 값을 명시해 줘야 한다.
else:
heappush(min_heap, v)
# 만약 최대 힙의 최댓값이 최소 힙의 최솟값보다 크다면, 서로 바꿔준다.
if len(max_heap) > 0 and len(min_heap) > 0:
max_top = max_heap[0][1]
min_top = min_heap[0]
if max_top > min_top:
heappop(max_heap)
heappop(min_heap)
heappush(max_heap, (-min_top, min_top))
heappush(min_heap, max_top)
# 이후 중앙값이 필요할 때에는 최대 힙의 루트에 저장된 값을 쓴다.
_, mid = max_heap[0]중앙값을 찾는 방법
$F(x)$ 계산하기
이제 중앙값을 찾았으니, 이를 토대로 $F(x)$의 값을 계산해야 합니다.
앞서 정리한 $F(x)$의 식은 다음과 같습니다.
$$F(x) = A[x] = \sum_{i=1}^{N} \vert x - a_{i} \vert + \sum_{i=1}^{N} b_{i}$$
$\vert x - a_i \vert$는 $x$의 값에 따라 나누어 계산하면 됩니다. 앞서 구한 중앙값을 $a_m$, 중앙값의 인덱스를 $m$이라고 할 때, 위 식은 다음과 같이 정리됩니다.
$$\begin{align*} F(x) &= \sum_{i=1}^{m} {a_{i} - x} + \sum_{i=m}^{N} {x - a_{i}} + \sum_{i=1}^{N} b_{i} \\ &= ( \sum_{i=1}^{m} a_{i} - m \times x ) + ( (N-M) \times x - \sum_{i=m+1}^{N} a_{i} ) + \sum_{i=1}^{N} b_{i}\end{align*}$$
- $b_{i}$의 합은 1번 쿼리를 수행할 때 마다 미리 계산해 두면 됩니다.
- $a_{i}$의 합의 경우, 최대 힙과 최소 힙에 저장된 원소 합을 각각 계산해 두고, 이를 $\sum_{i=1}^{m} a_{i}$, $\sum_{i=m+1}^{N} a_{i}$의 계산 결과로 사용하면 됩니다.
max_heap = []
min_heap = []
sum_max_heap = 0 # 최대 힙의 원소 합
sum_min_heap = 0 # 최소 힙의 원소 합
sum_b = 0 # b의 합
def add_element(v):
global sum_max_heap, sum_min_heap
if len(max_heap) == len(min_heap):
heappush(max_heap, (-v, v))
sum_max_heap += v # 최대 힙의 원소 합 갱신
else:
heappush(min_heap, v)
sum_min_heap += v # 최소 힙의 원소 합 갱신
if len(max_heap) > 0 and len(min_heap) > 0:
max_top = max_heap[0][1]
min_top = min_heap[0]
if max_top > min_top:
heappop(max_heap)
heappop(min_heap)
heappush(max_heap, (-min_top, min_top))
heappush(min_heap, max_top)
# 최대 힙과 최소 힙의 원소 합도 변경해준다.
sum_max_heap += min_top - max_top
sum_min_heap += max_top - min_top
for _ in range(int(input())):
cmd = input().split()
if cmd[0] == "1":
a = int(cmd[1])
b = int(cmd[2])
sum_b += b
add_element(a)
else:
x = max_heap[0][1] # 중앙값
y = (x * len(max_heap) - sum_max_heap) + (sum_min_heap - x * len(min_heap)) + sum_b # F(x)
print(x, y)전체 코드
from heapq import heappush, heappop
input = open(0).readline
max_heap = []
min_heap = []
sum_max_heap = 0
sum_min_heap = 0
sum_b = 0
def add_element(v):
global sum_max_heap, sum_min_heap
if len(max_heap) == len(min_heap):
heappush(max_heap, (-v, v))
sum_max_heap += v
else:
heappush(min_heap, v)
sum_min_heap += v
if len(max_heap) > 0 and len(min_heap) > 0:
max_top = max_heap[0][1]
min_top = min_heap[0]
if max_top > min_top:
heappop(max_heap)
heappop(min_heap)
heappush(max_heap, (-min_top, min_top))
heappush(min_heap, max_top)
sum_max_heap += min_top - max_top
sum_min_heap += max_top - min_top
for _ in range(int(input())):
cmd = input().split()
if cmd[0] == "1":
a = int(cmd[1])
b = int(cmd[2])
sum_b += b
add_element(a)
else:
x = max_heap[0][1]
y = (x * len(max_heap) - sum_max_heap) + (sum_min_heap - x * len(min_heap)) + sum_b
print(x, y)
solution.py