纯Python
你可能想测试你的代码,因为它似乎没有达到你的预期。请运行这个脚本,将你的代码与我的进行比较并检查输出:
import numpy as np
def find_first(a, index, value):
while index<a.size and a[index]!=value:
index += 1
return index
def find_end(a, index, value):
while index<a.size and a[index]==value:
index += 1
return index
def replace_run(a, begin, end, threshold, replace):
if end-begin+1 > threshold:
a[begin:end] = replace
def process_row(a, value, threshold, replace):
first = 0
while first < a.size:
if a[first]==value:
end = find_end(a, first, value)
replace_run(a, first, end, threshold, replace)
first = end
else:
first = find_first(a, first, value)
def replace_py(a, value, length, replace):
mat = a.copy()
for row in mat:
process_row(row, value, length, replace)
return mat
def replace_runs(a, search, run_length, replace = 2):
a_copy = a.copy()
for i, row in enumerate(a):
runs = []
current_run = []
for j, val in enumerate(row):
if val == search:
current_run.append(j)
else:
if len(current_run) >= run_length or j == len(row) -1:
runs.append(current_run)
current_run = []
if len(current_run) >= run_length or j == len(row) -1:
runs.append(current_run)
for run in runs:
for col in run:
a_copy[i][col] = replace
return a_copy
def print_mismatch(a, b):
print 'Elementwise equals'
mat_equals = a==b
print mat_equals
print 'Reduced to rows'
for i, outcome in enumerate(np.logical_and.reduce(mat_equals, 1)):
print i, outcome
if __name__=='__main__':
np.random.seed(31)
shape = (20, 10)
mat = np.asarray(a=np.random.binomial(1, p=0.5, size=shape), dtype=np.int32)
mat.reshape(shape)
runs = replace_runs(mat, 1, 3, 2)
py = replace_py(mat, 1, 3, 2)
print 'Original'
print mat
print 'replace_runs()'
print runs
print 'replace_py()'
print py
print 'Mismatch between replace_runs() and replace_py()'
print_mismatch(runs, py)
在你的代码不固定之前,基准测试没有意义。因此,我将使用我的replace_py()
函数进行基准测试。
我认为replace_py()
实现了你想要的功能,但它不符合Pythonic,有许多反模式。尽管如此,它似乎是正确的。
时间:
np.random.seed(31)
shape = (100000, 10)
mat = np.asarray(a=np.random.binomial(1, p=0.5, size=shape), dtype=np.int32)
mat.reshape(shape)
%timeit replace_py(mat, 1, 3, 2)
1 loops, best of 3: 9.49 s per loop
Cython
我认为你的问题不容易重写成使用Numpy和矢量化技术的形式。也许一位Numpy专家可以做到,但我担心代码会变得非常晦涩或缓慢(或两者兼而有之)。引用Numpy开发人员之一的话:
[...] 当需要使用NumPy-ology博士学位才能将解决方案向量化,或者结果会导致太多的内存开销时,您可以使用Cython[...]。
因此,我使用类型化内存视图在Cython中重新编写了replace_py()
及其调用的函数:
import numpy as np
cimport numpy as np
cdef inline int find_first(int[:] a, int index, int n, int value) nogil:
while index<n and a[index]!=value:
index += 1
return index
cdef inline int find_end(int[:] a, int index, int n, int value) nogil:
while index<n and a[index]==value:
index += 1
return index
cdef inline void replace_run(int[:] a, int begin, int end, int threshold, int replace) nogil:
if end-begin+1 > threshold:
for i in xrange(begin, end):
a[i] = replace
cdef inline void process_row(int[:] a, int value, int threshold, int replace) nogil:
cdef int first, end, n
first = 0
n = a.shape[0]
while first < n:
if a[first]==value:
end = find_end(a, first, n, value)
replace_run(a, first, end, threshold, replace)
first = end
else:
first = find_first(a, first, n, value)
def replace_cy(np.ndarray[np.int32_t, ndim=2] a, int value, int length, int replace):
cdef int[:, ::1] vmat
cdef int i, n
mat = a.copy()
vmat = mat
n = vmat.shape[0]
for i in xrange(n):
process_row(vmat[i], value, length, replace)
return mat
需要进行一些调整,代码比上面给出的Python代码更加混乱。但是这不需要太多的工作,而且非常直接。
时间:
np.random.seed(31)
shape = (100000, 10)
mat = np.asarray(a=np.random.binomial(1, p=0.5, size=shape), dtype=np.int32)
mat.reshape(shape)
%timeit replace_cy(mat, 1, 3, 2)
100 loops, best of 3: 8.16 ms per loop
这是一个1163倍的加速!
我在 Github 上得到了帮助,现在Numba版本也可以工作了;我只是在纯 Python 代码中添加了@autojit
,除了a[begin:end] = replace
之外,参见我在 Github 上得到的讨论。
import numpy as np
from numba import autojit
@autojit
def find_first(a, index, value):
while index<a.size and a[index]!=value:
index += 1
return index
@autojit
def find_end(a, index, value):
while index<a.size and a[index]==value:
index += 1
return index
@autojit
def replace_run(a, begin, end, threshold, replace):
if end-begin+1 > threshold:
for i in xrange(begin, end):
a[i] = replace
@autojit
def process_row(a, value, threshold, replace):
first = 0
while first < a.size:
if a[first]==value:
end = find_end(a, first, value)
replace_run(a, first, end, threshold, replace)
first = end
else:
first = find_first(a, first, value)
@autojit
def replace_numba(a, value, length, replace):
mat = a.copy()
for row in mat:
process_row(row, value, length, replace)
return mat
时间(使用上述常规输入,代码已省略):
1 loops, best of 3: 86.5 ms per loop
这与纯Python代码相比是一个110倍的加速!!!Numba版本仍然比Cython慢10倍,这很可能是由于
没有内联小函数,但我认为这基本上是免费获得这种加速,而不会弄乱我们的Python代码!