互联网笔试常用数据结构与算法总结(python 模板)

常用数据结构和算法模板(python)

经典

1.埃拉托斯特尼筛法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def countPrimes(n):
'''
输出 <= n 的质数个数
:param n: 整数
:return: 质数个数
'''
if n < 2:
return 0

isPrime = [True] * (n + 1)
isPrime[0] = isPrime[1] = False
i = 2
while i * i <= n:
if isPrime[i]:
for j in range(i*i, n+1, i):
isPrime[j] = False
i += 1

cnt = 0
for flag in isPrime:
if flag:
cnt += 1
return cnt

参考习题 https://leetcode-cn.com/problems/count-primes/, 答案如下,和模板略有区别:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Solution:
def countPrimes(self, n: int) -> int:
'''
返回小于 n 的质数个数
:param n:
:return:
'''
if n <= 2:
return 0

isPrime = [True] * (n)
isPrime[0] = isPrime[1] = False
i = 2
while i * i < n:
if isPrime[i]:
for j in range(i*i, n, i):
isPrime[j] = False
i += 1

cnt = 0
for flag in isPrime:
if flag:
cnt += 1
return cnt

2. 快速幂

1
2
3
4
5
6
7
8
9
10
11
def myPow(x: float, n: int) -> float:
res = 1.0
base = x
e = abs(n)
while e:
if e & 1 == 1:
res *= base

base *= base
e >>= 1
return res if n >= 0 else 1.0 / res

练习题:https://leetcode-cn.com/problems/powx-n/

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class Solution:
def myPow(self, x: float, n: int) -> float:

def quick_pow(x, n):
base = x
ans = 1

while n:
if n & 1:
ans *= base

base *= base
n >>= 1

return ans
return quick_pow(x, n) if n > 0 else 1/quick_pow(x, abs(n))

3. 大数模拟

大数加法

练习题:leetcode 415 https://leetcode-cn.com/problems/add-strings/

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def addStrings(self, num1: str, num2: str) -> str:
'''
大数加法
'''
res = ""
i, j, carry = len(num1) - 1, len(num2) - 1, 0
while i >= 0 or j >= 0:
# 高位补零
n1 = int(num1[i]) if i >= 0 else 0
n2 = int(num2[j]) if j >= 0 else 0
tmp = n1 + n2 + carry
carry = tmp // 10
res = str(tmp % 10) + res
i, j = i - 1, j - 1
return "1" + res if carry else res

大数乘法

练习题:leetcode 43 https://leetcode.com/problems/multiply-strings/

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def multiply(num1, num2):
product = [0] * (len(num1) + len(num2)) # placeholder for multiplication ndigit by mdigit result in n+m digits
position = len(product) - 1 # position within the placeholder

for n1 in num1[::-1]:
tempPos = position
for n2 in num2[::-1]:
product[tempPos] += int(n1) * int(n2) # ading the results of single multiplication
product[tempPos - 1] += product[tempPos] // 10 # bring out carry number to the left array
product[tempPos] %= 10 # remove the carry out from the current array
tempPos -= 1 # first shifting the multplication to the end of the first integer
position -= 1 # then once first integer is exhausted shifting the second integer and starting

# once the second integer is exhausted we want to make sure we are not zero padding
pointer = 0 # pointer moves through the digit array and locate where the zero padding finishes
while pointer < len(product) - 1 and product[
pointer] == 0: # if we have zero before the numbers shift the pointer to the right
pointer += 1

return ''.join(map(str, product[pointer:])) # only report the digits to the right side of the pointer

计算进位: 计算 carry = tmp // 10,代表当前位相加是否产生进位;
添加当前位: 计算 tmp = n1 + n2 + carry,并将当前位 tmp % 10 添加至 res 头部;
索引溢出处理: 当指针 i或j 走过数字首部后,给 n1,n2 赋值为 00,相当于给 num1,num2 中长度较短的数字前面填 00,以便后续计算。
当遍历完 num1,num2 后跳出循环,并根据 carry 值决定是否在头部添加进位 11,最终返回 res 即可。
复杂度分析:

时间复杂度 O(max(M,N))O(max(M,N)):其中 MM,NN 为 22 数字长度,按位遍历一遍数字(以较长的数字为准);
空间复杂度 O(1)O(1):指针与变量使用常数大小空间。

4. GCD

1
2
3
4
5
6
7
8
def gcd(a, b):
'''
这里不用判断a, b的相对大小,如果 a < b, 在递归调用 gcd(b, a%b)时
自动调换了顺序。
'''
if b == 0:
return a
return gcd(b, a % b)

参考习题:https://leetcode-cn.com/problems/simplified-fractions/,最简分数,gcd的应用

1
2
3
4
5
6
7
8
9
10
11
12
13
class Solution:
def simplifiedFractions(self, n: int) -> List[str]:
def gcd(a, b):
if b == 0:
return a
return gcd(b, a % b)

res = []
for i in range(2, n+1):
for j in range(1, i):
if gcd(j, i) == 1:
res.append(str(j) + '/' + str(i))
return res

5. LCM (最小公倍数)

1
2
def lcm(a, b):
return a * b // gcd(a, b)

7. 二分搜索

万能模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# 当分支逻辑不能排除右边界,选左中位数,如果选右中位数则会出现死循环
def binary_search1(left, right):
while left < right:
mid = (left + right) >> 1
if check(mid):
left = mid + 1
else:
right = mid
# 退出循环的时候, 视情况,是否需要单独判断left是否满足条件
return left

def binary_search2(left, right):
while left < right:
# 选择右中位数
while left < right:
mid = (left + right + 1) >> 1
if check(mid):
right = mid - 1
else:
left = mid
# 退出循环的时候, 视情况,是否需要单独判断left是否满足条件
return left

8. 并查集

1
2
3
4
5
6
7
8
9
10
11
12
13
parent = list(range(N))

def find(x):
if x != parent[x]:
# 路径完全压缩
parent[x] = find(parent[x])

return parent[x]

def union(x, y):
root1 = find(x)
root2 = find(y)
parent[root2] = root1

类形式:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class DSU:
def __init__(self, N):
self.parent = list(range(N+1))
self.edges = 0

def find(self, x):
if x != self.parent[x]:
# 路径完全压缩
self.parent[x] = self.find(self.parent[x])

return self.parent[x]

def union(self, x, y):
root1 = self.find(x)
root2 = self.find(y)

self.parent[root2] = root1
self.edges += 1

图论

9. 最小生成树(贪心思想)

解析 花花酱:https://www.youtube.com/watch?v=wmW8G8SrXDs

图论500 题 https://blog.csdn.net/luomingjun12315/article/details/47438607

Kruskal算法

主体部分如下,需要用到并查集(下面完整测试案例中会给出)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def kruskal():
'''
适用于求稀疏图
:return: 最小生成树的边和
'''
cost = 0
for u, v, w in sorted(edges, key=lambda x: x[2]):
pu, pv = find(u), find(v)
if pu == pv:
continue
# 等同于union
parent[pu] = pv
cost += w
return cost

leetcode中缺少相关的习题,固用花花酱视频中的例子测试了一下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
n = 4
edges = [[0, 1, 1], [0, 3, 3], [0, 2, 6], [2, 3, 2], [1, 2, 4]]

parent = list(range(n))

def find(x):
if x != parent[x]:
# 路径完全压缩
parent[x] = find(parent[x])

return parent[x]


def union(x, y):
root1 = find(x)
root2 = find(y)
parent[root2] = root1


def kruskal():
'''
适用于求稀疏图
:return: 最小生成树的边和
'''
cost = 0
for u, v, w in sorted(edges, key=lambda x: x[2]):
pu, pv = find(u), find(v)
if pu == pv:
continue
# 等同于union
parent[pu] = pv
cost += w
return cost

Prim算法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def prime():
'''
适用于稠密图,堆优化版本
:return: cost
'''
q = []
cost = 0
seen = set()
# push a dummy node, (cost, node)
heappush(q, (0, 0))

for _ in range(n):
w, u = heappop(q)
if u in seen:
continue
cost += w
seen.add(u)
for v, w in graph[u]:
if v in seen:
continue
heappush(q, (w, v))
return cost

测试代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from collections import defaultdict
from heapq import *

n = 4 # 顶点数
edges = [[0, 1, 1], [0, 3, 3], [0, 2, 6], [2, 3, 2], [1, 2, 4]]

graph = defaultdict(list)
for e in edges:
graph[e[0]].append((e[1], e[2]))
graph[e[1]].append((e[0], e[2]))

def prime():
'''
适用于稠密图,堆优化版本
:return: cost
'''
q = []
cost = 0
seen = set()
# push a dummy node, (cost, node)
heappush(q, (0, 0))

for _ in range(n):
w, u = heappop(q)
if u in seen:
continue
cost += w
seen.add(u)
for v, w in graph[u]:
if v in seen:
continue
heappush(q, (w, v))
return cost

print(prime())
# 6

最短路算法合集

统一习题:LC743 https://leetcode-cn.com/problems/network-delay-time/

最短路算法的分类:

  • 单源最短路
    • 所有边权都是正数
      • 朴素的Dijkstra算法 O(n^2) 适合稠密图
      • 堆优化版的Dijkstra算法 O(mlog n)(m是图中节点的个数)适合稀疏图
    • 存在负权边
      • Bellman-Ford O(nm)
      • spfa 一般O(m), 最坏O(nm)
  • 多源汇最短路 Floyd算法 O(n^3)

参考代码:https://leetcode.com/problems/network-delay-time/discuss/283711/python-bellman-ford-spfa-dijkstra-floyd-clean-and-easy-to-understand

11. 迪杰斯特拉算法

单源最短路

  • 所有边权都是正数
  • 朴素的Dijkstra算法 O(n^2) 适合稠密图
  • 堆优化版的Dijkstra算法 O(mlog n)(m是图中节点的个数)适合稀疏图
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def dijkstra(graph, source, N):
'''
单源最短路径算法
:param graph: 邻接矩阵,或者用字典实现
:param source: 起始点
:param N: 节点个数
:return:
'''
# 如果是node 是 1-indexed
dist = [float('inf')] * (N + 1)
prev = [-1] * (N + 1)

dist[source] = dist[0] = 0

hq = [(0, source)]
while hq:
d, u = heapq.heappop(hq)
for v in graph[u]:
alt = dist[u] + graph[u][v]
if alt < dist[v]:
dist[v] = alt
prev[v] = u
heapq.heappush(hq, (alt, v))
return dist, prev

习题解答:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class Solution:
def networkDelayTime(self, times: List[List[int]], N: int, K: int) -> int:
# construct graph
graph = collections.defaultdict(dict)
for u, v, time in times:
graph[u][v] = time


def dijkstra(graph, source, N):
'''
单源最短路径算法
:param graph: 邻接矩阵,或者用字典实现
:param source: 起始点
:return:
'''
dist = [float('inf')] * (N + 1)
prev = [-1] * (N + 1)

# 如果是node 是 1-indexed
dist[source] = dist[0] = 0

hq = [(0, source)]
while hq:
d, u = heapq.heappop(hq)
for v in graph[u]:
alt = dist[u] + graph[u][v]
if alt < dist[v]:
dist[v] = alt
prev[v] = u
heapq.heappush(hq, (alt, v))
return dist, prev

dist, _ = dijkstra(graph, K, N)
# print(dist)
ans = max(dist)
return ans if ans != float('inf') else -1

Bellman-Ford 算法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# Python3 program for Bellman-Ford's single source
# shortest path algorithm.

# Class to represent a graph
class Graph:

def __init__(self, vertices):
self.V = vertices # No. of vertices
self.graph = []


# function to add an edge to graph
def addEdge(self, u, v, w):
self.graph.append([u, v, w])


# utility function used to print the solution
def printArr(self, dist):
print("Vertex Distance from Source")
for i in range(self.V):
print("{0}\t\t{1}".format(i, dist[i]))

# The main function that finds shortest distances from src to
# all other vertices using Bellman-Ford algorithm. The function
# also detects negative weight cycle
def BellmanFord(self, src):

# Step 1: Initialize distances from src to all other vertices
# as INFINITE
dist = [float("Inf")] * self.V
dist[src] = 0

# Step 2: Relax all edges |V| - 1 times. A simple shortest
# path from src to any other vertex can have at-most |V| - 1
# edges
for _ in range(self.V - 1):
# Update dist value and parent index of the adjacent vertices of
# the picked vertex. Consider only those vertices which are still in
# queue
for u, v, w in self.graph:
if dist[u] != float("Inf") and dist[u] + w < dist[v]:
dist[v] = dist[u] + w

# Step 3: check for negative-weight cycles. The above step
# guarantees shortest distances if graph doesn't contain
# negative weight cycle. If we get a shorter path, then there
# is a cycle.

for u, v, w in self.graph:
if dist[u] != float("Inf") and dist[u] + w < dist[v]:
print("Graph contains negative weight cycle")
return

# print all distance
self.printArr(dist)


g = Graph(5)
g.addEdge(0, 1, -1)
g.addEdge(0, 2, 4)
g.addEdge(1, 2, 3)
g.addEdge(1, 3, 2)
g.addEdge(1, 4, 2)
g.addEdge(3, 2, 5)
g.addEdge(3, 1, 1)
g.addEdge(4, 3, -3)

# Print the solution
g.BellmanFord(0)

# Initially, Contributed by Neelam Yadav
# Later On, Edited by Himanshu Garg
1
2
3
4
5
6
7
8
9
class Solution:
def networkDelayTime(self, times: List[List[int]], N: int, K: int) -> int:
dist = [float("inf") for _ in range(N)]
dist[K-1] = 0
for _ in range(N-1):
for u, v, w in times:
if dist[u-1] + w < dist[v-1]:
dist[v-1] = dist[u-1] + w
return max(dist) if max(dist) < float("inf") else -1

12. spfa

判断有无负环:如果某个点进入队列的次数超过N次则存在负环(SPFA无法处理带负环的图,但是可以判断是否出现负权环)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class Solution:
def networkDelayTime(self, times: List[List[int]], N: int, K: int) -> int:
dist = [float("inf") for _ in range(N)]
K -= 1
dist[K] = 0
weight = collections.defaultdict(dict)
for u, v, w in times:
weight[u-1][v-1] = w

queue = collections.deque([K])
while queue:
u = queue.popleft()
for v in weight[u]:
if dist[u] + weight[u][v] < dist[v]:
dist[v] = dist[u] + weight[u][v]
queue.append(v)
return max(dist) if max(dist) < float("inf") else -1

13. Floyd-Warshall

1
2
3
4
5
6
7
8
9
10
def floyd_warshall(graph, N):
'''
:param graph: 邻接矩阵
:param N: 节点数
:return: 修改过的邻接矩阵
'''
for k in range(N):
for i in range(N):
for j in range(N):
dist[i][j] = min(dist[i][j], dist[i][k] + dist[k][j])

习题答案:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class Solution:
def networkDelayTime(self, times: List[List[int]], N: int, K: int) -> int:
# construct graph
dist = [[float("inf") for _ in range(N)] for _ in range(N)]
for u, v, w in times:
dist[u - 1][v - 1] = w
for i in range(N):
dist[i][i] = 0

def floyd_warshall(graph, N):
'''
:param graph: 邻接矩阵
:param N: 节点数
:return: 修改过的邻接矩阵
'''
for k in range(N):
for i in range(N):
for j in range(N):
dist[i][j] = min(dist[i][j], dist[i][k] + dist[k][j])

floyd_warshall(dist, N)
return max(dist[K - 1]) if max(dist[K - 1]) < float("inf") else -1

二分图

14. 染色法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import collections


def isBipartite(graph):
n = len(graph)
UNCOLORED, RED, GREEN = 0, 1, 2
color = [UNCOLORED] * n

# graph 不一定是连通图
for i in range(n):
if color[i] == UNCOLORED:
q = collections.deque([i])
color[i] = RED
while q:
node = q.popleft()
cNei = (GREEN if color[node] == RED else RED)
for neighbor in graph[node]:
if color[neighbor] == UNCOLORED:
q.append(neighbor)
color[neighbor] = cNei
elif color[neighbor] != cNei:
return False, None

return True, color

graph = [[1,3], [0,2], [1,3], [0,2]]
print(isBipartite(graph))
'''
(True, [1, 2, 1, 2])
'''

15. 匈牙利算法 (用于寻找最大匹配)

讲解:https://www.renfei.org/blog/bipartite-matching.html

https://blog.csdn.net/dark_scope/article/details/8880547

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class DFS_hungary():
# https://www.icode9.com/content-1-615590.html
# 参数初始化
def __init__(self, set_A, set_B, edge, cx, cy, visited):
self.set_A, self.set_B = set_A, set_B # 顶点集合
self.edge = edge # 顶点是否连边
self.cx, self.cy = cx, cy # 顶点是否匹配
self.visited = visited # 顶点是否被访问
self.M = [] # 匹配
self.res = 0 # 匹配数

# 遍历顶点A集合,得到最大匹配
def max_match(self):
for i in self.set_A:
if self.cx[i] == -1: # 未匹配
for key in self.set_B: # 将visited置0表示未访问过
self.visited[key] = 0
self.res += self.path(i)
print('i', i, 'M',self.M)


# 增广路置换获得更大的匹配
def path(self, u):
for v in self.set_B:
if self.edge[u][v] and (not self.visited[v]): # 如果可连且未被访问过
self.visited[v] = 1 # 访问该顶点
if self.cy[v] == -1: # 如果未匹配, 则建立匹配
self.cx[u], self.cy[v] = v, u
self.M.append((u, v))
return 1
else:
self.M.remove((self.cy[v], v)) # 如果匹配则删除之前的匹配
if self.path(self.cy[v]): # 递归调用
self.cx[u], self.cy[v] = v, u
self.M.append((u, v))
return 1
print('v', v, 'M', self.M)
return 0


if __name__ == '__main__':
set_A, set_B = ['A', 'B', 'C', 'D'], ['E', 'F', 'G', 'H']
edge = {'A': {'E': 1, 'F': 0, 'G': 1, 'H': 0}, 'B': {'E': 0, 'F': 1, 'G': 0, 'H': 1},
'C': {'E': 1, 'F': 0, 'G': 0, 'H': 1}, 'D': {'E': 0, 'F': 0, 'G': 1, 'H': 0}} # 1表示可以匹配,0表示不能匹配
cx, cy = {'A': -1, 'B': -1, 'C': -1, 'D': -1}, {'E': -1, 'F': -1, 'G': -1, 'H': -1}
visited = {'E': 0, 'F': 0, 'G': 0, 'H': 0}
dh = DFS_hungary(set_A, set_B, edge, cx, cy, visited)
dh.max_match()
print('res', dh.res)
print('cx', cx)
print('cy', cy)
print('visited', visited)

# 结果显示:
# i A M [('A', 'E')] # 对于E点,可与A点连接,第一次匹配,直接在max_match打印,存在增广路:CEAG
# v E M [('A', 'E')] # 对于E点,不能和B点连接,在path中打印
# i B M [('A', 'E'), ('B', 'F')] # 对于F点,可与B点连接,直接在max_match打印,匹配数增加,存在增广路:CEAG
# v E M [('B', 'F')] # 对于E点,可以与C连接,但已经与A点连接,从M中移除AE,在path中打印,进入递归内部
# v F M [('B', 'F')]
# i C M [('B', 'F'), ('A', 'G'), ('C', 'E')] # 对于G点,可与A点连接,直接在max_match打印,匹配数增加,存在增广路:DGAECH
# v E M [('B', 'F'), ('A', 'G'), ('C', 'E')] # 对于E点,不能与D点连接,在path中打印
# v F M [('B', 'F'), ('A', 'G'), ('C', 'E')] # 对于F点,不能与D点连接,在path中打印
# v E M [('B', 'F')] # 对于G点,可以与D连接,但已经与A点连接,从M中移除AG,在path中打印,进入递归内部,继续移除CE
# v F M [('B', 'F')]
# v G M [('B', 'F')]
# i D M [('B', 'F'), ('C', 'H'), ('A', 'E'), ('D', 'G')] # 无增广路
# res 4
# cx {'A': 'E', 'B': 'F', 'C': 'H', 'D': 'G'}
# cy {'E': 'A', 'F': 'B', 'G': 'D', 'H': 'C'}
# visited {'E': 1, 'F': 0, 'G': 1, 'H': 1}

动态规划

16. 背包问题

reference:https://zhuanlan.zhihu.com/p/93857890

0-1背包

  1. 不装入第i件物品,即dp[i−1][j]
  2. 装入第i件物品(前提是能装下),即dp[i−1][j−w[i]] + v[i]

即状态转移方程为

1
dp[i][j] = max(dp[i−1][j], dp[i−1][j−w[i]]+v[i]) // j >= w[i]

由上述状态转移方程可知,dp[i][j]的值只与dp[i-1][0,...,j-1]有关,所以我们可以采用动态规划常用的方法(滚动数组)对空间进行优化(即去掉dp的第一维)。需要注意的是,为了防止上一层循环的dp[0,...,j-1]被覆盖,循环的时候 j 只能逆向枚举(空间优化前没有这个限制),伪代码为:

1
2
3
4
5
// 01背包问题伪代码(空间优化版)
dp[0,...,W] = 0
for i = 1,...,N
for j = W,...,w[i] // 必须逆向枚举!!!
dp[j] = max(dp[j], dp[j−w[i]]+v[i])

动态规划的核心思想避免重复计算在01背包问题中体现得淋漓尽致。第i件物品装入或者不装入而获得的最大价值完全可以由前面i-1件物品的最大价值决定,暴力枚举忽略了这个事实。

完全背包

分析一
  1. 不装入第i种物品,即dp[i−1][j],同01背包;
  2. 装入第i种物品,此时和01背包不太一样,因为每种物品有无限个(但注意书包限重是有限的),所以此时不应该转移到dp[i−1][j−w[i]]而应该转移到dp[i][j−w[i]],即装入第i种商品后还可以再继续装入第种商品。

所以状态转移方程为

1
dp[i][j] = max(dp[i−1][j], dp[i][j−w[i]]+v[i]) // j >= w[i]

这个状态转移方程与01背包问题唯一不同就是max第二项不是dp[i-1]而是dp[i]。

和01背包问题类似,也可进行空间优化,优化后不同点在于这里的 j 只能正向枚举而01背包只能逆向枚举,因为这里的max第二项是dp[i]而01背包是dp[i-1],即这里就是需要覆盖而01背包需要避免覆盖。所以伪代码如下:

1
2
3
4
5
// 完全背包问题思路一伪代码(空间优化版)
dp[0,...,W] = 0
for i = 1,...,N
for j = w[i],...,W // 必须正向枚举!!!
dp[j] = max(dp[j], dp[j−w[i]]+v[i])

由上述伪代码看出,01背包和完全背包问题此解法的空间优化版解法唯一不同就是前者的 j 只能逆向枚举而后者的 j 只能正向枚举,这是由二者的状态转移方程决定的。此解法时间复杂度为O(NW), 空间复杂度为O(W)。

分析二

除了分析一的思路外,完全背包还有一种常见的思路,但是复杂度高一些。我们从装入第 i 种物品多少件出发,01背包只有两种情况即取0件和取1件,而这里是取0件、1件、2件…直到超过限重(k > j/w[i]),所以状态转移方程为:

1
2
# k为装入第i种物品的件数, k <= j/w[i]
dp[i][j] = max{(dp[i-1][j − k*w[i]] + k*v[i]) for every k}

同理也可以进行空间优化,需要注意的是,这里max里面是dp[i-1],和01背包一样,所以 j 必须逆向枚举,优化后伪代码为

1
2
3
4
5
6
// 完全背包问题思路二伪代码(空间优化版)
dp[0,...,W] = 0
for i = 1,...,N
for j = W,...,w[i] // 必须逆向枚举!!!
for k = [0, 1,..., j/w[i]]
dp[j] = max(dp[j], dp[j−k*w[i]]+k*v[i])

相比于分析一,此种方法不是在O(1)时间求得dp[i][j],所以总的时间复杂度就比分析一大些了,为 [公式]级别。

分析三、转换成01背包

01背包问题是最基本的背包问题,我们可以考虑把完全背包问题转化为01背包问题来解:将一种物品转换成若干件只能装入0件或者1件的01背包中的物品。

最简单的想法是,考虑到第 i 种物品最多装入 W/w[i] 件,于是可以把第 i 种物品转化为 W/w[i] 件费用及价值均不变的物品,然后求解这个01背包问题。

更高效的转化方法是采用二进制的思想:把第 i 种物品拆成重量为 [公式] 、价值为 [公式] 的若干件物品,其中 k 取遍满足 [公式] 的非负整数。这是因为不管最优策略选几件第 i 种物品,总可以表示成若干个刚才这些物品的和(例:13 = 1 + 4 + 8)。这样就将转换后的物品数目降成了对数级别。

多重背包

分析一

此时的分析和完全背包的分析二差不多,也是从装入第 i 种物品多少件出发:装入第i种物品0件、1件、…n[i]件(还要满足不超过限重)。所以状态方程为:

1
2
# k为装入第i种物品的件数, k <= min(n[i], j/w[i])
dp[i][j] = max{(dp[i-1][j − k*w[i]] + k*v[i]) for every k}

同理也可以进行空间优化,而且 j 也必须逆向枚举,优化后伪代码为

1
2
3
4
5
6
// 完全背包问题思路二伪代码(空间优化版)
dp[0,...,W] = 0
for i = 1,...,N
for j = W,...,w[i] // 必须逆向枚举!!!
for k = [0, 1,..., min(n[i], j/w[i])]
dp[j] = max(dp[j], dp[j−k*w[i]]+k*v[i])

总的时间复杂度约为 [公式] 级别。

其他情形

参考https://blog.csdn.net/weixin_41162823/article/details/87878853

1 恰好装满

背包问题有时候还有一个限制就是必须恰好装满背包,此时基本思路没有区别,只是在初始化的时候有所不同。

如果没有恰好装满背包的限制,我们将dp全部初始化成0就可以了。因为任何容量的背包都有一个合法解“什么都不装”,这个解的价值为0,所以初始时状态的值也就全部为0了。如果有恰好装满的限制,那只应该将dp[0,…,N][0]初始为0,其它dp值均初始化为-inf,因为此时只有容量为0的背包可以在什么也不装情况下被“恰好装满”,其它容量的背包初始均没有合法的解,应该被初始化为-inf

2 求方案总数

除了在给定每个物品的价值后求可得到的最大价值外,还有一类问题是问装满背包或将背包装至某一指定容量的方案总数。对于这类问题,需要将状态转移方程中的 max 改成 sum ,大体思路是不变的。例如若每件物品均是完全背包中的物品,转移方程即为

1
dp[i][j] = sum(dp[i−1][j], dp[i][j−w[i]]) // j >= w[i]
3 二维背包

前面讨论的背包容量都是一个量:重量。二维背包问题是指每个背包有两个限制条件(比如重量和体积限制),选择物品必须要满足这两个条件。此类问题的解法和一维背包问题不同就是dp数组要多开一维,其他和一维背包完全一样,例如5.4节。

4 求最优方案

一般而言,背包问题是要求一个最优值,如果要求输出这个最优值的方案,可以参照一般动态规划问题输出方案的方法:记录下每个状态的最优值是由哪一个策略推出来的,这样便可根据这条策略找到上一个状态,从上一个状态接着向前推即可。

以01背包为例,我们可以再用一个数组G[i][j]来记录方案,设 G[i][j] = 0表示计算 dp[i][j] 的值时是采用了max中的前一项(也即dp[i−1][j]),G[i][j] = 1 表示采用了方程的后一项。即分别表示了两种策略: 未装入第 i 个物品及装了第 i 个物品。其实我们也可以直接从求好的dp[i][j]反推方案:若 dp[i][j] = dp[i−1][j] 说明未选第i个物品,反之说明选了。

Leetcode相关练习题

0 - 1 背包问题:416. 分割等和子集

题目给定一个只包含正整数的非空数组。问是否可以将这个数组分割成两个子集,使得两个子集的元素和相等。

由于所有元素的和sum已知,所以两个子集的和都应该是sum/2(所以前提是sum不能是奇数),即题目转换成从这个数组里面选取一些元素使这些元素和为sum/2。如果我们将所有元素的值看做是物品的重量,每件物品价值都为1,所以这就是一个恰好装满的01背包问题。

我们定义空间优化后的状态数组dp,由于是恰好装满,所以应该将dp[0]初始化为0而将其他全部初始化为INT_MIN,然后按照类似1.2节的伪代码更新dp:

1
2
3
4
5
6
int capacity = sum / 2;
vector<int>dp(capacity + 1, INT_MIN);
dp[0] = 0;
for(int i = 1; i <= n; i++)
for(int j = capacity; j >= nums[i-1]; j--)
dp[j] = max(dp[j], 1 + dp[j - nums[i-1]]);

更新完毕后,如果dp[sum/2]大于0说明满足题意。

由于此题最后求的是能不能进行划分,所以dp的每个元素定义成bool型就可以了,然后将dp[0]初始为true其他初始化为false,而转移方程就应该是用或操作而不是max操作。完整代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class Solution:
def canPartition(self, nums: List[int]) -> bool:
n = len(nums)
s = sum(nums)

if s % 2:
return False

capacity = s // 2
dp = [False] * (capacity + 1)
dp[0] = True

for i in range(1, n+1):
for j in range(capacity, nums[i-1]-1, -1):
dp[j] = dp[j] or dp[j - nums[i-1]]

return dp[capacity]

完全背包问题:322. 零钱兑换

题目给定一个价值amount和一些面值,假设每个面值的硬币数都是无限的,问我们最少能用几个硬币组成给定的价值。

如果我们将面值看作是物品,面值金额看成是物品的重量,每件物品的价值均为1,这样此题就是是一个恰好装满的完全背包问题了。不过这里不是求最多装入多少物品而是求最少,我们只需要将2.2节的转态转移方程中的max改成min即可,又由于是恰好装满,所以除了dp[0],其他都应初始化为INT_MAX。完整代码如下:

1
2
3
4
5
6
7
8
9
class Solution:
def coinChange(self, coins: List[int], amount: int) -> int:
dp = [float('inf')] * (amount + 1)
dp[0] = 0

for i in range(len(coins)):
for j in range(coins[i], amount+1):
dp[j] = min(dp[j], 1 + dp[j - coins[i]])
return dp[amount] if dp[amount] != float('inf') else -1

17. 最长上升子序列

1
2
3
4
5
6
7
8
9
10
11
class Solution:
def lengthOfLIS(self, nums: List[int]) -> int:
if not nums:
return 0
dp = []
for i in range(len(nums)):
dp.append(1)
for j in range(i):
if nums[i] > nums[j]:
dp[i] = max(dp[i], dp[j] + 1)
return max(dp)

贪心加二分优化:

1
2
3
4
5
6
7
8
9
10
11
class Solution:
def lengthOfLIS(self, nums: List[int]) -> int:
if not nums:
return 0
dp = []
for i in range(len(nums)):
dp.append(1)
for j in range(i):
if nums[i] > nums[j]:
dp[i] = max(dp[i], dp[j] + 1)
return max(dp)

18. 最长公共子序列

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class Solution:
def longestCommonSubsequence(self, text1: str, text2: str) -> int:
m = len(text1)
n = len(text2)

dp = [[0] * (n+1) for _ in range(m+1)]

for i in range(1, m+1):
for j in range(1, n+1):
if text1[i-1] == text2[j-1]:
dp[i][j] = dp[i-1][j-1] + 1
else:
dp[i][j] = max(dp[i-1][j], dp[i][j-1])

return dp[-1][-1]

字符串

26. KMP 字符串匹配

练习题 https://leetcode.com/problems/implement-strstr/

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Python program for KMP Algorithm
def KMPSearch(pat, txt):
M = len(pat)
N = len(txt)

if M == 0:
return 0

# create lps[] that will hold the longest prefix suffix
# values for pattern
lps = [0] * M
j = 0 # index for pat[]

# Preprocess the pattern (calculate lps[] array)
computeLPSArray(pat, M, lps)

i = 0 # index for txt[]
while i < N:
if pat[j] == txt[i]:
i += 1
j += 1

if j == M:
# print("Found pattern at index " + str(i-j))
return i - j
j = lps[j - 1]

# mismatch after j matches
elif i < N and pat[j] != txt[i]:
# Do not match lps[0..lps[j-1]] characters,
# they will match anyway
if j != 0:
j = lps[j - 1]
else:
i += 1
return -1

def computeLPSArray(pat, M, lps):
len = 0 # length of the previous longest prefix suffix

lps[0] # lps[0] is always 0
i = 1

# the loop calculates lps[i] for i = 1 to M-1
while i < M:
if pat[i] == pat[len]:
len += 1
lps[i] = len
i += 1
else:
# This is tricky. Consider the example.
# AAACAAAA and i = 7. The idea is similar
# to search step.
if len != 0:
len = lps[len - 1]

# Also, note that we do not increment i here
else:
lps[i] = 0
i += 1


txt = "ABABDABACDABABCABAB"
pat = "ABABCABAB"
print(KMPSearch(pat, txt))

# This code is contributed by Bhavya Jain

27. 字典树

练习题:https://leetcode.com/problems/implement-trie-prefix-tree/

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class TrieNode:
# Initialize your data structure here.
def __init__(self):
self.children = collections.defaultdict(TrieNode)
self.is_word = False

class Trie:

def __init__(self):
self.root = TrieNode()

def insert(self, word):
current = self.root
for letter in word:
current = current.children[letter]
current.is_word = True

def search(self, word):
current = self.root
for letter in word:
current = current.children.get(letter)
if current is None:
return False
return current.is_word

def startsWith(self, prefix):
current = self.root
for letter in prefix:
current = current.children.get(letter)
if current is None:
return False
return True

区间查询

29. 线段树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
class Node:
def __init__(self, start, end):
self.start = start
self.end = end
self.total = 0
self.left = None
self.right = None


class NumArray:
def __init__(self, nums: List[int]):
# helper function to create a segment tree
def create_tree(left, right, nums):
if left > right:
return None

if left == right:
node = Node(left, right)
node.total = nums[left]
return node

mid = (left + right) // 2

node = Node(left, right)

node.left = create_tree(left, mid, nums)
node.right = create_tree(mid + 1, right, nums)

node.total = node.left.total + node.right.total

return node

self.root = create_tree(0, len(nums) - 1, nums)

def update(self, i: int, val: int) -> None:

def update_tree(root, i, val):

if root.start == root.end and root.start == i:
root.total = val
return

mid = (root.start + root.end) // 2

if i <= mid:
update_tree(root.left, i, val)

else:
update_tree(root.right, i, val)

root.total = root.left.total + root.right.total

update_tree(self.root, i, val)

def sumRange(self, i: int, j: int) -> int:

def get_sum(root, i, j):

if root.start == i and root.end == j:
return root.total

mid = (root.start + root.end) // 2

# in left tree
if j <= mid:
return get_sum(root.left, i, j)
elif mid < i:
return get_sum(root.right, i, j)
else:
l_s = get_sum(root.left, i, mid)
r_s = get_sum(root.right, mid + 1, j)
return l_s + r_s

return get_sum(self.root, i, j)

30. 树状数组

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class FenwickTree:
def __init__(self, n):
self.sums_ = [0] * (n + 1)

def update(self, i, delta):
while i < len(self.sums_):
self.sums_[i] += delta
i += self.lowbit(i)

def query(self, i):
sum_ = 0
while i > 0:
sum_ += self.sums_[i]
i -= self.lowbit(i)
return sum_

def lowbit(self, x):
return x & (-x)
0%