将Python函数广播到NumPy数组上

10

假设我们有一个非常简单的函数,例如:

import scipy as sp
def func(x, y):
   return x + y

这个函数显然适用于几种Python内置数据类型,例如字符串、列表、整数、浮点数、数组等。由于我们特别关注数组,因此我们考虑两个数组:

x = sp.array([-2, -1, 0, 1, 2])
y = sp.array([-2, -1, 0, 1, 2])

xx = x[:, sp.newaxis]
yy = y[sp.newaxis, :]

>>> func(xx, yy)

这将返回

array([[-4, -3, -2, -1,  0],
  [-3, -2, -1,  0,  1],
  [-2, -1,  0,  1,  2],
  [-1,  0,  1,  2,  3],
  [ 0,  1,  2,  3,  4]])

正如我们所预期的那样。

现在,如果有人想要将数组作为以下函数的输入,该怎么办呢?

def func2(x, y):
  if x > y:
     return x + y
  else:
     return x - y

执行 >>>func(xx, yy) 会引发错误。

最常见的方法是使用scipy/numpy中的 sp.vectorize 函数。然而,这种方法被证明并不是很高效。有没有更加健壮的方法可以将任何函数广播到numpy数组上呢?

如果重写代码以适应数组的方式是唯一的方法,那么在此处提及会很有帮助。

3个回答

16

np.vectorize是将操作数字的Python函数转换为操作ndarrays的numpy函数的一种常规方法。

然而,正如你所指出的那样,它并不快,因为它在“内部”使用了一个Python循环。

要获得更好的速度,您必须手工制作一个期望以numpy数组作为输入并利用numpy特性的函数:

import numpy as np

def func2(x, y):
    return np.where(x>y,x+y,x-y)      

x = np.array([-2, -1, 0, 1, 2])
y = np.array([-2, -1, 0, 1, 2])

xx = x[:, np.newaxis]
yy = y[np.newaxis, :]

print(func2(xx, yy))
# [[ 0 -1 -2 -3 -4]
#  [-3  0 -1 -2 -3]
#  [-2 -1  0 -1 -2]
#  [-1  0  1  0 -1]
#  [ 0  1  2  3  0]]

关于性能:

test.py:

import numpy as np

def func2a(x, y):
    return np.where(x>y,x+y,x-y)      

def func2b(x, y):
    ind=x>y
    z=np.empty(ind.shape,dtype=x.dtype)
    z[ind]=(x+y)[ind]
    z[~ind]=(x-y)[~ind]
    return z

def func2c(x, y):
    # x, y= x[:, None], y[None, :]
    A, L= x+ y, x<= y
    A[L]= (x- y)[L]
    return A

N=40
x = np.random.random(N)
y = np.random.random(N)

xx = x[:, np.newaxis]
yy = y[np.newaxis, :]

运行中:

N=30:

% python -mtimeit -s'import test' 'test.func2a(test.xx,test.yy)'
1000 loops, best of 3: 219 usec per loop

% python -mtimeit -s'import test' 'test.func2b(test.xx,test.yy)'
1000 loops, best of 3: 488 usec per loop

% python -mtimeit -s'import test' 'test.func2c(test.xx,test.yy)'
1000 loops, best of 3: 248 usec per loop

当N=1000:

% python -mtimeit -s'import test' 'test.func2a(test.xx,test.yy)'
10 loops, best of 3: 93.7 msec per loop

% python -mtimeit -s'import test' 'test.func2b(test.xx,test.yy)'
10 loops, best of 3: 367 msec per loop

% python -mtimeit -s'import test' 'test.func2c(test.xx,test.yy)'
10 loops, best of 3: 186 msec per loop

这似乎表明func2afunc2c稍微快一些(而func2b则非常慢)。


1
np.where也可以在这种情况下非常有用。 - matt
@unutbu:使用where确实看起来很好,但你有考虑过在实现时对性能的影响吗?谢谢。 - eat
@unutbu:有趣的时间。能否使用N进行计时,例如在1e3... 1e4范围内?到目前为止,使用“where”实现似乎是最合理的。谢谢。 - eat
@eat:当我不经意地设置N=10000时,我的小电脑就会卡住 :)。当x的形状为(10000,)时,func2*的返回值的形状为(10000,10000)。使用dtype = float64,至少需要760 MiB。这使我进入了缓冲区交换的领域。无论如何,我倾向于相信随着N的增长,结果的排序不会改变。你认为呢? - unutbu
@unutbu:已经有点晚了,但我明天会自己做一些时间测试(但你仍然可以尝试使用1e3级别)。嗯,我不知道排序是否会以任何戏剧性的方式改变,但根据我的有限经验,在某些情况下,“where”相对于纯“逻辑索引”会带来额外的开销。谢谢。 - eat
显示剩余2条评论

13

对于这种特殊情况,您还可以编写一个函数来操作NumPy数组和普通Python浮点数:

def func2d(x, y):
    z = 2.0 * (x > y) - 1.0
    z *= y
    return x + z

这个版本比unutbu的func2a()快四倍以上(在 N = 100的情况下测试)。


+1:干得好!看起来func2d更快,因为它需要较少的内存分配。你同意吗? - unutbu
@unutbu:不确定。我写的第一个版本使用了更少的临时变量(例如在第一行中省略了“-1.0”,并使用“z -= 1.0”),但这样会更慢。 - Sven Marnach
嗯,这很奇怪。对于我来说(当N=1000时),使用“z -= 1.0”每个循环需要40.7毫秒,而“func2d”需要47.8毫秒。 - unutbu
1
+1,干得好!在我的机器上,当N=10、100、1000时,func2a/func2d的性能比率分别为[1.144、1.885、1.624]。现在希望能够得到OP的反馈。谢谢。 - eat

1

为了获得基本的想法,您可以修改您的函数,例如这种方式:

def func2(x, y):
    x, y= x[:, None], y[None, :]
    A= x+ y
    A[x<= y]= (x- y)[x<= y]
    return A

因此,对于您的情况,以下内容应该是一个非常合理的起点:

In []: def func(x, y):
   ..:     x, y= x[:, None], y[None, :]
   ..:     return x+ y
   ..:
In []: def func2(x, y):
   ..:     x, y= x[:, None], y[None, :]
   ..:     A, L= x+ y, x<= y
   ..:     A[L]= (x- y)[L]
   ..:     return A
   ..:
In []: x, y= arange(-2, 3), arange(-2, 3)
In []: func(x, y)
Out[]:
array([[-4, -3, -2, -1,  0],
       [-3, -2, -1,  0,  1],
       [-2, -1,  0,  1,  2],
       [-1,  0,  1,  2,  3],
       [ 0,  1,  2,  3,  4]])
In []: func2(x, y)
Out[]:
array([[ 0, -1, -2, -3, -4],
       [-3,  0, -1, -2, -3],
       [-2, -1,  0, -1, -2],
       [-1,  0,  1,  0, -1],
       [ 0,  1,  2,  3,  0]])

尽管这种处理方式可能会浪费资源,但并不一定如此。始终要测量程序的实际性能,并在必要时进行更改(而不是过早地进行更改)。

我认为另一个优点是:这种“向量化”使您的代码最终变得一致且易读。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接