在NumPy数组中赋值

5

我有一个由零组成的numpy数组。为了更具体,假设它是2x3x4:

x = np.zeros((2,3,4))

假设我有一个2x3的随机整数数组,范围在0到3之间(x的第三个维度的索引)。
>>> y = sp.stats.distributions.randint.rvs(0, 4, size=(2,3))
>>> y
[[2 1 0]
 [3 2 0]]

我应该如何高效地完成以下任务(编辑:不使用for循环,适用于具有任意维度和任意元素数量的x)?
>>> x[0,0,y[0,0]]=1
>>> x[0,1,y[0,1]]=1
>>> x[0,2,y[0,2]]=1
>>> x[1,0,y[1,0]]=1
>>> x[1,1,y[1,1]]=1
>>> x[1,2,y[1,2]]=1
>>> x
array([[[ 0.,  0.,  1.,  0.],
        [ 0.,  1.,  0.,  0.],
        [ 1.,  0.,  0.,  0.]],

       [[ 0.,  0.,  0.,  1.],
        [ 0.,  0.,  1.,  0.],
        [ 1.,  0.,  0.,  0.]]])

感谢,詹姆斯。
3个回答

2

目前,我只能想到“简单”版本,它涉及沿着前两个维度进行平坦化。这段代码应该可以工作:

shape_last = x.shape[-1]
x.reshape((-1, shape_last))[np.arange(y.size), y.flatten()] = 1

这将产生以下结果(使用我的随机生成的y):
array([[[ 0.,  0.,  0.,  1.],
        [ 0.,  0.,  1.,  0.],
        [ 0.,  1.,  0.,  0.]],

       [[ 0.,  1.,  0.,  0.],
        [ 0.,  0.,  0.,  1.],
        [ 0.,  1.,  0.,  0.]]])

关键是,如果您使用多个numpy数组进行索引(高级索引),numpy将使用索引对来索引数组。

当然,确保xy都是C顺序或F顺序 - 否则,调用reshapeflatten可能会给出不同的顺序。


这肯定比我之前使用的for循环更快,而且似乎随着维度数量和每个维度中元素数量的增加,它的可扩展性变得非常好。谢谢。 - user1857751

2
使用numpy.meshgrid()函数创建索引数组,您可以使用这些数组来索引原始数组和第三维值数组。
import numpy as np
import scipy as sp
import scipy.stats.distributions

a = np.zeros((2,3,4))
z = sp.stats.distributions.randint.rvs(0, 4, size=(2,3))

xx, yy = np.meshgrid( np.arange(2), np.arange(3) )
a[ xx, yy, z[xx, yy] ] = 1
print a

我已将您的数组从x重命名为a,并将索引数组从y重命名为z,以提高清晰度。
a = np.zeros((2,3,4,5))
z = sp.stats.distributions.randint.rvs(0, 4, size=(2,3))
w = sp.stats.distributions.randint.rvs(0, 5, size=(2,3))

xx, yy = np.meshgrid( np.arange(2), np.arange(3) )
a[ xx, yy, z[xx, yy], w[xx, yy] ] = 1

谢谢你的回复。你知道如何将这个推广到任意维度吗? - user1857751
@user1857751:是的,当然,它可以直接推广。使用meshgrid在所有维度上创建索引向量,这些向量是您不想查找的维度,然后使用它们来索引原始数组和您想要查找的索引。 - Alex I

0
x = np.zeros((2,3,4))
y=np.array([[2, 1, 0],[3, 2, 0]]) # or y=sp.stats...
for i in range(2):
    for j in range(3):
        x[i,j,y[i,j]]=1

如果我没记错的话,这将产生期望的结果。如果数组维度永远不会改变,请考虑通过以下方式替换两个for循环及其负担:

for j in range(3):
    x[0,j,y[0,j]] = x[1,j,y[1,j]] = 1

2
谢谢回复,但问题是我需要高效地处理大型数组。换句话说,我想避免使用for循环。 - user1857751
我需要处理任何维度的代码,即任意数量的维度或每个维度中的任意数量元素的代码。 - user1857751
我对numpy不是非常熟悉,所以我不知道它是否有相关的内置函数来完成这项工作。当然,有一些相关的例程(argwhere、nonzero、where和extract)可以完成相反的工作,例如可以在数组中找到非零元素的位置。 - James Waldby - jwpat7

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接