Numpy: 从给定的范围中生成组合的高效方法

4

我有一个如下所示的n维数组:

np.array([[0,3],[0,3],[0,10]])

在这个数组中,元素表示低值和高值。例如:[0,3] 表示 [0,1,2,3]
我需要使用上述给定的范围生成所有值的组合。例如,我想要 [0,0,0], [0,0,1] ... [0,1,0] ... [3,3,10]
我尝试了以下方法来得到我想要的结果:
ds = np.array([[0,3],[0,3],[0,10]])
nItems = int(reduce(lambda a,b: a * (b[1] - b[0] + 1), ds, 1))
myCombinations = np.zeros((nItems,))
nArrays = []
for x in range(ds.shape[0]):
    low = ds[x][0]
    high= ds[x][1]
    nitm = high - low + 1
    ar = [x+low for x in range(nitm) ]
    nArrays.append(ar)

myCombinations = cartesian(nArrays)

这个笛卡尔函数取自于使用numpy构建两个数组的所有组合的数组

我需要执行这个操作几百万次

我的问题是:有没有更好/更有效的方法来做到这一点?

2个回答

25

我认为你需要的是np.mgrid。不幸的是,它返回的数组格式与你所需的不同,因此你需要进行一些后处理:

a = np.mgrid[0:4, 0:4, 0:11]     # All points in a 3D grid within the given ranges
a = np.rollaxis(a, 0, 4)         # Make the 0th axis into the last axis
a = a.reshape((4 * 4 * 11, 3))   # Now you can safely reshape while preserving order

解释

np.mgrid函数可以在N维空间中生成一组网格点。为了更好地理解,让我通过一个小的例子来展示:

>>> a = np.mgrid[0:2, 0:2]
>>> a
array([[[0, 0],
        [1, 1]],

       [[0, 1],
        [0, 1]]])
自从我给出了两组范围0:2, 0:2,我得到了一个二维网格。 mgrid 返回的是对应于二维空间中点(0, 0), (0, 1), (1, 0) 和 (1, 1) 的x值和y值。数组a[0] 告诉你这四个点的x值是什么,a[1]告诉你它们的y值。
但你真正想要的是我写出来的实际网格点列表,而不是这些点的x值和y值。第一反应是按所需方式重塑数组:
>>> a.reshape((4, 2))
array([[0, 0],
       [1, 1],
       [0, 1],
       [0, 1]])

但是很明显这样做是不行的,因为它会有效地重新塑造扁平化的数组(通过按顺序读取所有元素得到的数组),而这不是你想要的。

你想要做的是沿着a第三个维度查看,并创建一个数组:

[ [a[0][0, 0], a[1][0, 0]],
  [a[0][0, 1], a[1][0, 1]],
  [a[0][1, 0], a[1][1, 0]],
  [a[0][1, 1], a[1][1, 1]] ]

读作“首先告诉我第一个点(x1, y1),然后是第二个点(x2, y2)…”等等。或许通过一张图可以更好地解释。这就是a的样子:

                you want to read
                in this direction
                 (0, 0)   (0, 1)
                   |        |
                   |        |
                   v        v

          /        0--------0            +----> axis0
 x-values |       /|       /|           /|
          |      / |      / |    axis1 / |
          \     1--------1  |         L  |
                |  |     |  |            v
          /     |  0-----|--1           axis2
 y-values |     | /      | /
          |     |/       |/
          \     0--------1

                |        |
                |        |
                v        v
              (1, 0)   (1, 1)

np.rollaxis提供了一种实现此操作的方法。np.rollaxis(a, 0, 3)在上面的例子中表示“取第0个(或最外层)轴,并将其变为最后一个(或最内层)轴。(注意:这里只有轴0、1和2存在。所以说“将第0轴发送到第3个位置”是告诉Python将第0轴放在最后一个轴之后的一种方式)。您可能还想阅读这篇文章

>>> a = np.rollaxis(a, 0, 3)
>>> a
array([[[0, 0],
        [0, 1]],

       [[1, 0],
        [1, 1]]])

现在看起来已经接近你想要的结果了,只是多了一维数组。我们想要合并0和1维度,以便得到一个仅包含网格点的单一数组。但是现在,由于展平后的数组已按照预期方式读取,因此您可以安全地对其进行reshape操作,以获得所需的结果。

>>> a = a.reshape((4, 2))
>>> a
array([[0, 0],
       [0, 1],
       [1, 0],
       [1, 1]])

3D版本的作用和2D版本相同,只是我无法为其制作一个图形,因为它将处于4D空间。


这个很高效(100000次运行大约需要4秒),但是有点令人困惑,你能解释一下它是如何工作的吗?(或者请指引我一些文档,让我能够理解它?) - okkhoy
我已添加了一些解释以便您理解,但在我的计算机上,itertools.product 实际运行速度约快 6 倍。我的方法大部分时间都被 mgrid 占用,所以您甚至不能通过避免使用 rollaxisreshape 来加快它。出于好奇,您使用的 Python 和 numpy 版本是什么? - Praveen
我刚刚意识到实现rollaxis+reshape效果的另一种方法,但在此过程中失去了numpy的特性,那就是使用zip(a[0].flatten(), a[1].flatten(), a[2].flatten()) - Praveen
哇!感谢解释!我正在运行Python 2.7.6和Numpy 1.8.1,我再次检查,在我的机器上结果类似。itertools需要更长时间! - okkhoy

3
您可以使用itertools.product来实现此功能:
In [16]: from itertools import product

In [17]: values = list(product(range(4), range(4), range(11)))

In [18]: values[:5]
Out[18]: [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3), (0, 0, 4)]

In [19]: values[-5:]
Out[19]: [(3, 3, 6), (3, 3, 7), (3, 3, 8), (3, 3, 9), (3, 3, 10)]

给定一组范围的数组,可以像下面这样做。 (我使用了一些非零低值来演示一般情况 - 并缩小输出的大小。 :))

In [41]: ranges = np.array([[0, 3], [1, 3], [8, 10]])

In [42]: list(product(*(range(lo, hi+1) for lo, hi in ranges)))
Out[42]: 
[(0, 1, 8),
 (0, 1, 9),
 (0, 1, 10),
 (0, 2, 8),
 (0, 2, 9),
 (0, 2, 10),
 (0, 3, 8),
 (0, 3, 9),
 (0, 3, 10),
 (1, 1, 8),
 (1, 1, 9),
 (1, 1, 10),
 (1, 2, 8),
 (1, 2, 9),
 (1, 2, 10),
 (1, 3, 8),
 (1, 3, 9),
 (1, 3, 10),
 (2, 1, 8),
 (2, 1, 9),
 (2, 1, 10),
 (2, 2, 8),
 (2, 2, 9),
 (2, 2, 10),
 (2, 3, 8),
 (2, 3, 9),
 (2, 3, 10),
 (3, 1, 8),
 (3, 1, 9),
 (3, 1, 10),
 (3, 2, 8),
 (3, 2, 9),
 (3, 2, 10),
 (3, 3, 8),
 (3, 3, 9),
 (3, 3, 10)]

如果所有范围的最低值为0,则可以使用 np.ndindex
In [52]: values = list(np.ndindex(4, 4, 11))

In [53]: values[:5]
Out[53]: [(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3), (0, 0, 4)]

In [54]: values[-5:]
Out[34]: [(3, 3, 6), (3, 3, 7), (3, 3, 8), (3, 3, 9), (3, 3, 10)]

不,所有低值都不是0,因此我认为我不能使用np.ndindex。另一种方法适合我。一旦我有元组列表,我可以将其转换为numpy数组。谢谢! - okkhoy
我刚刚注意到,运行该方法100000次,我的方法在9秒内给出结果,而使用itertools需要44秒。这种方法编码更简单,但由于我必须执行几百万次,所以我一直在关注效率。 - okkhoy

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接