如何将这个操作向量化

5

假设我有两个列表(长度始终相同):

l0 = [0, 4, 4, 4, 0, 0, 0, 8, 8, 0] 
l1 = [0, 1, 1, 1, 0, 0, 0, 8, 8, 8]

我有以下关于交集和并集的规则,需要逐个元素地将这些列表进行比较:

# union and intersect
uni = [0]*len(l0)
intersec = [0]*len(l0)
for i in range(len(l0)):
    if l0[i] == l1[i]:
        uni[i] = l0[i]
        intersec[i] = l0[i]
    else:
        intersec[i] = 0  
        if l0[i] == 0:
            uni[i] = l1[i]
        elif l1[i] == 0:
            uni[i] = l0[i]
        else:
            uni[i] = [l0[i], l1[i]]

因此,所需的输出为:
uni: [0, [4, 1], [4, 1], [4, 1], 0, 0, 0, 8, 8, 8] 
intersec: [0, 0, 0, 0, 0, 0, 0, 8, 8, 0]

虽然这样做是可行的,但我需要对几百个非常大的列表(每个列表都有数千个元素)进行此操作,因此我正在寻找一种向量化的方法。我尝试使用np.where和各种掩码策略,但进展缓慢。欢迎提供建议。

*编辑*

关于

uni: [0, [4, 1], [4, 1], [4, 1], 0, 0, 0, 8, 8, 8]

对抗

uni: [0, [4, 1], [4, 1], [4, 1], 0, 0, 0, 8, 8, [0, 8]]

我仍在脑海中纠结于8与[0, 8]。这些列表源自系统注释中的BIO标记(请参见文本块的IOB标记),其中每个列表元素都是文档中的字符索引,其值是分配的枚举标签。0表示代表无注释的标签(即用于确定混淆矩阵中的负面因素);而非零元素表示该字符的分配枚举标签。由于我忽略了真正的负面因素,所以我认为8等同于[0, 8]。至于这是否简化了事情,我还不确定。

* 编辑2 *

我使用[0, 8]来保持简单,并使intersectionunion的定义与集合论一致。


2
向量化交集很简单。np.where(a0==l1,a0,0),其中 a0 = np.array(l0)。向量化你的 uni 将会很困难,因为输出不是一个有效的 numpy 数组。它可能是,但它的 dtype 将是 object,从而使大部分向量化收益失效。 - rafaelc
uni 中列表和标量的混合是一个很好的指示,表明完全“向量化”的解决方案是不可能的。如果解决方案具有长度不同的列表(或数组),则同样适用。 - hpaulj
你能定义一下“nowhere fast”吗?如果有几百个大列表,你的解决方案现在有多快? - Chris
嗯,无处快速的意思是:我没有解决方案!在速度方面什么都没有意义。我有交集,但并集是个难点。 - horcle_buzz
2个回答

2

我建议不要称它们为“交集”和“并集”,因为这些操作在集合中有明确定义的含义,而你想执行的操作并不是它们中的任何一个。

然而,要实现你想要的功能:

l0 = [0, 4, 4, 4, 0, 0, 0, 8, 8, 0]
l1 = [0, 1, 1, 1, 0, 0, 0, 8, 8, 8]

values = [
    (x
     if x == y else 0,
     0
     if x == y == 0
     else x if y == 0
     else y if x == 0
     else [x, y]) 
    for x, y in zip(l0, l1)
]

result_a, result_b = map(list, zip(*values))

print(result_a)
print(result_b)

这对于成千上万,甚至数百万个元素来说已经足够了,因为这个操作非常基础。当然,如果我们要处理数十亿个元素,你可能需要考虑使用numpy。


1
请注意,对于元组而言解决方案要简单一些,但是此例子要求使用列表。 - Grismar
我接受了你的答案,因为它确实很简单,并且可以很好地推广到我的要求上。我还没有进行任何速度测试。 - horcle_buzz
不错!与我之前的代码块相比,对于我的“并集”和“交集”的修改,我得到了以下结果: 旧的并集:10.6微秒±44.3纳秒每个循环(平均值±7次运行的标准差,每个循环100000次)旧的交集:10.9微秒±82纳秒每个循环(平均值±7次运行的标准差,每个循环100000次)新的并集:656纳秒±9.45纳秒每个循环(平均值±7次运行的标准差,每个循环1000000次)新的交集:653纳秒±11.3纳秒每个循环(平均值±7次运行的标准差,每个循环1000000次) - horcle_buzz
1
@horcle_buzz 显然你没有解包最终的 map 对象。 - rafaelc
这个程序在元素更多的列表上运行时会怎样:每次循环17.1纳秒±0.276纳秒(平均值±7次运行的标准差,每次循环100000000次) 每次循环20.5纳秒±0.456纳秒(平均值±7次运行的标准差,每次循环10000000次) - horcle_buzz

0

针对并集和交集的半向量化解决方案:

import numpy as np

l0 = np.array(l0)
l1 = np.array(l1)
intersec = np.zeros(l0.shape[0])
intersec_idx = np.where(l0==l1)
intersec[intersec_idx] = l0[intersec_idx]
intersec = intersec.astype(int).tolist()

union = np.zeros(l0.shape[0])
union_idx = np.where(l0==l1)
union[union_idx] = l0[union_idx]
no_union_idx = np.where(l0!=l1)
union = union.astype(int).tolist()
for idx in no_union_idx[0]:
    union[idx] = [l0[idx], l1[idx]]

以及输出:

>>> intersection
[0, 0, 0, 0, 0, 0, 0, 8, 8, 0]
>>> union  
[0, [4, 1], [4, 1], [4, 1], 0, 0, 0, 8, 8, [0, 8]]

注意:我认为你的原始联合解决方案是不正确的。请看最后的输出8与[0,8]的区别。


2
在包含 OP 的代码方面没有任何问题,特别是如果你正在纠正结果或进行计时。如果 OP 提供了一个好的 [mcve],我经常会将代码和数据复制到自己的会话中,然后开发自己的解决方案。这使我能够比较结果和计时。 - hpaulj
我仍在脑海中纠结于8[0, 8]。这些列表是从系统注释中的BIO标签中派生出来的,其中每个列表元素都是文档中的字符索引。 0表示表示没有注释的标签(即用于确定混淆矩阵中的负面情况);而非零元素表示为该字符分配的枚举标签。由于我忽略了真正的负面情况,所以我认为8等同于[0, 8]。至于这是否简化了事情,我还不确定。 - horcle_buzz
我喜欢这个回答,但在决定是否接受之前,我需要对真实数据进行一些基准测试。CPU时间似乎是主要瓶颈。 - horcle_buzz
唯一让我感到紧张的是将union用于数组名称,特别是因为存在同名方法。这就是为什么我选择了uni的原因。 - horcle_buzz

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接