LC 1268. 搜索推荐系统

题目描述

这是 LeetCode 上的 1268. 搜索推荐系统 ,难度为 中等

给你一个产品数组 products 和一个字符串 searchWordproducts 数组中每个产品都是一个字符串。

请你设计一个推荐系统,在依次输入单词 searchWord 的每一个字母后,推荐 products 数组中前缀与 searchWord 相同的最多三个产品。如果前缀相同的可推荐产品超过三个,请按字典序返回最小的三个。

请你以二维列表的形式,返回在输入 searchWord 每个字母后相应的推荐产品的列表。

示例 1:

1
2
3
4
5
6
7
8
9
10
11
12
输入:products = ["mobile","mouse","moneypot","monitor","mousepad"], searchWord = "mouse"
输出:[
["mobile","moneypot","monitor"],
["mobile","moneypot","monitor"],
["mouse","mousepad"],
["mouse","mousepad"],
["mouse","mousepad"]
]

解释:按字典序排序后的产品列表是 ["mobile","moneypot","monitor","mouse","mousepad"]
输入 m 和 mo,由于所有产品的前缀都相同,所以系统返回字典序最小的三个产品 ["mobile","moneypot","monitor"]
输入 mou, mous 和 mouse 后系统都返回 ["mouse","mousepad"]

示例 2:
1
2
3
输入:products = ["havana"], searchWord = "havana"

输出:[["havana"],["havana"],["havana"],["havana"],["havana"],["havana"]]

示例 3:
1
2
3
输入:products = ["bags","baggage","banner","box","cloths"], searchWord = "bags"

输出:[["baggage","bags","banner"],["baggage","bags","banner"],["baggage","bags"],["bags"]]

示例 4:
1
2
3
输入:products = ["havana"], searchWord = "tatiana"

输出:[[],[],[],[],[],[],[]]

提示:

  • $1 <= products.length <= 1000$
  • $1 <= Σ products[i].length <= 2 \times 10^4$
  • products[i] 中所有的字符都是小写英文字母。
  • $1 <= searchWord.length <= 1000$
  • searchWord 中所有字符都是小写英文字母。

排序 + 字典树 + 哈希表

为了方便,将 products 记为 ps,将 searchWord 记为 w

这是一个 "Suggestion string" 问题,容易想到字典树进行求解,不了解字典树的同学,可看 前置 🧀

由于题目要求「若有超过三个的产品可推荐,返回字典序最小的三个」,我们不妨先对 ps 进行排序,使 ps 从前往后满足字典序从小到大。

将所有 ps[i] 按顺序添加到字典树 tr 中,添加过程中,使用两个哈希表 minMapmaxMap 分别记录经过某个 tr[i][j] 时的最小 ps 下标和最大 ps 下标。即哈希表的 key 为具体的 tr[i][j],对应 value 为经过该节点的最小或最大下标。

构建答案时,根据当前 w 子串到字典树 tr 中查询,定位到该子串对应的 tr[i][j] 为何值,再从哈希表中获取建议字符串在 ps 中的左右端点 lr,并根据在 ps[l:r] (可能为空集)中最多取三个的原则来构建答案。

考虑实现字典树的两个关键方法,添加查询

  • 添加函数 void add(String s, int num) :其中 s 为待添加到字典树的字符串,num 则是该字符串在 ps 中的下标编号。

    往字典树添加过程中,按照首次访问字典树节点 tr[i][j] 的下标存入 minMap,最后一次访问字典树节点 tr[i][j] 的下标存入 maxMap 的规则来更新哈希表。

  • 查询函数 int[] query(String s):其中 s 为某个 w 子串,通过查询该子串(最后字符)在字典树的节点值,来得知建议列表对应 ps 的左右端点下标为何值,从而构建答案。

Java 代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class Solution {
int[][] tr = new int[20010][26];
int idx = 0;
Map<Integer, Integer> min = new HashMap<>(), max = new HashMap<>();
void add(String s, int num) {
int p = 0;
for (int i = 0; i < s.length(); i++) {
int u = s.charAt(i) - 'a';
if (tr[p][u] == 0) {
tr[p][u] = ++idx;
min.put(tr[p][u], num);
}
max.put(tr[p][u], num);
p = tr[p][u];
}
}
int[] query(String s) {
int a = -1, b = -1, p = 0;
for (int i = 0; i < s.length(); i++) {
int u = s.charAt(i) - 'a';
if (tr[p][u] == 0) return new int[]{-1, -1};
a = min.get(tr[p][u]); b = max.get(tr[p][u]);
p = tr[p][u];
}
return new int[]{a, b};
}
public List<List<String>> suggestedProducts(String[] ps, String w) {
Arrays.sort(ps);
List<List<String>> ans = new ArrayList<>();
for (int i = 0; i < ps.length; i++) add(ps[i], i);
for (int i = 0; i < w.length(); i++) {
List<String> list = new ArrayList<>();
int[] info = query(w.substring(0, i + 1));
int l = info[0], r = info[1];
for (int j = l; j <= Math.min(l + 2, r) && l != -1; j++) list.add(ps[j]);
ans.add(list);
}
return ans;
}
}

C++ 代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class Solution {
public:
int tr[20010][26] = {0};
int idx = 0;
unordered_map<int, int> minMap, maxMap;
void add(string s, int num) {
int p = 0;
for (char c : s) {
int u = c - 'a';
if (tr[p][u] == 0) {
tr[p][u] = ++idx;
minMap[tr[p][u]] = num;
}
maxMap[tr[p][u]] = num;
p = tr[p][u];
}
}
pair<int, int> query(string s) {
int a = -1, b = -1, p = 0;
for (char c : s) {
int u = c - 'a';
if (tr[p][u] == 0) return {-1, -1};
a = minMap[tr[p][u]];
b = maxMap[tr[p][u]];
p = tr[p][u];
}
return {a, b};
}
vector<vector<string>> suggestedProducts(vector<string>& ps, string w) {
sort(ps.begin(), ps.end());
vector<vector<string>> ans;
for (int i = 0; i < ps.size(); i++) add(ps[i], i);
for (int i = 0; i < w.length(); i++) {
vector<string> list;
pair<int, int> info = query(w.substr(0, i + 1));
int l = info.first, r = info.second;
for (int j = l; j <= min(l + 2, r) && l != -1; j++) list.push_back(ps[j]);
ans.push_back(list);
}
return ans;
}
};

Python 代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class Solution:
def suggestedProducts(self, ps: List[str], w: str) -> List[List[str]]:
tr = defaultdict(lambda: defaultdict(int))
idx = 0
minMap, maxMap = {}, {}
def add(s, num):
nonlocal idx
p = 0
for c in s:
u = ord(c) - ord('a')
if tr[p][u] == 0:
idx += 1
tr[p][u] = idx
minMap[tr[p][u]] = num
maxMap[tr[p][u]] = num
p = tr[p][u]
def query(s):
a, b, p = -1, -1, 0
for c in s:
u = ord(c) - ord('a')
if tr[p][u] == 0:
return (-1, -1)
a = minMap[tr[p][u]]
b = maxMap[tr[p][u]]
p = tr[p][u]
return (a, b)
ps.sort()
ans = []
for i in range(len(ps)):
add(ps[i], i)
for i in range(len(w)):
l, r = query(w[:i + 1])
lst = [ps[j] for j in range(l, min(l + 3, r + 1)) if l != -1]
ans.append(lst)
return ans

TypeScript 代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
let tr: number[][];
let idx: number;
let minMap: Map<number, number>, maxMap: Map<number, number>;
function add(s: string, num: number): void {
let p = 0;
for (let i = 0; i < s.length; i++) {
const u = s.charCodeAt(i) - 'a'.charCodeAt(0);
if (tr[p][u] === 0) {
idx++;
tr[p][u] = idx;
minMap.set(tr[p][u], num);
}
maxMap.set(tr[p][u], num);
p = tr[p][u];
}
}
function query(s: string): number[] {
let a = -1, b = -1, p = 0;
for (let i = 0; i < s.length; i++) {
const u = s.charCodeAt(i) - 'a'.charCodeAt(0);
if (tr[p][u] === 0) return [-1, -1];
a = minMap.get(tr[p][u])!;
b = maxMap.get(tr[p][u])!;
p = tr[p][u];
}
return [a, b];
}
function suggestedProducts(ps: string[], w: string): string[][] {
tr = new Array(20010).fill(0).map(() => new Array(26).fill(0));
idx = 0;
minMap = new Map(), maxMap = new Map();
ps.sort();
const ans = [];
for (let i = 0; i < ps.length; i++) add(ps[i], i);
for (let i = 0; i < w.length; i++) {
const list = [];
const [l, r] = query(w.substring(0, i + 1));
for (let j = l; j <= Math.min(l + 2, r) && l !== -1; j++) list.push(ps[j]);
ans.push(list);
}
return ans;
};

  • 时间复杂度:将 ps 长度记为 nw 长度记为 m。对 ps 进行排序复杂度为 $O(n\log{n})$;构建字典树复杂度为 $O(\sum{i = 0}^{n - 1}ps[i].length)$;根据 w 构建答案复杂度为 $O(m^2)$;整体复杂度为 $O(\sum{i = 0}^{n - 1}ps[i].length + m^2)$
  • 空间复杂度:$O(\sum_{i = 0}^{n - 1}ps[i].length \times C)$,其中 $C = 26$ 为字符集大小

排序 + 二分

由于每个 w 子串只会对应最多三个的建议字符串,同时又可以先通过排序来确保 ps 的有序性。

因此对于每个 w 子串而言,可以 先找到满足要求的,字典序最小的建议字符串 ps[i],接着往后逐个检查,组成最终的建议字符串列表(最多检查三个)

这个「在 ps 中找符合要求,字典序最小的建议字符串」操作,除了能够利用上述解法来做(省掉一个 maxMap)以外,还能利用字符串本身的有序性进行「二分」,因为该操作本质上,是在找第一个满足 ps[i] 大于等于当前子串的位置

Java 代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class Solution {
public List<List<String>> suggestedProducts(String[] ps, String w) {
Arrays.sort(ps);
int n = ps.length;
List<List<String>> ans = new ArrayList<>();
for (int i = 0; i < w.length(); i++) {
String cur = w.substring(0, i + 1);
int l = 0, r = n - 1;
while (l < r) {
int mid = l + r >> 1;
if (ps[mid].compareTo(cur) >= 0) r = mid;
else l = mid + 1;
}
List<String> list = new ArrayList<>();
if (ps[r].compareTo(cur) >= 0) {
for (int j = r; j <= Math.min(n - 1, r + 2); j++) {
if (ps[j].length() < cur.length()) break;
if (!ps[j].substring(0, i + 1).equals(cur)) break;
list.add(ps[j]);
}
}
ans.add(list);
}
return ans;
}
}

C++ 代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class Solution {
public:
vector<vector<string>> suggestedProducts(vector<string>& ps, string w) {
sort(ps.begin(), ps.end());
int n = ps.size();
vector<vector<string>> ans;
for (int i = 0; i < w.length(); i++) {
string cur = w.substr(0, i + 1);
int l = 0, r = n - 1;
while (l < r) {
int mid = (l + r) >> 1;
if (ps[mid].compare(cur) >= 0) r = mid;
else l = mid + 1;
}
vector<string> list;
if (ps[r].compare(cur) >= 0) {
for (int j = r; j <= min(n - 1, r + 2); j++) {
if (ps[j].length() < cur.length()) break;
if (ps[j].substr(0, i + 1) != cur) break;
list.push_back(ps[j]);
}
}
ans.push_back(list);
}
return ans;
}
};

Python 代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class Solution:
def suggestedProducts(self, ps: List[str], w: str) -> List[List[str]]:
ps.sort()
n = len(ps)
ans = []
for i in range(len(w)):
cur = w[:i + 1]
l, r = 0, n - 1
while l < r:
mid = (l + r) // 2
if ps[mid] >= cur:
r = mid
else:
l = mid + 1
lst = []
if ps[r] >= cur:
for j in range(r, min(n - 1, r + 2) + 1):
if len(ps[j]) < len(cur) or ps[j][:i + 1] != cur:
break
lst.append(ps[j])
ans.append(lst)
return ans

TypeScript 代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
function suggestedProducts(ps: string[], w: string): string[][] {
ps.sort();
const n = ps.length;
const ans = [];
for (let i = 0; i < w.length; i++) {
const cur = w.substring(0, i + 1);
let l = 0, r = n - 1;
while (l < r) {
const mid = (l + r) >> 1;
if (ps[mid].localeCompare(cur) >= 0) r = mid;
else l = mid + 1;
}
const list: string[] = [];
if (ps[r].localeCompare(cur) >= 0) {
for (let j = r; j <= Math.min(n - 1, r + 2); j++) {
if (ps[j].length < cur.length || !ps[j].startsWith(cur)) break;
list.push(ps[j]);
}
}
ans.push(list);
}
return ans;
};

  • 时间复杂度:将 ps 长度记为 nw 长度记为 m。对 ps 进行排序复杂度为 $O(n\log{n})$;每次二分需要进行字符串比较,复杂度为 $O(m\log{n})$;二分到左端点后需要往后检查最多三个字符串,复杂度为 $O(m)$。整体复杂度为 $O(n\log{n} + m\log{n})$
  • 空间复杂度:$O(\log{n})$

最后

这是我们「刷穿 LeetCode」系列文章的第 No.1268 篇,系列开始于 2021/01/01,截止于起始日 LeetCode 上共有 1916 道题目,部分是有锁题,我们将先把所有不带锁的题目刷完。

在这个系列文章里面,除了讲解解题思路以外,还会尽可能给出最为简洁的代码。如果涉及通解还会相应的代码模板。

为了方便各位同学能够电脑上进行调试和提交代码,我建立了相关的仓库:https://github.com/SharingSource/LogicStack-LeetCode

在仓库地址里,你可以看到系列文章的题解链接、系列文章的相应代码、LeetCode 原题链接和其他优选题解。