【转】分享|算法竞赛中 Python 的优化技巧

站长留言:转自 分享|算法竞赛中 Python 的优化技巧 – FatalError

PyPy 下 STD/IO 效率对比 CPython 高很多。但仍然建议优化。

前言

本文旨在介绍在算法竞赛使用 Python 时,在复杂度正确且不变的前提下,通过一些技巧优化程序、缩短运行时间。以下介绍的技巧分为三个维度评价:修改的复杂程度,优化的显著程度,以及实用程度。

值得注意的是,Python 自身性质决定了其无论如何优化,都无法通过数据范围较大、复杂度较高的题目1。此外,用 Python 写完代码再额外优化的时间,并不见得比直接用 C++ 快。因此尽管 Python 有一些优秀的语言特性,但如果想从事专业算法竞赛,还是放弃对 Python 的执念、多学一门 C++ 为好。

由于以下原因,本文内容多为经验性、有不够严谨之处:尚不完全明确原理;需要对比实验验证;可能在 CPython/PyPy 以及不同版本的 Python 上表现有差异。错误之处敬请指正。


读写

【简单,显著,实用】读写是耗时瓶颈之一。在输入行数较多时,使用标准输入 sys.stdin 相较于使用 input 优化明显。

import sys
input = lambda: sys.stdin.readline().rstrip()  # 删除行末换行符
II = lambda: int(input())
LII = lambda: list(map(int, input().split()))

【简单,显著,实用】如果还想要避免频繁地读取,还可以一次性读入所有输入到内存。

import sys
it = map(int, sys.stdin.read().split())
II = lambda: next(it)
# 如果输入包含字符串,则可以修改为
# it = iter(sys.stdin.read().split())
# SI = lambda: next(it)
# II = lambda: int(SI())

【简单,不显著,实用】同理,避免频繁地输出,还可以把所有结果暂存下来再统一输出。

output = []
for _ in range(n):
    ans = solve()
    output.append(ans)
print(*output, sep='\n')

【复杂】BufferedReader, BufferedWriter 实现过于麻烦,故未测试。


数据类型

int

【简单,不显著,实用】取模优化:

# 修改前
ans = 0
for i in range(n):
    ans = (ans + comb(n, i) * pow(2, i, MOD) % MOD) % MOD

# 修改后
ans = 0
for i in range(n):
    ans += comb(n, i) * pow(2, i, MOD)
ans %= MOD

【简单,不显著,实用】float(‘inf’) 是浮点数,比较运算很慢,所以尽量用大整数。

# 修改前
from math import inf
inf = float('inf')

# 修改后
inf = 1 << 60
dis = [inf] * n

str

【简单,显著,实用】字符串拼接优化:

# 修改前
ans = ''
for s in strs:
    ans += s

# 修改后
ans = ''.join(strs)

【简单,不显著,不实用】bytearray 模拟可变字符串:

# 修改前
t = list(s)
t[0] = 'a'
s = ''.join(t)

# 修改后
t = bytearray(s, encoding='ascii')
t[0] = ord('a')
s = t.decode('ascii')

list

【简单,不显著,实用】使用 enumerate:

# 修改前
for i in range(len(nums)):
    x = nums[i]

# 修改后
for i, x in enumerate(nums):

【简单,不显著,实用】提前分配空间:

# 修改前
nums = []
for i in range(n):
    nums.append(i)

# 修改后
nums = [0] * n
for i in range(n):
    nums[i] = i

【简单,显著,实用】多维 list 优化:

# 修改前
n, k = 10**5, 20
dp = [[0] * k for _ in range(n)]

# 修改后
n, k = 10**5, 20
dp = [[0] * n for _ in range(k)]

【简单,显著,实用】二维转一维:

# 修改前
dp = [[0] * n for _ in range(m)]

# 修改后
dp = [0] * (m*n)
compress = lambda i, j: i*n+j
decompress = lambda k: divmod(k, n)

【复杂,显著,实用】链式前向星代替邻接表:

# 修改前
g = [[] for _ in range(n)]
def add_edge(u: int, v: int, w: int):
    g[u].append((v, w))

# 修改后
head = [-1] * n
to = [-1] * m
weight = [0] * m
nxt = [-1] * m
ptr = 0
def add_edge(u: int, v: int, w: int):
    nonlocal ptr
    to[ptr] = v
    weight[ptr] = w
    nxt[ptr] = head[u]
    head[u] = ptr
    ptr += 1

【简单,显著,实用】array.array 替代 list:

from array import array
nums = array('i', [0] * n)

【简单,显著,不实用】布尔数组用 bytearray:

vis = bytearray(bytes(n))

【简单,显著,不实用】ctypes C 数组:

from ctypes import c_int32
rank = (c_int32 * n)()
pa = (c_int32 * n)(*range(n))

tuple

【简单,不显著,不实用】多个 list 替代 tuple:

# 修改前
items = [(w1, v1), (w2, v2), ...]

# 修改后
weights = [w1, w2, ...]
values = [v1, v2, ...]

dict

【简单,显著,实用】dict 替换为 list:

# 修改前
g = defaultdict(list)

# 修改后
g = [[] for _ in range(n)]

【简单,不显著,实用】遍历 dict 用 .items():

# 修改前
for k in mp:
    v = mp[k]

# 修改后
for k, v in mp.items():

【简单,不显著,不实用】清空 dict 直接新建:

mp = {}

【简单,显著,实用】defaultdict(int) 代替 Counter:

from collections import defaultdict
cnt = defaultdict(int)

【简单,不显著,实用】避免不必要键插入:

x = mp.get(k, 0)

【简单,显著,实用】随机化防哈希冲突:

from random import getrandbits
RD = getrandbits(31)
pos = defaultdict(list)
for i, x in enumerate(nums):
    pos[x ^ RD].append(i)

【简单,显著,实用】离散化:

sarr = sorted(set(nums))
mp = {x: i for i, x in enumerate(sarr)}
nums = [mp[x] for x in nums]

deque

【显著,不实用】数组模拟队列更快:

q = [0] * n
head, tail = 0, 1
while head < tail:
    u = q[head]
    head += 1
    for v in g[u]:
        q[tail] = v
        tail += 1

函数

【简单,不显著,实用】accumulate 优化:

from itertools import accumulate
pres = list(accumulate(nums, initial=0))

【简单,显著,实用】手写 min/max:

fmin = lambda x, y: x if x < y else y
fmax = lambda x, y: x if x > y else y

【简单,显著,实用】手写快速幂:

def qpow(x, k):
    res = 1
    while k:
        if k & 1:
            res = res * x % MOD
        x = x * x % MOD
        k >>= 1
    return res

【简单,不显著,实用】生成器优化:

s = sum(x**2 for x in range(n))

【简单,显著,不实用】避免 sum(list, []) 拼接:

longlist = []
for lst in lsts:
    longlist.extend(lst)

【复杂,显著,实用】迭代改写递归 DFS:

order = []
parents = [-1] * len(tree)
stk = [root]
while stk:
    u = stk.pop()
    order.append(u)
    for v in g[u]:
        if parents[u] != v:
            parents[v] = u
            stk.append(v)

【复杂,显著,实用】数组代替类:

class StaticTrie:
    def __init__(self, lengths):
        lengths += 1
        self.children = [[-1] * lengths for _ in range(26)]
        self.isend = [False] * lengths
        self.cnt = [0] * lengths
        self.ptr = 1

【简单,显著,实用】__slots__ 优化:

class DSU:
    __slots__ = 'parent', 'size'
    def __init__(self, n: int):
        self.parent = list(range(n))
        self.size = [1] * n

参考

[1]  Python performance tips. https://codeforces.com/blog/entry/21851

[2]  PyRival. https://github.com/cheran-senthil/PyRival/blob/master/pyrival/misc/bootstrap.py

[3]  Python Docs. https://docs.python.org/zh-cn/3.13/reference/datamodel.html#object.__slots

[4]  AtCoder Library Python. https://github.com/not522/ac-library-python