测试一个数组是否可以广播到一个指定的形状?

6
什么是检测数组能否广播到给定形状的最佳方法?
对于我的情况,“pythonic”方法 try 不起作用,因为意图是对操作进行惰性评估。
我询问如何实现下面的 is_broadcastable:
>>> x = np.ones([2,2,2])
>>> y = np.ones([2,2])
>>> is_broadcastable(x,y)
True
>>> y = np.ones([2,3])
>>> is_broadcastable(x,y)
False

或者更好的方式是:
>>> is_broadcastable(x.shape, y.shape)

请参考此链接:http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html - Saullo G. P. Castro
7个回答

12

我真的认为你们过于复杂化了这个问题,为什么不简单点呢?

def is_broadcastable(shp1, shp2):
    for a, b in zip(shp1[::-1], shp2[::-1]):
        if a == 1 or b == 1 or a == b:
            pass
        else:
            return False
    return True

我想不出有什么问题... 我可能漏掉了一些微妙的额外规则,对吗? - keflavich

5

如果您只想避免根据给定形状实例化数组,则可以使用as_strided:

import numpy as np
from numpy.lib.stride_tricks import as_strided

def is_broadcastable(shp1, shp2):
    x = np.array([1])
    a = as_strided(x, shape=shp1, strides=[0] * len(shp1))
    b = as_strided(x, shape=shp2, strides=[0] * len(shp2))
    try:
        c = np.broadcast_arrays(a, b)
        return True
    except ValueError:
        return False

is_broadcastable((1000, 1000, 1000), (1000, 1, 1000))  # True
is_broadcastable((1000, 1000, 1000), (3,))  # False

这种做法非常节省内存空间,因为a和b都是由同一条记录支持的。


4
您可以使用 np.broadcast。例如:
In [47]: x = np.ones([2,2,2])

In [48]: y = np.ones([2,3])

In [49]: try:
   ....:     b = np.broadcast(x, y)
   ....:     print "Result has shape", b.shape
   ....: except ValueError:
   ....:     print "Not compatible for broadcasting"
   ....:     
Not compatible for broadcasting

In [50]: y = np.ones([2,2])

In [51]: try:
   ....:     b = np.broadcast(x, y)
   ....:     print "Result has shape", b.shape
   ....: except ValueError:
   ....:     print "Not compatible for broadcasting"
   ....:
Result has shape (2, 2, 2)

对于您实现惰性评估的需求,您可能会发现np.broadcast_arrays非常有用。


2
“broadcast_arrays” 是我寻找的懒惰计算方式,但我仍需要一个验证阶段 - 如果假设我手头没有一个形状为B的数组,我该如何测试A的形状是否可以广播到B的形状? - keflavich

3

如果您想检查任意数量的类似数组的对象(而不是传递形状),我们可以利用np.nditer进行广播数组迭代

def is_broadcastable(*arrays):
    try:
        np.nditer(arrays)
        return True
    except ValueError:
        return False

请注意,这仅适用于np.ndarray或定义了__array__(将会被调用)的类。


0

numpy.broadcast_shapes自numpy 1.20版本开始提供,因此可以像这样轻松实现:

import numpy as np

def is_broadcastable(shp1, shp2):
    try:
        c = np.broadcast_shapes(shp1, shp2)
        return True
    except ValueError:
        return False

在底层,它使用零长度列表的NumPy数组来调用broadcast_arrays,方法如下:

np.empty(shp, dtype=[])

这种方法避免了分配内存。它类似于@ChrisB提出的解决方案,但不依赖于as_strided技巧,我觉得有点令人困惑。


0

要将此推广到任意数量的形状,可以按照以下方式进行:

def is_broadcast_compatible(*shapes):
    if len(shapes) < 2:
        return True
    else:
        for dim in zip(*[shape[::-1] for shape in shapes]):
            if len(set(dim).union({1})) <= 2:
                pass
            else:
                return False
        return True

相应的测试用例如下:

import unittest


class TestBroadcastCompatibility(unittest.TestCase):
    def check_true(self, *shapes):
        self.assertTrue(is_broadcast_compatible(*shapes), msg=shapes)

    def check_false(self, *shapes):
        self.assertFalse(is_broadcast_compatible(*shapes), msg=shapes)

    def test(self):
        self.check_true((1, 2, 3), (1, 2, 3))
        self.check_true((3, 1, 3), (3, 3, 3))
        self.check_true((1,), (2,), (2,))

        self.check_false((1, 2, 3), (1, 2, 2))
        self.check_false((1, 2, 3), (1, 2, 3, 4))
        self.check_false((1,), (2,), (3,))

0

当顺序很重要时,比如测试是否np.broadcast_to(a, b.shape)能够正常工作,以下方法似乎效果不错:

def is_broadcastable(src, dst):
    try:
        return np.broadcast_shapes(src, dst) == dst
    except ValueError:
        return False

np.broadcast_shaps 的返回值与其给定的参数顺序无关,我们需要检查结果是否与 dst 相同;而且 dst 是“较大”的尺寸。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接