如何检查一个numpy数组的所有元素是否在另一个numpy数组中

5
我有两个2D numpy数组,例如:
A = numpy.array([[1, 2, 4, 8], [16, 32, 32, 8], [64, 32, 16, 8]])

B = numpy.array([[1, 2], [32, 32]])
我想要得到所有包含B的任何一行中的所有元素的A的行。当B的一行中有2个相同的元素时,A的行必须至少包含2个相同的元素。在我的例子中,我想要实现这个目标:
A_filtered = [[1, 2, 4, 8], [16, 32, 32, 8]]
我可以控制值的表示方式,因此我选择二进制表示中只有一个位置为1的数字(例如:0b000000010b00000010等)。这样,我可以使用np.logical_or.reduce()函数轻松检查所有类型的值是否在一行中,但我无法检查A行中相同元素的数量是否大于或等于。我真的希望我能避免简单的for循环和数组的深复制,因为对于我来说性能是非常重要的方面。
在numpy中如何以高效的方式实现这一点?

更新:

这里的一个解决方案可能有效,但我认为性能是我关心的一个重要问题,A 可能非常大(>300000行),而 B 可以适中(>30):

[set(row).issuperset(hand) for row in A.tolist() for hand in B.tolist()]

更新2:

set() 解决方案不起作用,因为 set() 会删除所有重复值。


我建议查看这个问题:如何检查一个 NumPy 数组是否是另一个数组的子集 - Easton Bornemeier
“至少2个”很重要吗?如果A和B中的行需要在B中每个标记具有相等的计数,则我认为我知道一种优雅的解决方案。 - Eelco Hoogendoorn
在你的例子中,重叠的元素恰好以相同的顺序出现在“A”和“B”中;这是一个要求还是巧合?也就是说,如果“A”的第二个元素是“[32, 16, 32, 8]”,它是否仍应该被包括在内? - fuglede
1
我认为集合并不完全符合您的描述,因为单个和多个出现之间的区别会丢失。 - Eelco Hoogendoorn
这个回答解决了你的问题吗?检查numpy数组是否为另一个数组的子集 - Martin Spacek
显示剩余4条评论
2个回答

1

我认为这应该可以解决:

首先,按以下方式编码数据(这假设有限数量的“标记”,因为您的二进制方案似乎也暗示了这一点):

创建A形状[n_rows,n_tokens],dtype int8,其中每个元素计算标记数。以相同方式对B进行编码,形状为[n_hands,n_tokens]

这允许单个向量化表达式输出; matches =(A [None,:,:] > = B [:,None,:])。 all(axis = -1)。 (如何将此匹配数组映射到所需的输出格式留给读者作为练习,因为问题未定义多个匹配)。

但我们在这里谈论的是每个标记大于10M字节的内存。即使使用32个标记,这也不应该是不可想象的;但在这种情况下,不要对n_tokens或n_hands进行向量化循环,或两者都不要;对于小n,或者如果主体中有足够的工作要做,使循环开销微不足道,则for循环很好。

只要n_tokens和n_hands保持适度,我认为这将是最快的解决方案,如果仅停留在纯python和numpy领域内。


“_where each element counts the number of tokens_” 的意思是什么? - doodoroma
你的示例 A 数组有 7 个唯一的标记。为每个标记分配一列,并将每个元素分配为行中该标记的计数,得到 A = numpy.array([[1, 1, 1, 1, 0, 0, 0], [0, 0, 0, 1, 1, 2, 0], [0, 0, 0, 1, 1, 1, 1]])。 - Eelco Hoogendoorn
是的,这就是我在寻找的解决方案。我只能离线转换 A,而 B 必须每次都转换,我刚刚数了一下,我有 71 个令牌,这似乎很多。 - doodoroma
尽管我一开始考虑的解决方案与这个答案类似,但由于内存限制,我选择了@max9111的想法。拥有71个标记将需要太大的矩阵来处理(300000*71)。不管怎样,这真的很有用! - doodoroma

1
我希望我正确理解了你的问题。至少它可以解决你在问题描述中提到的问题。如果输出顺序应保持与输入相同,则更改原地排序。
代码看起来很丑,但应该表现良好,而且不应该太难理解。 代码
import time
import numba as nb
import numpy as np

@nb.njit(fastmath=True,parallel=True)
def filter(A,B):
  iFilter=np.zeros(A.shape[0],dtype=nb.bool_)

  for i in nb.prange(A.shape[0]):
    break_loop=False

    for j in range(B.shape[0]):
      ind_to_B=0
      for k in range(A.shape[1]):
        if A[i,k]==B[j,ind_to_B]:
          ind_to_B+=1

        if ind_to_B==B.shape[1]:
          iFilter[i]=True
          break_loop=True
          break

      if break_loop==True:
        break

  return A[iFilter,:]

测量性能。
####First call has some compilation overhead####
A=np.random.randint(low=0, high=60, size=300_000*4).reshape(300_000,4)
B=np.random.randint(low=0, high=60, size=30*2).reshape(30,2)

t1=time.time()
#At first sort the arrays
A.sort()
B.sort()
A_filtered=filter(A,B)
print(time.time()-t1)

####Let's measure the second call too####
A=np.random.randint(low=0, high=60, size=300_000*4).reshape(300_000,4)
B=np.random.randint(low=0, high=60, size=30*2).reshape(30,2)

t1=time.time()
#At first sort the arrays
A.sort()
B.sort()
A_filtered=filter(A,B)
print(time.time()-t1)

Results

46ms after the first run on a dual-core Notebook (sorting included)
32ms (sorting excluded)

我尝试了一下,它确实提高了我的代码效率,谢谢! 我对numba还不熟悉,只是稍微了解了一下。在这种情况下,我可以使用nopython模式吗?或者numpy需要使用object模式? - doodoroma
1
@doodoroma (at)njit是@(at)jit(nopython=True)的一种简写方式。 - max9111

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接