嵌套循环的向量化

5
我希望对一个嵌套循环进行向量化,这个循环将用于处理300,000个列表,每个列表包含3个值。该嵌套循环会比较每个列表中的值与其他列表中相应的值,并仅附加具有最大差异小于0.1的相应值的列表索引。因此,包含[0.234, 0.456, 0.567]和[0.246, 0.479, 0.580]的列表将属于此类别,因为它们的相应值(即0.234和0.246;0.456和0.479;0.567和0.580)之间的差异小于0.1。

我目前使用以下嵌套循环来完成此操作,但它当前需要约58小时才能完成(总计90万亿次迭代);

import numpy as np
variable = np.random.random((300000,3)).tolist()
out1=list()
out2=list()
for i in range(0:300000):
    for j in range(0:300000):
        if ((i<j) and ((abs(variable[i][0]-variable[j][0]))<0.1) and ((abs(variable[i][1]-variable[j] [1]))<0.1) and ((abs(variable[i][2]-variable[j][2]))<0.1)):
        out1.append(i)  
        out2.append(j)

你的 variable 是随机的,这只是为了举例吗?还是你实际上在模拟某些东西? - Julien
是的,这只是一个例子 - 实际上我有一个列表的列表,通过模拟生成,其中数据实际上在我提到的阈值范围内。 - JBorg
2个回答

3

将数据转换为NumPy数组,以便于在之后使用NumPy函数。然后,可以提出两种方法。

方法一

可以使用NumPy广播来将它们扩展到3D数组,并以矢量化的方式执行操作。因此,我们会有如下实现 -

th = 0.1 # Threshold
arr = np.asarray(variable)
out1,out2 = np.where(np.triu((np.abs(arr[:,None,:] - arr) < th).all(-1),1))

方法二

这是一种专注于内存效率的替代实现,它使用选择性索引来负责这些迭代 -

th = 0.1 # Threshold
arr = np.asarray(variable)
R,C = np.triu_indices(arr.shape[0],1)
mask = (np.abs(arr[R] - arr[C])<th).all(-1)
out1,out2 = R[mask], C[mask]

如果你有1TB的内存,那么这个方案是可行的 :) - Eelco Hoogendoorn
@EelcoHoogendoorn “方法二”可能会更轻巧 :) - Divakar
尝试过这个,但出现了内存不足的错误;再次尝试只使用了30,000个列表,但仍在运行中;我猜想这可能还需要很长时间?运行时占用256Gb RAM。 - JBorg
@JBorg,你试过Approach #2了吗?另外,建议你也试试基于KDTrees的解决方案,就像@Eelco Hoogendoorn所建议的那样。 - Divakar
仍然使用300,000个列表进行第二种方法时报告内存错误,并且在30,000个列表中运行了大约5-7分钟。再次运行以检查实际时间。顺便说一下,感谢您的回复! :) - JBorg

3

请查看scipy.spatial;它具有大量用于高效解决此类空间查询的功能;KDTrees尤其如此:

import scipy.spatial
out = scipy.spatial.cKDTree(variable).query_pairs(r=0.1, p=np.infinity)

尝试了一下,返回“需要浮点数”。我认为这是一个简单的问题。谢谢您的回复! - JBorg
啊,我误解了文档;它们在这个主题上并不特别清晰。尝试编辑一下。'无穷范数'应该归结为你要寻找的度量标准;任何一个分量的最大绝对值。 - Eelco Hoogendoorn
为了效率起见,最好完全放弃使用列表,转而使用ndarray。这适用于您的输入,也适用于输出;请注意,您可以在此调用中添加output_type ='ndarray'关键字参数。 - Eelco Hoogendoorn
将np.infinity更改为np.infty后,它可以正常工作,但在运行了300,000个列表几分钟后,出现了“Killed”消息,这又一次证明是内存问题。如何添加output_type关键字参数? - JBorg
请查看我帖子中提供的链接以获取确切的签名。你仍然会收到内存错误的原因是你生成的数据对数非常大;由np.random.random生成的所有点都位于一个单位立方体中,因此在r=0.1的盒子内的邻居数量仍然是巨大的。你确定这实际上代表了你的真实数据吗? - Eelco Hoogendoorn
这完全解决了问题 - 整个函数现在运行时间约为30秒。谢谢! - JBorg

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接