NumPy数组索引

4

这里有一个涉及到对数组进行索引以获取其一部分值的简单问题。比如说我有一个recarray,其中一个空间存储年龄,另一个空间存储相应的数值。我也有一个数组,它是我想要的年龄子集。以下是我的意思:

ages = np.arange(100)
values = np.random.uniform(low=0, high= 1, size = ages.shape)
data = np.core.rec.fromarrays([ages, values], names='ages,values')
desired_ages = np.array([1,4, 16, 29, 80])

我要做的是类似这样的事情:

data.values[data.ages==desired_ages]

但是,它没有起作用。

3个回答

4

您想创建一个子数组,其中只包含索引在desired_ages中的值。

Python没有直接对应此操作的语法,但是列表推导可以完成这个任务:

result = [value for index, value in enumerate(data.values) if index in desired_ages]

然而,如果这样做,Python需要扫描data.values中的每个元素以查找desired_ages,这会很慢。 如果您能插入...
desired_ages = set(desired_ages)

在前一行中,这将提高性能。(您可以在常数时间内确定一个值是否在集合中,而不受集合大小的影响。)

完整示例

import numpy as np

ages = np.arange(100)
values = np.random.uniform(low=0, high= 1, size = ages.shape)
data = np.core.rec.fromarrays([ages, values], names='ages,values')
desired_ages = np.array([1,4, 16, 29, 80])

result = [value for index, value in enumerate(data.values) if index in desired_ages]
print result

[0.45852624094611272, 0.0099713014816563694, 0.26695859251958864, 0.10143425810157047, 0.93647796171383935]

2

这是一个合理的首要方法:

>>> bool_indices = reduce(numpy.logical_or, 
                          (data.ages == x for x in desired_ages))
>>> data.values[bool_indices]
array([ 0.63143784,  0.93852927,  0.0026815 ,  0.66263594,  0.2603184 ])

但这个方法使用了 Python 函数,所以可能会更慢。我们可以很容易地将其转换为纯 Numpy,使用 ix_ 使数组相互广播。(meshgrid 也可以,但会使用更多内存。)

>>> bools_2d = numpy.equal(*numpy.ix_(desired_ages, data.ages))
>>> bool_indices = numpy.logical_or.reduce(bools_2d)
>>> data.ages[bool_indices]
array([ 1,  4, 16, 29, 80])
>>> data.values[bool_indices]
array([ 0.32324063,  0.65453647,  0.9300062 ,  0.34534668,  0.12151951])

另请参阅HYRY的答案,该答案可能提供更快的解决方案(使用searchsorted)和更易于阅读的解决方案(使用in1d)。


2

我稍微修改了你的示例,随机改变了年龄的顺序:

import numpy as np
np.random.seed(0)
ages = np.arange(3,103)
np.random.shuffle(ages)
values = np.random.uniform(low=0, high= 1, size = ages.shape)
data = np.core.rec.fromarrays([ages, values], names='ages,values')
desired_ages = np.array([4, 16, 29, 80])

如果desired_ages中的所有元素都在data.ages中,你可以先按年龄字段对数据进行排序,然后使用searchsorted()快速找到所有索引:
data.sort(order="ages") # sort by ages
print data.values[np.searchsorted(data.ages, desired_ages)]

或者你可以使用np.in1d来获取一个布尔数组,并将其用作索引:
print data.values[np.in1d(data.ages, desired_ages)]

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接