站长留言
转自 分享|算法竞赛中 Python 的优化技巧 - FatalError。
PyPy 下 STD/IO 效率对比 CPython 高很多。但仍然建议优化。
前言
本文旨在介绍在算法竞赛使用 Python 时,在复杂度正确且不变的前提下,通过一些技巧优化程序、缩短运行时间。以下介绍的技巧分为三个维度评价:修改的复杂程度,优化的显著程度,以及实用程度。
值得注意的是,Python 自身性质决定了其无论如何优化,都无法通过数据范围较大、复杂度较高的题目1。此外,用 Python 写完代码再额外优化的时间,并不见得比直接用 C++ 快。因此尽管 Python 有一些优秀的语言特性,但如果想从事专业算法竞赛,还是放弃对 Python 的执念、多学一门 C++ 为好。
由于以下原因,本文内容多为经验性、有不够严谨之处:尚不完全明确原理;需要对比实验验证;可能在 CPython/PyPy 以及不同版本的 Python 上表现有差异。错误之处敬请指正。
读写
【简单,显著,实用】 读写是耗时瓶颈之一。在输入行数较多时,使用标准输入 sys.stdin 相较于使用 input 优化明显。
import sysinput = lambda: sys.stdin.readline().rstrip() # 删除行末换行符II = lambda: int(input())LII = lambda: list(map(int, input().split()))【简单,显著,实用】 如果还想要避免频繁地读取,还可以一次性读入所有输入到内存。
import sysit = map(int, sys.stdin.read().split())II = lambda: next(it)# 如果输入包含字符串,则可以修改为# it = iter(sys.stdin.read().split())# SI = lambda: next(it)# II = lambda: int(SI())【简单,不显著,实用】 同理,避免频繁地输出,还可以把所有结果暂存下来再统一输出。
output = []for _ in range(n): ans = solve() output.append(ans)print(*output, sep='\n')【复杂】 BufferedReader, BufferedWriter 实现过于麻烦,故未测试。
数据类型
int
【简单,不显著,实用】 取模优化:
# 修改前ans = 0for i in range(n): ans = (ans + comb(n, i) * pow(2, i, MOD) % MOD) % MOD
# 修改后ans = 0for i in range(n): ans += comb(n, i) * pow(2, i, MOD)ans %= MOD【简单,不显著,实用】 float('inf') 是浮点数,比较运算很慢,所以尽量用大整数。
# 修改前from math import infinf = float('inf')
# 修改后inf = 1 << 60dis = [inf] * nstr
【简单,显著,实用】 字符串拼接优化:
# 修改前ans = ''for s in strs: ans += s
# 修改后ans = ''.join(strs)【简单,不显著,不实用】 bytearray 模拟可变字符串:
# 修改前t = list(s)t[0] = 'a's = ''.join(t)
# 修改后t = bytearray(s, encoding='ascii')t[0] = ord('a')s = t.decode('ascii')list
【简单,不显著,实用】 使用 enumerate:
# 修改前for i in range(len(nums)): x = nums[i]
# 修改后for i, x in enumerate(nums):【简单,不显著,实用】 提前分配空间:
# 修改前nums = []for i in range(n): nums.append(i)
# 修改后nums = [0] * nfor i in range(n): nums[i] = i【简单,显著,实用】 多维 list 优化:
# 修改前n, k = 10**5, 20dp = [[0] * k for _ in range(n)]
# 修改后n, k = 10**5, 20dp = [[0] * n for _ in range(k)]【简单,显著,实用】 二维转一维:
# 修改前dp = [[0] * n for _ in range(m)]
# 修改后dp = [0] * (m*n)compress = lambda i, j: i*n+jdecompress = lambda k: divmod(k, n)【复杂,显著,实用】 链式前向星代替邻接表:
# 修改前g = [[] for _ in range(n)]def add_edge(u: int, v: int, w: int): g[u].append((v, w))
# 修改后head = [-1] * nto = [-1] * mweight = [0] * mnxt = [-1] * mptr = 0def add_edge(u: int, v: int, w: int): nonlocal ptr to[ptr] = v weight[ptr] = w nxt[ptr] = head[u] head[u] = ptr ptr += 1【简单,显著,实用】 array.array 替代 list:
from array import arraynums = array('i', [0] * n)【简单,显著,不实用】 布尔数组用 bytearray:
vis = bytearray(bytes(n))【简单,显著,不实用】 ctypes C 数组:
from ctypes import c_int32rank = (c_int32 * n)()pa = (c_int32 * n)(*range(n))tuple
【简单,不显著,不实用】 多个 list 替代 tuple:
# 修改前items = [(w1, v1), (w2, v2), ...]
# 修改后weights = [w1, w2, ...]values = [v1, v2, ...]dict
【简单,显著,实用】 dict 替换为 list:
# 修改前g = defaultdict(list)
# 修改后g = [[] for _ in range(n)]【简单,不显著,实用】 遍历 dict 用 .items():
# 修改前for k in mp: v = mp[k]
# 修改后for k, v in mp.items():【简单,不显著,不实用】 清空 dict 直接新建:
mp = {}【简单,显著,实用】 defaultdict(int) 代替 Counter:
from collections import defaultdictcnt = defaultdict(int)【简单,不显著,实用】 避免不必要键插入:
x = mp.get(k, 0)【简单,显著,实用】 随机化防哈希冲突:
from random import getrandbitsRD = getrandbits(31)pos = defaultdict(list)for i, x in enumerate(nums): pos[x ^ RD].append(i)【简单,显著,实用】 离散化:
sarr = sorted(set(nums))mp = {x: i for i, x in enumerate(sarr)}nums = [mp[x] for x in nums]deque
【显著,不实用】 数组模拟队列更快:
q = [0] * nhead, tail = 0, 1while head < tail: u = q[head] head += 1 for v in g[u]: q[tail] = v tail += 1函数
【简单,不显著,实用】 accumulate 优化:
from itertools import accumulatepres = list(accumulate(nums, initial=0))【简单,显著,实用】 手写 min/max:
fmin = lambda x, y: x if x < y else yfmax = lambda x, y: x if x > y else y【简单,显著,实用】 手写快速幂:
def qpow(x, k): res = 1 while k: if k & 1: res = res * x % MOD x = x * x % MOD k >>= 1 return res【简单,不显著,实用】 生成器优化:
s = sum(x**2 for x in range(n))【简单,显著,不实用】 避免 sum(list, []) 拼接:
longlist = []for lst in lsts: longlist.extend(lst)【复杂,显著,实用】 迭代改写递归 DFS:
order = []parents = [-1] * len(tree)stk = [root]while stk: u = stk.pop() order.append(u) for v in g[u]: if parents[u] != v: parents[v] = u stk.append(v)类
【复杂,显著,实用】 数组代替类:
class StaticTrie: def __init__(self, lengths): lengths += 1 self.children = [[-1] * lengths for _ in range(26)] self.isend = [False] * lengths self.cnt = [0] * lengths self.ptr = 1【简单,显著,实用】 __slots__ 优化:
class DSU: __slots__ = 'parent', 'size' def __init__(self, n: int): self.parent = list(range(n)) self.size = [1] * n参考
[1] Python performance tips. https://codeforces.com/blog/entry/21851
[2] PyRival. https://github.com/cheran-senthil/PyRival/blob/master/pyrival/misc/bootstrap.py
[3] Python Docs. https://docs.python.org/zh-cn/3.13/reference/datamodel.html#object.__slots
[4] AtCoder Library Python. https://github.com/not522/ac-library-python
