np.broadcast_to
可以将 A
'重塑' 以匹配 B
; 然后你可以同时迭代这两个数组。它使用 striding,因此实际上不会增加内存使用。
In [370]: def f(a,b):
...: assert(a.shape==(1,3))
...: assert(b.shape==(1,3))
...: return a+b
...:
In [371]: B=np.arange(12).reshape(4,3)
In [372]: A=np.arange(3).reshape(1,3)
In [373]: np.broadcast_to(A, B.shape)
Out[373]:
array([[0, 1, 2],
[0, 1, 2],
[0, 1, 2],
[0, 1, 2]])
In [374]: np.broadcast_to(B, B.shape)
Out[374]:
array([[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8],
[ 9, 10, 11]])
我通常使用列表推导式而不是 map 函数:
In [375]: [f(np.atleast_2d(a),np.atleast_2d(b)) for a,b in zip(np.broadcast_to(A,B.shape),B)]
Out[375]:
[array([[0, 2, 4]]),
array([[3, 5, 7]]),
array([[ 6, 8, 10]]),
array([[ 9, 11, 13]])]
In [376]: [f(np.atleast_2d(a),np.atleast_2d(b)) for a,b in zip(np.broadcast_to(B,B.shape),B)]
Out[376]:
[array([[0, 2, 4]]),
array([[ 6, 8, 10]]),
array([[12, 14, 16]]),
array([[18, 20, 22]])]
迭代2D数组会产生1D数组的列表,因此需要使用
np.atleast_2d
来满足我的
f
断言。如果
f
也接受(3,)输入,则不需要这样做。或者使用
map
:
In [377]: map(lambda a,b: f(np.atleast_2d(a),np.atleast_2d(b)), np.broadcast_to(B,B.shape),B)
Out[377]: <map at 0xb14f4c6c>
In [378]: list(_)
Out[378]:
[array([[0, 2, 4]]),
array([[ 6, 8, 10]]),
array([[12, 14, 16]]),
array([[18, 20, 22]])]
In [379]: map(lambda a,b: f(np.atleast_2d(a),np.atleast_2d(b)), np.broadcast_to(A,B.shape),B)
Out[379]: <map at 0xb0871a8c>
In [380]: list(_)
Out[380]:
[array([[0, 2, 4]]),
array([[3, 5, 7]]),
array([[ 6, 8, 10]]),
array([[ 9, 11, 13]])]
np.vectorize
和np.frompyfunc
也可以处理这种广播,但它们是为接受标量的函数设计的,而不是1d数组。
使用broadcast_arrays
,我可以平等地处理两个数组:
In [386]: map(lambda a,b: f(np.atleast_2d(a),np.atleast_2d(b)), *np.broadcast_arrays(B,A))
Out[386]: <map at 0xb69851ac>
In [387]: list(_)
Out[387]:
[array([[0, 2, 4]]),
array([[3, 5, 7]]),
array([[ 6, 8, 10]]),
array([[ 9, 11, 13]])]
更一般地说,
A
和
B
可以是任何产生所需
(N,3)
数组的东西。我可以通过生成
(N,1,3)
数组来摆脱使用
atleast_2d
:
In [397]: map(f, *np.broadcast_arrays(np.arange(3)[None,None,:], np.arange(0,40,10)[:,None,None]))
Out[397]: <map at 0xb08b562c>
In [398]: list(_)
Out[398]:
[array([[0, 1, 2]]),
array([[10, 11, 12]]),
array([[20, 21, 22]]),
array([[30, 31, 32]])]