用公式在Python中定义一个ndarray

3

我有一个多维数组,初始化为C=np.zeros([20,20,20,20])。然后,我正在尝试通过一些公式(C(x)=(exp(-|x|^2) 在这种情况下)给C赋值。以下代码是有效的,但非常缓慢。

it=np.nditer(C, flags=['multi_index'], op_flags=['readwrite'])
while not it.finished:
    diff=np.linalg.norm(np.array(it.multi_index))
    it[0]=np.exp(-diff**2)
    it.iternext()

这个问题是否可以更快地使用更符合Python风格的方式解决?


从你的例子来看,你已经找到了nditer教程页面。在最后,它有一个在Cython代码中使用nditer的示例。那是很快的。但在Python代码中,它并不比其他迭代方法更快。最好完全避免迭代。 - hpaulj
2个回答

2
这里有一种方法可以实现。
第一步,使用代码中的np.array(it.multi_index)获取所有索引计算所对应的所有组合。在这方面,可以利用itertools库中的product函数
第二步,以向量化的方式执行所有组合的L2范数计算。
第三步,最后以逐元素的方式执行C(x)=(exp(-|x|^2)
# Get combinations using itertools.product
combs = np.array(list(product(range(N), repeat=4)))

# Perform L2 norm and elementwise exponential calculations to get final o/p 
out = np.exp(-np.sqrt((combs**2).sum(1))**2).reshape(N,N,N,N)

运行时测试和验证输出 -

In [42]: def vectorized_app(N):
    ...:     combs = np.array(list(product(range(N), repeat=4)))
    ...:     return np.exp(-np.sqrt((combs**2).sum(1))**2).reshape(N,N,N,N)
    ...: 
    ...: def original_app(N):
    ...:     C=np.zeros([N,N,N,N])
    ...:     it=np.nditer(C, flags=['multi_index'], op_flags=['readwrite'])
    ...:     while not it.finished:
    ...:         diff_n=np.linalg.norm(np.array(it.multi_index))
    ...:         it[0]=np.exp(-diff_n**2)
    ...:         it.iternext()
    ...:     return C
    ...: 

In [43]: N = 10

In [44]: %timeit original_app(N)
1 loops, best of 3: 288 ms per loop

In [45]: %timeit vectorized_app(N)
100 loops, best of 3: 8.63 ms per loop

In [46]: np.allclose(vectorized_app(N),original_app(N))
Out[46]: True

非常感谢。我刚刚收到了“product() got an unexpected keyword argument 'repeat'”的错误,但可能是Python的旧版本(?) - Peter Franek
@PeterFranek 你的Python版本是什么? - Divakar
2.7.6.(在Ubuntu下) 我应该重新安装它吗? - Peter Franek
@PeterFranek 好的,我使用的是 2.7.9 版本。这个功能似乎很旧了,就像 2013年的问题 中使用的那样。你能否尝试重新安装?我认为即使不考虑这个 带重复产品 的功能,它也会非常有用! - Divakar
最后一个问题,你的代码真的至少快两倍吗?如果是的话,我会重新安装 :) - Peter Franek
@PeterFranek 它快了 33倍 ;) - Divakar

1

看起来你只想将操作应用于每个元素的索引?这样怎么样:

x = np.exp(-np.linalg.norm(np.indices([20,20,20,20]), axis=0)**2)

np.indices是一个非常巧妙的函数。mgrid和meshgrid与此相关,用于更复杂的操作。在这种情况下,由于有4个维度,它返回一个形状为(4,20,20,20,20)的数组。

而且纯numpy会更快 :)

In [13]: timeit posted_code()
1 loops, best of 3: 843 ms per loop

In [14]: timeit np.exp(-np.linalg.norm(np.indices([20,20,20,20]), axis=0)**2)
100 loops, best of 3: 3.76 ms per loop

而且这正是相同的结果:
In [26]: np.all(C == x)
Out[26]: True

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接