문제
https://www.acmicpc.net/problem/24060
오늘도 서준이는 병합 정렬 수업 조교를 하고 있다. 아빠가 수업한 내용을 학생들이 잘 이해했는지 문제를 통해서 확인해보자.
N개의 서로 다른 양의 정수가 저장된 배열 A가 있다. 병합 정렬로 배열 A를 오름차순 정렬할 경우 배열 A에 K 번째 저장되는 수를 구해서 우리 서준이를 도와주자.
크기가 N인 배열에 대한 병합 정렬 의사 코드는 다음과 같다.
merge_sort(A[p..r]) { # A[p..r]을 오름차순 정렬한다.
if (p < r) then {
q <- ⌊(p + r) / 2⌋; # q는 p, r의 중간 지점
merge_sort(A, p, q); # 전반부 정렬
merge_sort(A, q + 1, r); # 후반부 정렬
merge(A, p, q, r); # 병합
}
}
# A[p..q]와 A[q+1..r]을 병합하여 A[p..r]을 오름차순 정렬된 상태로 만든다.
# A[p..q]와 A[q+1..r]은 이미 오름차순으로 정렬되어 있다.
merge(A[], p, q, r) {
i <- p; j <- q + 1; t <- 1;
while (i ≤ q and j ≤ r) {
if (A[i] ≤ A[j])
then tmp[t++] <- A[i++]; # tmp[t] <- A[i]; t++; i++;
else tmp[t++] <- A[j++]; # tmp[t] <- A[j]; t++; j++;
}
while (i ≤ q) # 왼쪽 배열 부분이 남은 경우
tmp[t++] <- A[i++];
while (j ≤ r) # 오른쪽 배열 부분이 남은 경우
tmp[t++] <- A[j++];
i <- p; t <- 1;
while (i ≤ r) # 결과를 A[p..r]에 저장
A[i++] <- tmp[t++];
}
입력
첫째 줄에 배열 A의 크기 N(5 ≤ N ≤ 500,000), 저장 횟수 K(1 ≤ K ≤ 10^8)가 주어진다.
다음 줄에 서로 다른 배열 A의 원소 A1, A2, ..., AN이 주어진다. (1 ≤ Ai ≤ 10^9)
출력
배열 A에 K 번째 저장 되는 수를 출력한다. 저장 횟수가 K 보다 작으면 -1을 출력한다.
예제
나의 풀이
N, K = map(int, input().split())
nums = list(map(int, input().split()))
def merge_sort(arr):
if len(arr) <= 1:
return arr
mid = (len(arr) + 1) // 2
left = merge_sort(arr[:mid])
right = merge_sort(arr[mid:])
return merge(left, right)
def merge(left, right):
merged = []
l = r = 0
while l < len(left) and r < len(right):
if left[l] > right[r]:
merged.append(right[r])
res.append(right[r])
r += 1
else:
merged.append(left[l])
res.append(left[l])
l += 1
while l < len(left):
merged.append(left[l])
res.append(left[l])
l += 1
while r < len(right):
merged.append(right[r])
res.append(right[r])
r += 1
return merged
res = []
merge_sort(nums)
if len(res) >= K:
print(res[K-1])
else:
print(-1)
병합 정렬은 배열을 절반씩 나누며 정렬한 뒤, 정렬된 두 배열을 병합하는 알고리즘이다.
이 구조는 그대로 유지하면서, K번째로 저장된 수를 출력하기 위해 값이 배열에 저장될 때마다 기록하는 부분을 추가했다.
mid = (len(arr) + 1) // 2
보통 병합 정렬은 `mid = len(arr) // 2`로 나누지만,
문제와 동일한 분할 조건을 구현하기 위해 길이에 1을 더해서 나누는 방식로 수정했다.
# 요소를 병합할 때마다 저장 리스트에 기록
res.append(value)
병합할 때마다 `res` 리스트에 값을 저장해 K번째 저장되는 값을 추적한다.
나머지 리스트의 값을 병합할 때에도 `left[l:]` 또는 `right[r:]`와 같이 리스트를 한 번에 병합하지 않고,
while문을 사용해서 값을 하나씩 저장하면서 기록했다.
최종적으로 저장된 횟수가 K 이상이라면 `res`에서 K번째 저장된 값을 출력하고, 그렇지 않으면 -1을 출력한다.