使用numpy进行逐元素“in”的Pythonic和高效方式

13

我正在寻找一种高效获取布尔类型数组的方法,给定两个大小相等的数组a和b,如果a的对应元素出现在b的对应元素中,则每个元素都为true。

例如,下面的程序:

a = numpy.array([1, 2, 3, 4])
b = numpy.array([[1, 2, 13], [2, 8, 9], [5, 6], [7]])
print(numpy.magic_function(a, b))

应该打印

[True, True, False, False]

请记住,这个函数应该等同于

[x in y for x, y in zip(a, b)]

只有在ab很大,并且b的每个元素都相对较小的情况下,才会对numpy进行优化。


3
顺便提一下:如果你关心效率,就不应该使用类似于b这样的数组:它只是一个对象类型的1D数组,而不是一个快速的NumPy 2D整数数组。请注意,在翻译过程中不要改变原文的含义。 - DSM
公平地说,大部分工作(检查一个整数是否在一个numpy数组中)由你的列表推导式处理。这可能是最高效的(并且绝对是最Pythonic的)方法来完成它。 - Alyssa Haroldsen
@Kupiakos 不错,我忘记提到了,我需要这个用例是针对大的 ab,并且 b 的每个元素都有几个元素。 - Martín Fixman
根据http://docs.scipy.org/doc/numpy/reference/arrays.nditer.html,我认为你想要使用nditer或者可能是in1d http://docs.scipy.org/doc/numpy/reference/generated/numpy.in1d.html#numpy.in1d。 - Bennett Brown
您IP地址为143.198.54.68,由于运营成本限制,当前对于免费用户的使用频率限制为每个IP每72小时10次对话,如需解除限制,请点击左下角设置图标按钮(手机用户先点击左上角菜单按钮)。 - Martín Fixman
1
由于“b”的元素长度不同,将其作为数组意义不大。可以将其转换为真正的二维数组,或者将其视为列表处理。 - hpaulj
3个回答

4
为了利用NumPy的broadcasting规则,您应该首先使数组b平方,这可以使用itertools.izip_longest实现:
from itertools import izip_longest

c = np.array(list(izip_longest(*b))).astype(float)

导致:
array([[  1.,   2.,   5.,   7.],
       [  2.,   8.,   6.,  nan],
       [ 13.,   9.,  nan,  nan]])

然后,通过执行np.isclose(c, a),您将得到一个二维布尔数组,显示每个c[:, i]a[i]之间的差异,根据广播规则,结果如下:
array([[ True,  True, False, False],
       [False, False, False, False],
       [False, False, False, False]], dtype=bool)

可以用来获取您的答案:

np.any(np.isclose(c, a), axis=0)
#array([ True,  True, False, False], dtype=bool)

3
有没有小列表长度的上限?如果有,您可以将b制成1000x5的矩阵,并使用nan填充子数组中太短的间隙。然后,您可以使用numpy.any来获取所需的答案,类似于以下内容:
In [42]: a = np.array([1, 2, 3, 4])
    ...: b = np.array([[1, 2, 13], [2, 8, 9], [5, 6], [7]])

In [43]: bb = np.full((len(b), max(len(i) for i in b)), np.nan)

In [44]: for irow, row in enumerate(b):
    ...:     bb[irow, :len(row)] = row

In [45]: bb
Out[45]: 
array([[  1.,   2.,  13.],
       [  2.,   8.,   9.],
       [  5.,   6.,  nan],
       [  7.,  nan,  nan]])

In [46]: a[:,np.newaxis] == bb
Out[46]: 
array([[ True, False, False],
       [ True, False, False],
       [False, False, False],
       [False, False, False]], dtype=bool)

In [47]: np.any(a[:,np.newaxis] == bb, axis=1)
Out[47]: array([ True,  True, False, False], dtype=bool)

我不知道这是否对您的数据更快。

1

概述

Sauldo Castro的方法在目前发布的方法中运行速度最快。原始帖子中的生成器表达式是第二快的。

生成测试数据的代码:

import numpy
import random

alength = 100
a = numpy.array([random.randint(1, 6) for i in range(alength)])
b = []
for i in range(alength):
    length = random.randint(1, 5)
    element = []
    for i in range(length):
        element.append(random.randint(1, 6))
    b.append(element)
b = numpy.array(b)
print a, b

选项:

from itertools import izip_longest
def magic_function1(a, b): # From OP Martin Fixman
    return [x in y for x, y in zip(a, b)]  

def magic_function2(a, b): # What I thought might be better.
    bools = []
    for x, y in zip(a,b):
        found = False
        for j in y:
            if x == j:
                found=True
                break
        bools.append(found)

def magic_function3(a, b): # What I tried first
    bools = []
    for i in range(len(a)):
        found = False
        for j in range(len(b[i])):
            if a[i] == b[i][j]:
                found=True
                break
        bools.append(found)

def magic_function4(a, b): # From Bas Swinkels
    bb = numpy.full((len(b), max(len(i) for i in b)), numpy.nan)
    for irow, row in enumerate(b):
        bb[irow, :len(row)] = row
    a[:,numpy.newaxis] == bb
    return numpy.any(a[:,numpy.newaxis] == bb, axis=1)

def magic_function5(a, b): # From Sauldo Castro, revised version
    c = numpy.array(list(izip_longest(*b))).astype(float)
    return numpy.isclose(c, a), axis=0)  

时间 n_executions

n_executions = 100
clock = timeit.Timer(stmt="magic_function1(a, b)", setup="from __main__ import magic_function1, a, b")
print clock.timeit(n_executions), "seconds"
# Repeat with each candidate function

结果:

  • magic_function1 花费 0.158078225475 秒
  • magic_function2 花费 0.181080926835 秒
  • magic_function3 花费 0.259621047822 秒
  • magic_function4 花费 0.287054750224 秒
  • magic_function5 花费 0.0839162196207 秒

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接