检查numpy数组是否为另一个数组的子集

6

有类似的问题已经在SO上提出,但它们具有更具体的限制,并且它们的答案不适用于我的问题。

通常来说,确定任意numpy数组是否是另一个数组的子集的最pythonic方法是什么? 更具体地说,我有一个大约20000x3的数组,我需要知道完全包含在集合中的1x3元素的索引。 更普遍地说,下面的写法是否更符合python风格:

master = [12, 155, 179, 234, 670, 981, 1054, 1209, 1526, 1667, 1853]  # some indices of interest
triangles = np.random.randint(2000, size=(20000, 3))  # some data

for i, x in enumerate(triangles):
    if x[0] in master and x[1] in master and x[2] in master:
        print i

对于我的使用情况,我可以安全地假设len(master) << 20000。(因此,也可以安全地假设master已排序,因为这很便宜。)

5个回答

4

可以使用np.isin,这可能比列表解析在@petrichor的答案中更有效。使用相同的设置:

import numpy as np

x = np.arange(30).reshape(10, 3)
searchKey = [4, 5, 8]
x[[0, 3, 7], :] = searchKey
array([[ 4,  5,  8],
       [ 3,  4,  5],
       [ 6,  7,  8],
       [ 4,  5,  8],
       [12, 13, 14],
       [15, 16, 17],
       [18, 19, 20],
       [ 4,  5,  8],
       [24, 25, 26],
       [27, 28, 29]])

现在可以使用np.isin;默认情况下,它将逐个元素地工作:
np.isin(x, searchKey)
array([[ True,  True,  True],
       [False,  True,  True],
       [False, False,  True],
       [ True,  True,  True],
       [False, False, False],
       [False, False, False],
       [False, False, False],
       [ True,  True,  True],
       [False, False, False],
       [False, False, False]])

现在我们需要筛选所有条目都为 True 的行,我们可以使用 all 来实现:

np.isin(x, searchKey).all(1)
array([ True, False, False,  True, False, False, False,  True, False,
       False])

如果现在想要相应的索引,可以使用 np.where:
np.where(np.isin(x, searchKey).all(1))
(array([0, 3, 7]),)

编辑:

我发现需要小心谨慎。例如,如果我执行以下操作:

x[4, :] = [8, 4, 5]

因此,在分配中,我使用与“searchKey”中相同的值,但顺序不同,当执行时仍将其返回。
np.where(np.isin(x, searchKey).all(1))

打印

(array([0, 3, 4, 7]),)

那可能是不期望的。


4
你可以通过在列表推导式中迭代数组来轻松实现这一点。一个玩具示例如下所示:
import numpy as np
x = np.arange(30).reshape(10,3)
searchKey = [4,5,8]
x[[0,3,7],:] = searchKey
x

提供

 array([[ 4,  5,  8],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 4,  5,  8],
        [12, 13, 14],
        [15, 16, 17],
        [18, 19, 20],
        [ 4,  5,  8],
        [24, 25, 26],
        [27, 28, 29]])

现在对元素进行迭代:
ismember = [row==searchKey for row in x.tolist()]

结果是:
[True, False, False, True, False, False, False, True, False, False]

您可以按照您的问题将其修改为子集:

searchKey = [2,4,10,5,8,9]  # Add more elements for testing
setSearchKey = set(searchKey)
ismember = [setSearchKey.issuperset(row) for row in x.tolist()]

如果您需要索引,则使用:

np.where(ismember)[0]

它提供了

array([0, 3, 7])

谢谢你的回答——你的列表推导式比我的for循环更符合Pythonic风格——但是你的回答没有考虑问题的子集部分。row in searchKey并不能判断row是否为searchKey的子集。在这个例子中,它总是返回一个由False组成的数组。 - aestrivex
我已经更新了你的问题答案。更新后的版本考虑到了集合。 - petrichor
1
整体想法不错,但你应该简化列表推导式。可以像这样 set_searchKey = set(searchKey); [set_searchKey.issuperset(row) for row in x],这样你就不需要在每次迭代中将searchkey转换为集合了。另外注意,x不需要转换为列表。 - Bi Rico
@BiRico 我已经相应地进行了修改。非常感谢。 - petrichor
不错的解决方案。如果我只想找到返回 True第一行,有更好的方法吗? - Filippo Bistaffa
显示剩余3条评论

4
这里有两种方法可供尝试:
1、使用集合。集合的实现方式与Python字典类似,并且具有恒定时间查找。这看起来与您已经拥有的代码非常相似,只需从主要代码中创建一个集合即可:
master = [12,155,179,234,670,981,1054,1209,1526,1667,1853]
master_set = set(master)
triangles = np.random.randint(2000,size=(20000,3)) #some data
for i, x in enumerate(triangles):
  if master_set.issuperset(x):
    print i

2、使用searchsorted。这很好,因为它不需要你使用可哈希类型并且使用了numpy的内置函数。searchsorted在主数组大小上是log(N),在三角形大小上是O(N),所以它应该也很快,根据您的数组大小等情况,可能会更快。

master = [12,155,179,234,670,981,1054,1209,1526,1667,1853]
master = np.asarray(master)
triangles = np.random.randint(2000,size=(20000,3)) #some data
idx = master.searchsorted(triangles)
idx.clip(max=len(master) - 1, out=idx)
print np.where(np.all(triangles == master[idx], axis=1))

这第二种情况假设主数据已排序,正如searchsorted所暗示的那样。

searchsorted在这种情况下没有帮助,因为它会将元素虚拟插入到列表中的正确位置,无论它们是否实际存在于'master'中。也就是说,虚假条目[11,154,178]将返回与感兴趣的条目[12,155,179]相同的结果。事实上,你的代码甚至没有执行到那一步,因为在1853和2000之间的插入操作超出了'master'的大小,导致程序崩溃。 - aestrivex
最后一行处理了虚拟插入,但是你说得对,你需要进行剪裁以解决大小问题。我加了一行剪裁代码。 - Bi Rico

2

在numpy中,更自然(可能更快)的解决集合操作的方法是使用numpy.lib.arraysetops中的函数。这些函数通常允许您避免在Python的set类型之间来回转换。要检查一个数组是否是另一个数组的子集,请使用numpy.setdiff1d()并测试返回的数组长度是否为0:

import numpy as np
a = np.arange(10)
b = np.array([1, 5, 9])
c = np.array([-5, 5, 9])
# is `a` a subset of `b`?
len(np.setdiff1d(a, b)) == 0 # gives False
# is `b` a subset of `a`?
len(np.setdiff1d(b, a)) == 0 # gives True
# is `c` a subset of `a`?
len(np.setdiff1d(c, a)) == 0 # gives False

您还可以选择设置assume_unique=True以提高速度。

我有点惊讶的是,numpy没有像内置的issubset()函数一样做到以上所述(类似于set.issubset())。

另一个选项是使用numpy.in1d()(请参见https://stackoverflow.com/a/37262010/2020363

编辑:我刚刚意识到,在遥远的过去的某个时候,这件事困扰了我很久,所以我编写了自己简单的函数:

def issubset(a, b):
    """Return whether sequence `a` is a subset of sequence `b`"""
    return len(np.setdiff1d(a, b)) == 0

1

从这里开始:

master=[12,155,179,234,670,981,1054,1209,1526,1667,1853] #一些感兴趣的索引

triangles=np.random.randint(2000,size=(20000,3)) #一些数据

最Pythonic的方法是什么,以找到包含在master中的三元组的索引?尝试使用np.in1d与列表推导:

inds = [j for j in range(len(triangles)) if all(np.in1d(triangles[j], master))]

%timeit 显示 ~0.5 秒 = 半秒钟

--> 更快的方法(1000 倍!),避免使用 Python 的慢循环?尝试使用 np.isinnp.sum 来获取 np.arange 的布尔掩码:

inds = np.where(
 np.sum(np.isin(triangles, master), axis=-1) == triangles.shape[-1])

%timeit 表明 ~0.0005 秒 = 半毫秒!

建议:尽可能避免循环列表,因为与包含一个算术运算的 python 循环的单次迭代相同价格,您可以调用执行数千个相同算术操作的 numpy 函数。

结论

似乎 np.isin(arr1=triangles, arr2=master) 是您正在寻找的函数,它给出了一个布尔掩码,形状与 arr1 相同,告诉每个 arr1 元素是否也是 arr2 的元素;从这里,要求掩码行的总和为 3(即三角形行的完整长度)给出所需行的 1d 掩码(或使用 np.arange 的索引)。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接