def permutations(n):
a = np.zeros((np.math.factorial(n), n), np.uint8)
f = 1
for m in range(2, n+1):
b = a[:f, n-m+1:] # the block of permutations of range(m-1)
for i in range(1, m):
a[i*f:(i+1)*f, n-m] = i
a[i*f:(i+1)*f, n-m+1:] = b + (b >= i)
b += 1
f *= m
return a
演示:
>>> permutations(3)
array([[0, 1, 2],
[0, 2, 1],
[1, 0, 2],
[1, 2, 0],
[2, 0, 1],
[2, 1, 0]], dtype=uint8)
对于n=10,使用itertools解决方案需要5.5秒,而这个NumPy解决方案只需要0.2秒。
它的过程是这样的:它从一个目标尺寸的零数组开始,该数组已经在右上角包含了range(1)
的排列(我“点掉”了数组的其他部分):
[[. . 0]
[. . .]
[. . .]
[. . .]
[. . .]
[. . .]]
然后将它转换为range(2)
的排列组合:
[[. 0 1]
[. 1 0]
[. . .]
[. . .]
[. . .]
[. . .]]
然后进入 range(3)
的排列组合:
[[0 1 2]
[0 2 1]
[1 0 2]
[1 2 0]
[2 0 1]
[2 1 0]]
它通过填充下一列左侧并复制/修改上一个排列块向下来实现。def faster_permutations(n):
# empty() is fast because it does not initialize the values of the array
# order='F' uses Fortran ordering, which makes accessing elements in the same column fast
perms = np.empty((np.math.factorial(n), n), dtype=np.uint8, order='F')
perms[0, 0] = 0
rows_to_copy = 1
for i in range(1, n):
perms[:rows_to_copy, i] = i
for j in range(1, i + 1):
start_row = rows_to_copy * j
end_row = rows_to_copy * (j + 1)
splitter = i - j
perms[start_row: end_row, splitter] = i
perms[start_row: end_row, :splitter] = perms[:rows_to_copy, :splitter] # left side
perms[start_row: end_row, splitter + 1:i + 1] = perms[:rows_to_copy, splitter:i] # right side
rows_to_copy *= i + 1
return perms
在我的计算机上,使用 n=11
的时间如下:
faster_permutations(): 0.12 seconds
permutations() [superb rain's approach]: 1.44 seconds
permutations() with memory order optimization: 0.62 seconds
根据superb rain的回答,这是一个更快的版本,具有更高效的内存访问模式:
def fast_permutations(n):
a = np.zeros((n, np.math.factorial(n)), np.uint8)
f = 1
for m in range(2, n + 1):
b = a[n - m + 1:, :f] # the block of permutations of range(m-1)
for i in range(1, m):
a[n - m, i * f:(i + 1) * f] = i
a[n - m + 1:, i * f:(i + 1) * f] = b + (b >= i)
b += 1
f *= m
return a.T
n=10
,0.05 秒与 0.12 秒相比)。n=11
时,超级大脑的版本需要6.7秒,而你的版本只需要3.5秒。因此,你的解决方案确实看起来更快。谢谢! - Artya = np.zeros((n, np.math.factorial(n)), np.uint8).T
替换单行a = np.zeros((np.math.factorial(n), n), np.uint8)
即可。 - Arty由于我没有找到一个好的/快速的解决方案,所以我决定使用Numba JIT/AOT代码编译器/优化器从头实现全排列算法。
对于足够大的n
,我的下一个基于numba的解决方案比使用itertools.permutations(...)
执行相同任务要快25x-50x
倍。请参见代码后的计时。
如果一次迭代一个排列,我的代码只比itertools.permutations(...)
快1.25x
,但根据最初的问题,我需要整个所有排列的数组或至少迭代大块。
iter_ = True, iter_batches = False
)或者一次迭代一批排列,这样更快(iter_ = True, iter_batches = True
),或者返回所有排列的整个数组而不进行迭代(iter_ = False
)。还可以通过 batch_size = 1000
来调整批量大小。next_batch(...)
,它实际上实现了给定先前排列生成下一个排列的整个算法。这是唯一由numba JIT / AOT编译的函数,其余都是辅助纯Python包装器。# Needs: python -m pip install numba numpy timerit
def permutations(
n, *, iter_ = True, numba_ = True, numba_aot = False,
batch_size = 1000, iter_batches = False, state = {},
):
key = (bool(numba_), bool(numba_aot))
if key in state:
return state[key](int(n), bool(iter_), int(batch_size), bool(iter_batches))
def prepare(numba_, numba_aot):
import numpy as np
def next_batch(a, r):
c, n = r.shape[0], r.shape[1]
for ic in range(c):
r[ic] = a
a = r[ic]
for i in range(n - 2, -1, -1):
if a[i] < a[i + 1]:
break
else:
assert False # Already last permutation
for j in range(n - 1, i, -1):
if a[i] < a[j]:
break
a[i], a[j] = a[j], a[i]
for k in range(1, (n - i + 1) >> 1):
a[i + k], a[n - k] = a[n - k], a[i + k]
def factorial(n):
res = 1
for i in range(2, n + 1):
res *= i
return res
def permutations_iter(nxb, n, batch_size, iter_batches):
a = np.arange(n, dtype = np.uint8)
if iter_batches:
yield a[None, :]
else:
yield a
if n <= 1:
return
total = factorial(n)
for i in range(1, total, batch_size):
batch = np.empty((min(batch_size, total - i), n), dtype = np.uint8)
nxb(a, batch)
if iter_batches:
yield batch
else:
yield from iter(batch)
a = batch[-1]
def permutations_arr(nxb, n, batch_size):
total = factorial(n)
res = np.empty((total, n), dtype = np.uint8)
res[0] = np.arange(n, dtype = np.uint8)
for i in range(1, total, batch_size):
nxb(res[i - 1], res[i : i + min(batch_size, total - i)])
return res
if not numba_:
return lambda n, it, bs, ib: permutations_iter(next_batch, n, bs, ib) if it else permutations_arr(next_batch, n, bs)
else:
if not numba_aot:
import numba
nxb = numba.njit('void(u1[:], u1[:, :])', cache = True)(next_batch)
else:
import numba, numba.pycc
cc = numba.pycc.CC('permutations_numba')
cc.export('next_batch', 'void(u1[:], u1[:, :])')(next_batch)
cc.compile()
from permutations_numba import next_batch as nxb
return lambda n, it, bs, ib: permutations_iter(nxb, n, bs, ib) if it else permutations_arr(nxb, n, bs)
state[key] = prepare(numba_, numba_aot)
return state[key](int(n), bool(iter_), int(batch_size), bool(iter_batches))
def test():
import numpy as np, itertools
from timerit import Timerit
Timerit._default_asciimode = True
# Heat-up / pre-compile
permutations(2, numba_ = False)
permutations(2, numba_ = True)
for n in range(12):
num = 99 if n <= 7 else 15 if n <= 8 else 3 if n <= 9 else 1
print('-' * 60 + f'\nn = {str(n).rjust(2)}')
print(f'itertools : ', end = '', flush = True)
for t in Timerit(num = num, verbose = 1):
with t:
ref = np.array(list(itertools.permutations(range(n))), dtype = np.uint8)
def superbrain(n):
a = np.zeros((n, np.math.factorial(n)), np.uint8).T
f = 1
for m in range(2, n+1):
b = a[:f, n-m+1:] # the block of permutations of range(m-1)
for i in range(1, m):
a[i*f:(i+1)*f, n-m] = i
a[i*f:(i+1)*f, n-m+1:] = b + (b >= i)
b += 1
f *= m
return a
print(f'superbrain : ', end = '', flush = True)
for t in Timerit(num = num, verbose = 1):
with t:
cur = superbrain(n)
assert np.array_equal(ref, cur)
if n <= 9:
print(f'python_array : ', end = '', flush = True)
for t in Timerit(num = num, verbose = 1):
with t:
curpa = permutations(n, iter_ = False, numba_ = False)
assert np.array_equal(ref, curpa)
for batch_size in [10, 100, 1000, 10000]:
print(f'batch_size = {str(batch_size).rjust(5)}')
print(f'numba_iter : ', end = '', flush = True)
for t in Timerit(num = num, verbose = 1):
with t:
curi = np.array(list(permutations(n, iter_ = True, numba_ = True, batch_size = batch_size)))
assert np.array_equal(ref, curi)
print(f'numba_iter_batches : ', end = '', flush = True)
for t in Timerit(num = num, verbose = 1):
with t:
curib = np.concatenate(list(permutations(n, iter_ = True, numba_ = True, batch_size = batch_size, iter_batches = True)))
assert np.array_equal(ref, curib)
print(f'numba_array : ', end = '', flush = True)
for t in Timerit(num = num, verbose = 1):
with t:
cura = permutations(n, iter_ = False, numba_ = True, batch_size = batch_size)
assert np.array_equal(ref, cura)
if __name__ == '__main__':
test()
------------------------------------------------------------
n = 0
itertools : Timed best=8.210 us, mean=8.335 +- 0.4 us
python_array : Timed best=14.881 us, mean=15.457 +- 0.5 us
batch_size = 10
numba_iter : Timed best=15.908 us, mean=16.126 +- 0.3 us
numba_iter_batches : Timed best=17.447 us, mean=17.929 +- 0.3 us
numba_array : Timed best=15.394 us, mean=15.519 +- 0.3 us
batch_size = 100
numba_iter : Timed best=15.908 us, mean=16.250 +- 0.3 us
numba_iter_batches : Timed best=17.447 us, mean=18.038 +- 0.2 us
numba_array : Timed best=15.394 us, mean=15.519 +- 0.3 us
batch_size = 1000
numba_iter : Timed best=15.908 us, mean=16.328 +- 0.3 us
numba_iter_batches : Timed best=17.960 us, mean=18.069 +- 0.2 us
numba_array : Timed best=15.394 us, mean=15.441 +- 0.1 us
batch_size = 10000
numba_iter : Timed best=15.908 us, mean=16.328 +- 0.2 us
numba_iter_batches : Timed best=17.448 us, mean=17.976 +- 0.2 us
numba_array : Timed best=14.881 us, mean=15.410 +- 0.3 us
------------------------------------------------------------
n = 1
itertools : Timed best=7.697 us, mean=7.790 +- 0.3 us
python_array : Timed best=14.882 us, mean=15.488 +- 0.3 us
batch_size = 10
numba_iter : Timed best=15.908 us, mean=16.064 +- 0.3 us
numba_iter_batches : Timed best=17.960 us, mean=18.318 +- 0.3 us
numba_array : Timed best=14.881 us, mean=15.348 +- 0.3 us
batch_size = 100
numba_iter : Timed best=15.908 us, mean=16.203 +- 0.3 us
numba_iter_batches : Timed best=17.960 us, mean=18.054 +- 0.2 us
numba_array : Timed best=15.394 us, mean=15.472 +- 0.2 us
batch_size = 1000
numba_iter : Timed best=15.908 us, mean=16.421 +- 0.1 us
numba_iter_batches : Timed best=17.960 us, mean=18.147 +- 0.3 us
numba_array : Timed best=14.882 us, mean=15.379 +- 0.2 us
batch_size = 10000
numba_iter : Timed best=15.908 us, mean=16.095 +- 0.2 us
numba_iter_batches : Timed best=17.960 us, mean=18.132 +- 0.3 us
numba_array : Timed best=14.881 us, mean=15.395 +- 0.3 us
------------------------------------------------------------
n = 2
itertools : Timed best=8.723 us, mean=8.786 +- 0.2 us
python_array : Timed best=29.250 us, mean=29.670 +- 0.4 us
batch_size = 10
numba_iter : Timed best=34.381 us, mean=35.035 +- 0.7 us
numba_iter_batches : Timed best=30.276 us, mean=30.790 +- 0.4 us
numba_array : Timed best=22.579 us, mean=22.672 +- 0.2 us
batch_size = 100
numba_iter : Timed best=34.381 us, mean=34.584 +- 0.3 us
numba_iter_batches : Timed best=30.277 us, mean=30.836 +- 0.2 us
numba_array : Timed best=22.066 us, mean=22.595 +- 0.2 us
batch_size = 1000
numba_iter : Timed best=34.381 us, mean=34.739 +- 0.4 us
numba_iter_batches : Timed best=30.277 us, mean=30.851 +- 0.3 us
numba_array : Timed best=22.579 us, mean=22.626 +- 0.1 us
batch_size = 10000
numba_iter : Timed best=34.381 us, mean=34.786 +- 0.4 us
numba_iter_batches : Timed best=30.276 us, mean=30.650 +- 0.3 us
numba_array : Timed best=22.066 us, mean=22.641 +- 0.3 us
------------------------------------------------------------
n = 3
itertools : Timed best=12.829 us, mean=13.093 +- 0.3 us
python_array : Timed best=62.606 us, mean=63.461 +- 0.6 us
batch_size = 10
numba_iter : Timed best=39.513 us, mean=40.120 +- 0.4 us
numba_iter_batches : Timed best=31.302 us, mean=31.661 +- 0.2 us
numba_array : Timed best=22.579 us, mean=23.077 +- 0.3 us
batch_size = 100
numba_iter : Timed best=39.513 us, mean=40.042 +- 0.2 us
numba_iter_batches : Timed best=31.302 us, mean=31.629 +- 0.3 us
numba_array : Timed best=22.579 us, mean=23.154 +- 0.2 us
batch_size = 1000
numba_iter : Timed best=39.513 us, mean=39.840 +- 0.4 us
numba_iter_batches : Timed best=31.302 us, mean=31.629 +- 0.4 us
numba_array : Timed best=22.579 us, mean=23.170 +- 0.2 us
batch_size = 10000
numba_iter : Timed best=39.513 us, mean=40.120 +- 0.5 us
numba_iter_batches : Timed best=30.789 us, mean=31.412 +- 0.3 us
numba_array : Timed best=23.092 us, mean=23.232 +- 0.3 us
------------------------------------------------------------
n = 4
itertools : Timed best=34.381 us, mean=34.911 +- 0.4 us
python_array : Timed best=207.830 us, mean=209.152 +- 1.0 us
batch_size = 10
numba_iter : Timed best=82.619 us, mean=83.054 +- 0.7 us
numba_iter_batches : Timed best=44.645 us, mean=44.754 +- 0.2 us
numba_array : Timed best=31.302 us, mean=31.458 +- 0.2 us
batch_size = 100
numba_iter : Timed best=63.632 us, mean=64.036 +- 0.4 us
numba_iter_batches : Timed best=32.329 us, mean=32.889 +- 0.2 us
numba_array : Timed best=24.118 us, mean=24.600 +- 0.3 us
batch_size = 1000
numba_iter : Timed best=63.632 us, mean=64.083 +- 0.5 us
numba_iter_batches : Timed best=32.329 us, mean=32.904 +- 0.3 us
numba_array : Timed best=24.118 us, mean=24.569 +- 0.3 us
batch_size = 10000
numba_iter : Timed best=63.119 us, mean=63.927 +- 0.4 us
numba_iter_batches : Timed best=32.329 us, mean=32.889 +- 0.5 us
numba_array : Timed best=24.118 us, mean=24.461 +- 0.3 us
------------------------------------------------------------
n = 5
itertools : Timed best=156.001 us, mean=166.311 +- 20.5 us
python_array : Timed best=0.999 ms, mean=1.002 +- 0.0 ms
batch_size = 10
numba_iter : Timed best=293.528 us, mean=294.461 +- 0.8 us
numba_iter_batches : Timed best=102.632 us, mean=103.254 +- 0.4 us
numba_array : Timed best=64.145 us, mean=64.985 +- 0.5 us
batch_size = 100
numba_iter : Timed best=198.080 us, mean=199.107 +- 0.8 us
numba_iter_batches : Timed best=44.132 us, mean=44.894 +- 0.4 us
numba_array : Timed best=33.355 us, mean=33.884 +- 0.3 us
batch_size = 1000
numba_iter : Timed best=186.791 us, mean=187.522 +- 0.4 us
numba_iter_batches : Timed best=37.973 us, mean=38.471 +- 0.3 us
numba_array : Timed best=29.763 us, mean=30.183 +- 0.3 us
batch_size = 10000
numba_iter : Timed best=186.790 us, mean=187.646 +- 0.7 us
numba_iter_batches : Timed best=37.974 us, mean=38.534 +- 0.3 us
numba_array : Timed best=29.763 us, mean=30.245 +- 0.3 us
------------------------------------------------------------
n = 6
itertools : Timed best=0.991 ms, mean=1.007 +- 0.0 ms
python_array : Timed best=5.873 ms, mean=6.012 +- 0.0 ms
batch_size = 10
numba_iter : Timed best=1.668 ms, mean=1.673 +- 0.0 ms
numba_iter_batches : Timed best=503.411 us, mean=506.506 +- 1.2 us
numba_array : Timed best=293.015 us, mean=296.047 +- 1.2 us
batch_size = 100
numba_iter : Timed best=1.036 ms, mean=1.145 +- 0.3 ms
numba_iter_batches : Timed best=120.593 us, mean=132.878 +- 23.0 us
numba_array : Timed best=93.908 us, mean=97.438 +- 2.4 us
batch_size = 1000
numba_iter : Timed best=962.178 us, mean=976.624 +- 23.9 us
numba_iter_batches : Timed best=78.001 us, mean=82.992 +- 7.7 us
numba_array : Timed best=68.250 us, mean=69.852 +- 4.3 us
batch_size = 10000
numba_iter : Timed best=963.717 us, mean=977.044 +- 27.3 us
numba_iter_batches : Timed best=77.487 us, mean=80.084 +- 7.5 us
numba_array : Timed best=68.250 us, mean=69.634 +- 4.4 us
------------------------------------------------------------
n = 7
itertools : Timed best=8.502 ms, mean=8.579 +- 0.0 ms
python_array : Timed best=41.690 ms, mean=42.358 +- 0.8 ms
batch_size = 10
numba_iter : Timed best=11.523 ms, mean=11.646 +- 0.2 ms
numba_iter_batches : Timed best=3.407 ms, mean=3.497 +- 0.1 ms
numba_array : Timed best=1.944 ms, mean=1.975 +- 0.0 ms
batch_size = 100
numba_iter : Timed best=7.050 ms, mean=7.397 +- 0.3 ms
numba_iter_batches : Timed best=659.925 us, mean=668.198 +- 5.9 us
numba_array : Timed best=503.411 us, mean=506.086 +- 3.3 us
batch_size = 1000
numba_iter : Timed best=6.576 ms, mean=6.630 +- 0.0 ms
numba_iter_batches : Timed best=382.305 us, mean=389.707 +- 4.4 us
numba_array : Timed best=354.081 us, mean=360.364 +- 4.3 us
batch_size = 10000
numba_iter : Timed best=6.463 ms, mean=6.504 +- 0.0 ms
numba_iter_batches : Timed best=349.976 us, mean=352.091 +- 1.5 us
numba_array : Timed best=330.989 us, mean=337.194 +- 1.8 us
------------------------------------------------------------
n = 8
itertools : Timed best=71.003 ms, mean=71.824 +- 0.5 ms
python_array : Timed best=331.176 ms, mean=339.746 +- 7.3 ms
batch_size = 10
numba_iter : Timed best=99.929 ms, mean=101.098 +- 1.3 ms
numba_iter_batches : Timed best=27.489 ms, mean=27.905 +- 0.3 ms
numba_array : Timed best=15.370 ms, mean=15.560 +- 0.1 ms
batch_size = 100
numba_iter : Timed best=62.168 ms, mean=62.765 +- 0.7 ms
numba_iter_batches : Timed best=5.083 ms, mean=5.119 +- 0.0 ms
numba_array : Timed best=3.824 ms, mean=3.842 +- 0.0 ms
batch_size = 1000
numba_iter : Timed best=57.706 ms, mean=57.935 +- 0.2 ms
numba_iter_batches : Timed best=2.824 ms, mean=2.832 +- 0.0 ms
numba_array : Timed best=2.656 ms, mean=2.670 +- 0.0 ms
batch_size = 10000
numba_iter : Timed best=57.457 ms, mean=60.128 +- 2.1 ms
numba_iter_batches : Timed best=2.615 ms, mean=2.635 +- 0.0 ms
numba_array : Timed best=2.550 ms, mean=2.565 +- 0.0 ms
------------------------------------------------------------
n = 9
itertools : Timed best=724.017 ms, mean=724.017 +- 0.0 ms
python_array : Timed best=3.071 s, mean=3.071 +- 0.0 s
batch_size = 10
numba_iter : Timed best=950.892 ms, mean=950.892 +- 0.0 ms
numba_iter_batches : Timed best=261.376 ms, mean=261.376 +- 0.0 ms
numba_array : Timed best=145.207 ms, mean=145.207 +- 0.0 ms
batch_size = 100
numba_iter : Timed best=584.761 ms, mean=584.761 +- 0.0 ms
numba_iter_batches : Timed best=50.632 ms, mean=50.632 +- 0.0 ms
numba_array : Timed best=39.945 ms, mean=39.945 +- 0.0 ms
batch_size = 1000
numba_iter : Timed best=535.190 ms, mean=535.190 +- 0.0 ms
numba_iter_batches : Timed best=29.557 ms, mean=29.557 +- 0.0 ms
numba_array : Timed best=26.541 ms, mean=26.541 +- 0.0 ms
batch_size = 10000
numba_iter : Timed best=533.592 ms, mean=533.592 +- 0.0 ms
numba_iter_batches : Timed best=27.507 ms, mean=27.507 +- 0.0 ms
numba_array : Timed best=25.115 ms, mean=25.115 +- 0.0 ms
------------------------------------------------------------
n = 10
itertools : Timed best=15.483 s, mean=15.483 +- 0.0 s
batch_size = 10
numba_iter : Timed best=24.163 s, mean=24.163 +- 0.0 s
numba_iter_batches : Timed best=6.039 s, mean=6.039 +- 0.0 s
numba_array : Timed best=3.246 s, mean=3.246 +- 0.0 s
batch_size = 100
numba_iter : Timed best=13.891 s, mean=13.891 +- 0.0 s
numba_iter_batches : Timed best=1.136 s, mean=1.136 +- 0.0 s
numba_array : Timed best=890.228 ms, mean=890.228 +- 0.0 ms
batch_size = 1000
numba_iter : Timed best=12.768 s, mean=12.768 +- 0.0 s
numba_iter_batches : Timed best=693.685 ms, mean=693.685 +- 0.0 ms
numba_array : Timed best=658.007 ms, mean=658.007 +- 0.0 ms
batch_size = 10000
numba_iter : Timed best=11.175 s, mean=11.175 +- 0.0 s
numba_iter_batches : Timed best=278.304 ms, mean=278.304 +- 0.0 ms
numba_array : Timed best=251.208 ms, mean=251.208 +- 0.0 ms
------------------------------------------------------------
n = 11
itertools : Timed best=95.118 s, mean=95.118 +- 0.0 s
batch_size = 10
numba_iter : Timed best=124.414 s, mean=124.414 +- 0.0 s
numba_iter_batches : Timed best=75.427 s, mean=75.427 +- 0.0 s
numba_array : Timed best=28.079 s, mean=28.079 +- 0.0 s
batch_size = 100
numba_iter : Timed best=70.749 s, mean=70.749 +- 0.0 s
numba_iter_batches : Timed best=6.084 s, mean=6.084 +- 0.0 s
numba_array : Timed best=4.357 s, mean=4.357 +- 0.0 s
batch_size = 1000
numba_iter : Timed best=67.576 s, mean=67.576 +- 0.0 s
numba_iter_batches : Timed best=8.572 s, mean=8.572 +- 0.0 s
numba_array : Timed best=6.915 s, mean=6.915 +- 0.0 s
batch_size = 10000
numba_iter : Timed best=123.208 s, mean=123.208 +- 0.0 s
numba_iter_batches : Timed best=3.348 s, mean=3.348 +- 0.0 s
numba_array : Timed best=2.789 s, mean=2.789 +- 0.0 s
.so
/.pyd
模块。如果它们在速度上相差5%-30%
,那完全没问题。 - Artya = np.zeros((n, np.math.factorial(n)), np.uint8).T
替换a = np.zeros((np.math.factorial(n), n), np.uint8)
,代码就会快近乎两倍。 - Artyn=10
的情况下,您的解决方案的时间几乎与我的Numba解决方案相同。在@DanielGiger的建议之后(在我上面的上一个评论中描述),您的改进代码的运行时间约为我的Numba解决方案的60%,因此大约快1.8倍。您可以从我的答案运行代码,只是今天我添加了与您的解决方案(使用@DanielGieger的改进)的时间比较,因此您可以将其与我的Numba代码进行比较。您可以在我的代码中运行计时,直到n=10
,因为在itertools中,当n=11
时需要很长时间。 - Arty