如何加速矩阵代码

4

我有以下简单的代码,用于估计大小为h乘n的二进制矩阵具有某种属性的概率。它运行时间呈指数级增长(这本来就很糟糕),但是我惊讶地发现即使对于n = 12和h = 9,它的速度也非常慢。

#!/usr/bin/python

import numpy as np
import itertools

n = 12
h = 9

F = np.matrix(list(itertools.product([0,1],repeat = n))).transpose()

count = 0
iters = 100
for i in xrange(iters):
    M =  np.random.randint(2, size=(h,n))
    product = np.dot(M,F)
    setofcols = set()
    for column in product.T:
        setofcols.add(repr(column))
    if (len(setofcols)==2**n):
        count = count + 1
print count*1.0/iters

我使用n=10和h=7进行了剖析,输出结果相当长,但以下是花费更多时间的行。

        23447867 function calls (23038179 primitive calls) in 35.785 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        2    0.002    0.001    0.019    0.010 __init__.py:1(<module>)
        1    0.001    0.001    0.054    0.054 __init__.py:106(<module>)
        1    0.001    0.001    0.022    0.022 __init__.py:15(<module>)
        2    0.003    0.002    0.013    0.006 __init__.py:2(<module>)
        1    0.001    0.001    0.003    0.003 __init__.py:38(<module>)
        1    0.001    0.001    0.001    0.001 __init__.py:4(<module>)
        1    0.001    0.001    0.004    0.004 __init__.py:45(<module>)
        1    0.001    0.001    0.002    0.002 __init__.py:88(<module>)
   307200    0.306    0.000    1.584    0.000 _methods.py:24(_any)
   102400    0.026    0.000    0.026    0.000 arrayprint.py:22(product)
   102400    1.345    0.000   32.795    0.000 arrayprint.py:225(_array2string)
307200/102400    1.166    0.000   33.350    0.000 arrayprint.py:335(array2string)
   716800    0.820    0.000    1.162    0.000 arrayprint.py:448(_extendLine)
204800/102400    1.699    0.000    5.090    0.000 arrayprint.py:456(_formatArray)
   307200    0.651    0.000   22.510    0.000 arrayprint.py:524(__init__)
   307200   11.783    0.000   21.859    0.000 arrayprint.py:538(fillFormat)
  1353748    1.920    0.000    2.537    0.000 arrayprint.py:627(_digits)
   102400    0.576    0.000    2.523    0.000 arrayprint.py:636(__init__)
   716800    2.159    0.000    2.159    0.000 arrayprint.py:649(__call__)
   307200    0.099    0.000    0.099    0.000 arrayprint.py:658(__init__)
   102400    0.163    0.000    0.225    0.000 arrayprint.py:686(__init__)
   102400    0.307    0.000   13.784    0.000 arrayprint.py:697(__init__)
   102400    0.110    0.000    0.110    0.000 arrayprint.py:713(__init__)
   102400    0.043    0.000    0.043    0.000 arrayprint.py:741(__init__)
        1    0.003    0.003    0.003    0.003 chebyshev.py:87(<module>)
        2    0.001    0.000    0.001    0.000 collections.py:284(namedtuple)
        1    0.277    0.277   35.786   35.786 counterfeit.py:3(<module>)
   205002    0.222    0.000    0.247    0.000 defmatrix.py:279(__array_finalize__)
   102500    0.747    0.000    1.077    0.000 defmatrix.py:301(__getitem__)
   102400    0.322    0.000   34.236    0.000 defmatrix.py:352(__repr__)
   102400    0.100    0.000    0.508    0.000 fromnumeric.py:1087(ravel)
   307200    0.382    0.000    2.829    0.000 fromnumeric.py:1563(any)
      271    0.004    0.000    0.005    0.000 function_base.py:3220(add_newdoc)
        1    0.003    0.003    0.003    0.003 hermite.py:59(<module>)
        1    0.003    0.003    0.003    0.003 hermite_e.py:59(<module>)
        1    0.001    0.001    0.002    0.002 index_tricks.py:1(<module>)
        1    0.003    0.003    0.003    0.003 laguerre.py:59(<module>)
        1    0.003    0.003    0.003    0.003 legendre.py:83(<module>)
        1    0.001    0.001    0.001    0.001 linalg.py:10(<module>)
        1    0.001    0.001    0.001    0.001 numeric.py:1(<module>)
   102400    0.247    0.000   33.598    0.000 numeric.py:1365(array_repr)
   204800    0.321    0.000    1.143    0.000 numeric.py:1437(array_str)
   614400    1.199    0.000    2.627    0.000 numeric.py:2178(seterr)
   614400    0.837    0.000    0.918    0.000 numeric.py:2274(geterr)
   102400    0.081    0.000    0.186    0.000 numeric.py:252(asarray)
   307200    0.259    0.000    0.622    0.000 numeric.py:322(asanyarray)
        1    0.003    0.003    0.004    0.004 polynomial.py:54(<module>)
   513130    0.134    0.000    0.134    0.000 {isinstance}
   307229    0.075    0.000    0.075    0.000 {issubclass}
5985327/5985305    0.595    0.000    0.595    0.000 {len}
 306988    0.120    0.000    0.120    0.000 {max}
   102400    0.061    0.000    0.061    0.000 {method '__array__' of 'numpy.ndarray' objects}
   102406    0.027    0.000    0.027    0.000 {method 'add' of 'set' objects}
   307200    0.241    0.000    1.824    0.000 {method 'any' of 'numpy.ndarray' objects}
   307200    0.482    0.000    0.482    0.000 {method 'compress' of 'numpy.ndarray' objects}
   204800    0.035    0.000    0.035    0.000 {method 'item' of 'numpy.ndarray' objects}
   102451    0.014    0.000    0.014    0.000 {method 'join' of 'str' objects}
   102400    0.222    0.000    0.222    0.000 {method 'ravel' of 'numpy.ndarray' objects}
   921176    3.330    0.000    3.330    0.000 {method 'reduce' of 'numpy.ufunc' objects}
   102405    0.057    0.000    0.057    0.000 {method 'replace' of 'str' objects}
  2992167    0.660    0.000    0.660    0.000 {method 'rstrip' of 'str' objects}
   102400    0.041    0.000    0.041    0.000 {method 'splitlines' of 'str' objects}
        6    0.003    0.000    0.003    0.001 {method 'sub' of '_sre.SRE_Pattern' objects}
   307276    0.090    0.000    0.090    0.000 {min}
      100    0.013    0.000    0.013    0.000 {numpy.core._dotblas.dot}
   409639    0.473    0.000    0.473    0.000 {numpy.core.multiarray.array}
  1228800    0.239    0.000    0.239    0.000 {numpy.core.umath.geterrobj}
   614401    0.352    0.000    0.352    0.000 {numpy.core.umath.seterrobj}
   102475    0.031    0.000    0.031    0.000 {range}
   102400    0.076    0.000    0.102    0.000 {reduce}
204845/102445    0.198    0.000   34.333    0.000 {repr}

矩阵相乘似乎只占用了很少一部分的时间。有没有可能加速其他部分呢?

结果

现在有三个答案,但其中一个似乎目前存在错误。我已经使用 n=18、h=11 和 iters=10 测试了剩下的两个答案。

  • bubble - 21 秒,185MB 内存。对 "sort" 操作只需要 16 秒。
  • hpaulj - 7.5 秒,130MB 内存。对 "tolist" 操作只需要 3 秒,"numpy.core.multiarray.array" 操作需要 1.5 秒,"genexpr"(即 'set' 行) 需要 1.5 秒。

有趣的是,矩阵相乘所花费的时间仍然只是总体时间的一小部分。


2
显然可以令其速度惊人地加快。如果你看到arrayprint,你认为哪一部分是最慢的呢?;) - seberg
1
你应该阅读我在你之前的问题中提到的链接https://dev59.com/mmQn5IYBdhLWcg3wRVRs,以获取更高效的唯一计数方法。`arrapyprint.fillFormat`是你代码中 repr(column) 部分所调用的,因此它是最慢的部分。 - alko
集合方法可能非常好。但是,如果您知道dtype/row是安全的(包括连续性和字节顺序),您可以直接使用arr.tostring() - seberg
大部分时间都在arrayprint中,它由repr调用。将数字转换为字符串的目的是什么?numpy旨在快速处理数字数组。处理字符串时,它使用常规的Python方法。 - hpaulj
@hpaulj 这个转换只是为了让你知道有多少列是唯一的。 - marshall
3个回答

3
为了加快上述代码的执行速度,您应该避免使用循环。
import numpy as np
import itertools

def unique_rows(a):
    a = np.ascontiguousarray(a)
    unique_a = np.unique(a.view([('', a.dtype)]*a.shape[1]))
    return unique_a.view(a.dtype).reshape((unique_a.shape[0], a.shape[1]))


n = 12
h = 9
iters=100
F = np.matrix(list(itertools.product([0,1],repeat = n))).transpose()
M =  np.random.randint(2, size=(h*iters,n))
product = np.dot(M,F)
counts = map(lambda x: len(unique_rows(x.T))==2**n, np.split(product,iters,axis=0))
prob=float(sum(counts))/iters

#All unique submatrices M (hxn) with the sophisticated property...
[np.split(M,iters,axis=0)[j] for j in range(len(counts)) if counts[j]==True]

代码中的瓶颈是慢的unique_rows实现。 - alko
这些解决方案很慢吗?https://dev59.com/mmQn5IYBdhLWcg3wRVRs - bubble
不,我也指的是那些链接,但你的回答根本没有包含它们,所以对OP的问题没有任何帮助。 - alko
如果您愿意详细说明您的答案并提供可工作的代码,我想它会被接受。 - alko
@Bubble 谢谢!这个快多了。不过避免循环有两个问题。第一个是现在它使用了大量的RAM。第二个是当不同列的数量为2 ** n时,我现在无法打印M。有没有一种简单的方法来修改你的代码以解决这个问题? - marshall
显示剩余2条评论

2

尝试用以下代码替换repr(col)

setofcols.add(tuple(column.A1.tolist()))

set 接受一个 tuplecolumn.A1 是转换为 1d 数组的矩阵。然后,元组类似于 (0, 1, 0)set 可以轻松比较。

仅替换昂贵的 repr 格式化程序就可以节省大量时间(速度提升了 25 倍)。

编辑

通过在一个语句中创建和填充 set,我获得了进一步的 10 倍速度提升。在我的测试中,它比 bubble 的向量化方法快 2 倍。

count = 0
for i in xrange(iters):
    M =  np.random.randint(2, size=(h,n))
    product = np.dot(M,F)
    setofcols = set(tuple(x) for x in product.T.tolist())
    # or {tuple(x) for x in product.T.tolist()} if new enough Python
    if (len(setofcols)==2**n):
        count += 1
        # print M # to see the unique M
print count*1.0/iters

编辑

这里有一个更快的方法 - 将每列的9个整数转换为1个整数,使用dot([1,10,100,...],column)。然后对整数列表应用np.unique(或set)。这是一个2-3倍的进一步加速。

count = 0
X = 10**np.arange(h)
for i in xrange(iters):
    M =  np.random.randint(2, size=(h,n))
    product = np.dot(M,F)
    setofcols = np.unique(np.dot(X,product).A1)
    if (setofcols.size==2**n):
        count += 1
print count*1.0/iters

通过这个,顶部的调用是

  200    0.201    0.001    0.204    0.001 {numpy.core._dotblas.dot}
  100    0.026    0.000    0.026    0.000 {method 'sort' of 'numpy.ndarray' objects}
  100    0.007    0.000    0.035    0.000 arraysetops.py:93(unique)

谢谢。它仍然花费了一小部分时间来计算矩阵的乘积。对于不同的n和h参数以及iters=2,我得到了以下结果:"39 6.465 0.166 6.465 0.166 {numpy.core.multiarray.array}", "2 3.803 1.902 3.803 1.902 {method 'tolist' of 'numpy.ndarray' objects}", "2 1.072 0.536 1.072 0.536 {numpy.core._dotblas.dot}"。 - marshall
使用BLAS的np.dotnumpy中最高效(时间方面)的操作之一。因此,我并不惊讶tolist需要更长的时间。我怀疑即使您删除了set操作,众多的array调用仍然存在。 - hpaulj
在IPython中进行分析显示,在生成F时调用了multiarray.array,而不是在iters循环期间调用。 - hpaulj
我通过将每列压缩为一个整数(通过乘以[1,10,100,...])获得了更进一步的加速。仍需逐列执行uniqueset操作。 - hpaulj

1

正如alko和seberg指出的那样,你正在浪费大量时间将数组转换为大字符串以将它们存储在列集中。

如果我正确理解了你的代码,你试图找到在product矩阵中不同列的数量是否等于该矩阵的长度。你可以通过对其进行排序并查看从一列到下一列的差异来轻松实现这一点:

D = (np.diff(np.sort(product.T, axis=0), axis=0) == 0)

这将为您提供一个由布尔值组成的矩阵D。然后,您可以查看是否至少有一个元素从一列到另一列发生了变化:
C = (1 - np.prod(D, axis=1)) # i.e. 'not all(D[i,:]) for all i'

然后,您只需要查看all的所有值是否不同即可:

hasproperty = np.all(C)

这将给您完整的代码:

def f(n, h, iters):
    F = np.array(list(itertools.product([0,1], repeat=n))).T
    counts = []
    for _ in xrange(iters):
        M = np.random.randint(2, size=(h,n))
        product = M.dot(F)
        D = (np.diff(np.sort(product.T, axis=1), axis=0) == 0)
        C =  (1 - np.prod(D, axis=1))
        hasproperty = np.all(C)
        counts.append(1. if hasproperty else 0.)
    return np.mean(counts)

对于f(12, 9, 100),大约需要8秒钟。

如果你更喜欢滑稽紧凑的表达方式:

def g(n, h, iters):
    F = np.array(list(itertools.product([0,1], repeat=n))).T
    return np.mean([np.all(1 - np.prod(np.diff(np.sort(np.random.randint(2,size=(h,n)).dot(F).T, axis=1), axis=0)==0, axis=1)) for _ in xrange(iters)])

计时它的结果如下:
>>> setup = """import numpy as np
def g(n, h, iters):
    F = np.array(list(itertools.product([0,1], repeat=n))).T
    return np.mean([np.all(1 - np.prod(np.diff(np.sort(np.random.randint(2,size=(h,n)).dot(F).T, axis=1), axis=0)==0, axis=1)) for _ in xrange(iters)])
"""
>>> timeit.timeit('g(10, 7, 100)', setup=setup, number=10)
17.358669997900734
>>> timeit.timeit('g(10, 7, 100)', setup=setup, number=50)
83.06966196163967

或者每次调用g(10,7,100)大约需要1.7秒。


谢谢,但我认为f函数至少有一些问题。当你使用f(12,8,100)时,你得到的结果是0.0。但是使用@bubble的代码却不是这样。 - marshall
确实,存在一个错误:product.T 必须沿 axis=1 进行排序。已经进行了更正,感谢您的反馈 :) - val
我怕还有一个错误。尝试n = 5,h = 4和iters = 100。输出应该大约为0.07。 - marshall

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接