不广播Numpy数组

4
在一个庞大的代码库中,我正在使用np.broadcast_to来广播数组(这里只是使用简单的例子):
In [1]: x = np.array([1,2,3])

In [2]: y = np.broadcast_to(x, (2,1,3))

In [3]: y.shape
Out[3]: (2, 1, 3)

在代码的其他地方,我使用可以在Numpy数组上向量化操作但不是ufunc的第三方函数。这些函数不理解广播,这意味着对于像y这样的数组调用这样的函数是低效的。 Numpy的vectorize等解决方案也不好,因为虽然它们理解广播,但会引入一个for循环遍历数组元素,这是非常低效的。
理想情况下,我希望能够有一个函数,我们可以称之为unbroadcast,它返回一个具有最小形状的数组,如果需要,可以将其广播回完整大小。例如:
In [4]: z = unbroadcast(y)

In [5]: z.shape
Out[5]: (1, 1, 3)

我可以在z上运行第三方函数,然后将结果广播回y.shape。是否有一种依赖于Numpy公共API实现的unbroadcast方法?如果没有,是否有任何技巧可以产生所需的结果?

y[None,0] 这样的东西? - Divakar
“minimal shape”是什么意思?unbroadcast返回的N维度中的最小形状不总是(1, 1, ..., 1)(甚至可能是(1,))吗? - Alex Riley
我的意思是最小的形状,仍然包含所有必需的数据以将其广播回完整的数组。因此,在上面的示例中,“z.shape”为“(1,1,3)”而不是“(1,1,1)”。 - astrofrog
2个回答

4
我有一个可能的解决方案,所以我会在这里发布它(但如果有更好的,请随时回复!)。一个解决方法是检查数组的strides参数,在广播维度上将为0:
def unbroadcast(array):
    slices = []
    for i in range(array.ndim):
        if array.strides[i] == 0:
            slices.append(slice(0, 1))
        else:
            slices.append(slice(None))
    return array[slices]

这将会得到:
In [14]: unbroadcast(y).shape
Out[14]: (1, 1, 3)

3
这可能与您自己的解决方案等价,只是更加内置化。它在 numpy.lib.stride_tricks 中使用了 as_strided
import numpy as np
from numpy.lib.stride_tricks import as_strided

x = np.arange(16).reshape(2,1,8,1)  # shape (2,1,8,1)
y = np.broadcast_to(x,(2,3,8,5))    # shape (2,3,8,5) broadcast

def unbroadcast(arr):
    #determine unbroadcast shape
    newshape = np.where(np.array(arr.strides) == 0,1,arr.shape) # [2,1,8,1], thanks to @Divakar
    return as_strided(arr,shape=newshape)    # strides are automatically set here

z = unbroadcast(x)
np.all(z==x)  # is True

请注意,在我的原始答案中,我没有定义一个函数,结果的z数组的strides(64,0,8,0),而输入的strides(64,64,8,8)。在当前版本中,返回的z数组与x具有相同的步幅,我猜测传递和返回数组会强制创建一个副本。无论如何,我们可以在as_strided中手动设置步幅以在所有情况下获得相同的数组,但在上述设置中似乎并不必要。

1
还是使用np.where(np.array(y.strides) == 0,1,y.shape)来获取新形状? - Divakar
@Divakar 对啊,我老是忘记 np.where :) 很显然这是numpy习惯用法,谢谢。 - Andras Deak -- Слава Україні
为了后代,将这两行关键代码编辑成一个“未广播”函数可能会很好? - astrofrog
@astrofrog 谢谢,我希望任何试图这样做的人都能将其转换为函数;) 无论如何,我编辑了我的答案。结果是得到一个与 x 具有相同步幅的数组。 - Andras Deak -- Слава Україні

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接