LC 215. 数组中的第K个最大元素

题目描述

这是 LeetCode 上的 215. 数组中的第K个最大元素 ,难度为 中等

给定整数数组 nums 和整数 k,请返回数组中第 k 个最大的元素。

请注意,你需要找的是数组排序后的第 k 个最大的元素,而不是第 k 个不同的元素。

你必须设计并实现时间复杂度为 $O(n)$ 的算法解决此问题。

示例 1:

1
2
3
输入: [3,2,1,5,6,4], k = 2

输出: 5

示例 2:
1
2
3
输入: [3,2,3,1,2,4,5,5,6], k = 4

输出: 4

提示:

  • $1 <= k <= nums.length <= 10^5$
  • $-10^4 <= nums[i] <= 10^4$

值域映射 + 树状数组 + 二分

除了直接对数组进行排序,取第 $k$ 位的 $O(n\log{n})$ 做法以外。

对于值域大小 小于 数组长度本身时,我们还能使用「树状数组 + 二分」的 $O(n\log{m})$ 做法,其中 $m$ 为值域大小。

首先值域大小为 $[-10^4, 10^4]$,为了方便,我们为每个 $nums[i]$ 增加大小为 $1e4 + 10$ 的偏移量,将值域映射到 $[10, 2 \times 10^4 + 10]$ 的空间。

将每个增加偏移量后的 $nums[i]$ 存入树状数组,考虑在 $[0, m)$ 范围内进行二分,假设我们真实第 $k$ 大的值为 $t$,那么在以 $t$ 为分割点的数轴上,具有二段性质:

  • 在 $[0, t]$ 范围内的数 $cur$ 满足「树状数组中大于等于 $cur$ 的数不低于 $k$ 个」
  • 在 $(t, m)$ 范围内的数 $cur$ 不满足「树状数组中大于等于 $cur$ 的数不低于 $k$ 个」

二分出结果后再减去刚开始添加的偏移量即是答案。

Java 代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class Solution {
int M = 100010, N = 2 * M;
int[] tr = new int[N];
int lowbit(int x) {
return x & -x;
}
int query(int x) {
int ans = 0;
for (int i = x; i > 0; i -= lowbit(i)) ans += tr[i];
return ans;
}
void add(int x) {
for (int i = x; i < N; i += lowbit(i)) tr[i]++;
}
public int findKthLargest(int[] nums, int k) {
for (int x : nums) add(x + M);
int l = 0, r = N - 1;
while (l < r) {
int mid = l + r + 1 >> 1;
if (query(N - 1) - query(mid - 1) >= k) l = mid;
else r = mid - 1;
}
return r - M;
}
}

C++ 代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class Solution {
public:
int N = 200010, M = 100010, tr[200010];
int lowbit(int x) {
return x & -x;
}
int query(int x) {
int ans = 0;
for (int i = x; i > 0; i -= lowbit(i)) ans += tr[i];
return ans;
}
void add(int x) {
for (int i = x; i < N; i += lowbit(i)) tr[i]++;
}
int findKthLargest(vector<int>& nums, int k) {
for (int x : nums) add(x + M);
int l = 0, r = N - 1;
while (l < r) {
int mid = l + r + 1 >> 1;
if (query(N - 1) - query(mid - 1) >= k) l = mid;
else r = mid - 1;
}
return r - M;
}
};

Python 代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class Solution:
def findKthLargest(self, nums: List[int], k: int) -> int:
N, M = 200010, 100010
tr = [0] * N
def lowbit(x):
return x & -x
def query(x):
ans = 0
i = x
while i > 0:
ans += tr[i]
i -= lowbit(i)
return ans
def add(x):
i = x
while i < N:
tr[i] += 1
i += lowbit(i)
for x in nums:
add(x + M)
l, r = 0, N - 1
while l < r:
mid = l + r + 1 >> 1
if query(N - 1) - query(mid - 1) >= k: l = mid
else: r = mid - 1
return r - M

TypeScript 代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
function findKthLargest(nums: number[], k: number): number {
const N = 200010, M = 100010;
const tr = new Array(N).fill(0);
const lowbit = function(x: number): number {
return x & -x;
};
const add = function(x: number): void {
for (let i = x; i < N; i += lowbit(i)) tr[i]++;
};
const query = function(x: number): number {
let ans = 0;
for (let i = x; i > 0; i -= lowbit(i)) ans += tr[i];
return ans;
};
for (const x of nums) add(x + M);
let l = 0, r = N - 1;
while (l < r) {
const mid = l + r + 1 >> 1;
if (query(N - 1) - query(mid - 1) >= k) l = mid;
else r = mid - 1;
}
return r - M;
};

  • 时间复杂度:将所有数字放入树状数组复杂度为 $O(n\log{m})$;二分出答案复杂度为 $O(\log^2{m})$,其中 $m = 2 \times 10^4$ 为值域大小。整体复杂度为 $O(n\log{m})$
  • 空间复杂度:$O(m)$

优先队列(堆)

另外一个容易想到的想法是利用优先队列(堆),由于题目要我们求的是第 $k$ 大的元素,因此我们建立一个小根堆。

根据当前队列元素个数或当前元素与栈顶元素的大小关系进行分情况讨论:

  • 当优先队列元素不足 $k$ 个,可将当前元素直接放入队列中;
  • 当优先队列元素达到 $k$ 个,并且当前元素大于栈顶元素(栈顶元素必然不是答案),可将当前元素放入队列中。

Java 代码:

1
2
3
4
5
6
7
8
9
10
class Solution {
public int findKthLargest(int[] nums, int k) {
PriorityQueue<Integer> q = new PriorityQueue<>((a,b)->a-b);
for (int x : nums) {
if (q.size() < k || q.peek() < x) q.add(x);
if (q.size() > k) q.poll();
}
return q.peek();
}
}

C++ 代码:
1
2
3
4
5
6
7
8
9
10
11
class Solution {
public:
int findKthLargest(vector<int>& nums, int k) {
priority_queue<int, vector<int>, greater<int>> q;
for (int x : nums) {
if (q.size() < k || q.top() < x) q.push(x);
if (q.size() > k) q.pop();
}
return q.top();
}
};

Python 代码:
1
2
3
4
5
6
7
8
9
class Solution:
def findKthLargest(self, nums: List[int], k: int) -> int:
q = []
for x in nums:
if len(q) < k or q[0] < x:
heapq.heappush(q, x)
if len(q) > k:
heapq.heappop(q)
return q[0]

  • 时间复杂度:$O(n\log{k})$
  • 空间复杂度:$O(k)$

快速选择

对于给定数组,求解第 $k$ 大元素,且要求线性复杂度,正解为使用「快速选择」做法。

基本思路与「快速排序」一致,每次敲定一个基准值 x,根据当前与 x 的大小关系,将范围在 $[l, r]$ 的 $nums[i]$ 划分为到两边。

同时利用,利用题目只要求输出第 $k$ 大的值,而不需要对数组进行整体排序,我们只需要根据划分两边后,第 $k$ 大数会落在哪一边,来决定对哪边进行递归处理即可。

快速排序模板为面试向重点内容,需要重要掌握。

Java 代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Solution {
int[] nums;
int qselect(int l, int r, int k) {
if (l == r) return nums[k];
int x = nums[l], i = l - 1, j = r + 1;
while (i < j) {
do i++; while (nums[i] < x);
do j--; while (nums[j] > x);
if (i < j) swap(i, j);
}
if (k <= j) return qselect(l, j, k);
else return qselect(j + 1, r, k);
}
void swap(int i, int j) {
int c = nums[i];
nums[i] = nums[j];
nums[j] = c;
}
public int findKthLargest(int[] _nums, int k) {
nums = _nums;
int n = nums.length;
return qselect(0, n - 1, n - k);
}
}

C++ 代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class Solution {
public:
vector<int> nums;
int qselect(int l, int r, int k) {
if (l == r) return nums[k];
int x = nums[l], i = l - 1, j = r + 1;
while (i < j) {
do i++; while (nums[i] < x);
do j--; while (nums[j] > x);
if (i < j) swap(nums[i], nums[j]);
}
if (k <= j) return qselect(l, j, k);
else return qselect(j + 1, r, k);
}
int findKthLargest(vector<int>& _nums, int k) {
nums = _nums;
int n = nums.size();
return qselect(0, n - 1, n - k);
}
};

Python 代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class Solution:
def findKthLargest(self, nums: List[int], k: int) -> int:
def qselect(l, r, k):
if l == r:
return nums[k]
x, i, j = nums[l], l - 1, r + 1
while i < j:
i += 1
while nums[i] < x:
i += 1
j -= 1
while nums[j] > x:
j -= 1
if i < j:
nums[i], nums[j] = nums[j], nums[i]
if k <= j:
return qselect(l, j, k)
else:
return qselect(j + 1, r, k)

n = len(nums)
return qselect(0, n - 1, n - k)

TypeScript 代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
function findKthLargest(nums: number[], k: number): number {
const qselect = function(l: number, r: number, k: number): number {
if (l === r) return nums[k];
const x = nums[l];
let i = l - 1, j = r + 1;
while (i < j) {
i++;
while (nums[i] < x) i++;
j--;
while (nums[j] > x) j--;
if (i < j) [nums[i], nums[j]] = [nums[j], nums[i]];
}
if (k <= j) return qselect(l, j, k);
else return qselect(j + 1, r, k);
};
const n = nums.length;
return qselect(0, n - 1, n - k);
};

  • 时间复杂度:期望 $O(n)$
  • 空间复杂度:忽略递归带来的额外空间开销,复杂度为 $O(1)$

最后

这是我们「刷穿 LeetCode」系列文章的第 No.215 篇,系列开始于 2021/01/01,截止于起始日 LeetCode 上共有 1916 道题目,部分是有锁题,我们将先把所有不带锁的题目刷完。

在这个系列文章里面,除了讲解解题思路以外,还会尽可能给出最为简洁的代码。如果涉及通解还会相应的代码模板。

为了方便各位同学能够电脑上进行调试和提交代码,我建立了相关的仓库:https://github.com/SharingSource/LogicStack-LeetCode

在仓库地址里,你可以看到系列文章的题解链接、系列文章的相应代码、LeetCode 原题链接和其他优选题解。