高效的Python实现numpy数组比较

7

背景

我有两个numpy数组,想以最高效/最快的方式进行一些比较操作。两个数组中都只包含无符号整数。

pairs是一个n x 2 x 3数组,其中包含一长串成对的三维坐标(为了方便命名,pairs数组包含一组配对...)-即

# full pairs array
In [145]: pairs
Out[145]:
    array([[[1, 2, 4],
        [3, 4, 4]],
        .....
       [[1, 2, 5],
        [5, 6, 5]]])

# each entry contains a pair of 3D coordinates
In [149]: pairs[0]
Out[149]:
array([[1, 2, 4],
       [3, 4, 4]])

positions是一个n x 3的数组,其中存储了一组三维坐标。

In [162]: positions
Out[162]:
array([[ 1,  2,  4],
       [ 3,  4,  5],
       [ 5,  6,  3],
       [ 3,  5,  6],
       [ 6,  7,  5],
       [12,  2,  5]])

目标 我想创建一个数组,它是pairs数组的子集,但只包含最多有一个对在positions数组中的条目 - 即不应该有两个对都在positions数组中。对于一些域信息,每个对将至少有一对位置在positions列表内。

已尝试的方法 我的最初的天真方法是循环遍历pairs数组中的每个对,并从positions向量中减去每个对的两个位置,确定在两种情况下是否找到了匹配,这由来自减法操作的两个向量中都存在0来表示:

 if (~(positions-pair[0]).any(axis=1)).any() and 
    (~(positions-pair[1]).any(axis=1)).any():
    # both members of the pair were in the positions array -
    # these weren't the droids we were looking for
    pass
 else:
    # append this set of pairs to a new matrix 

这样做还不错,并利用了一些向量化技术,但可能有更好的方法吗?

对于程序中其他性能敏感的部分,我已经使用Cython重写了一些内容,这带来了巨大的加速,尽管在这种情况下(至少基于一个天真的嵌套for循环实现),这种方法略慢于上面概述的方法。

如果有人有建议,我很乐意进行分析并报告结果(我已经设置好了所有的分析基础设施)。


http://stackoverflow.com/a/31889183/901925 中使用的方法应该可行。它扩展了维度(或一个或两个数组),以便您可以执行逐元素比较,然后使用“all”在一个或多个维度上合并结果。或者在您的情况下,我会使用'sum'为1的'rows'。我稍后可以详细说明。 - hpaulj
2个回答

6

方法一

如问题所述,两个数组都只包含无符号的 int,这可以利用起来将 XYZ 合并成一个线性索引等效版本,对于每个唯一的 XYZ 三元组都是独一无二的。实现代码如下 -

maxlen = np.max(pairs,axis=(0,1))
dims = np.append(maxlen[::-1][:-1].cumprod()[::-1],1)

pairs1D = np.dot(pairs.reshape(-1,3),dims)
positions1D = np.dot(positions,dims)
mask_idx = ~(np.in1d(pairs1D,positions1D).reshape(-1,2).all(1))
out = pairs[mask_idx]

由于涉及到3D坐标,因此您还可以使用cdist来检查输入数组之间是否存在相同的XYZ三元组。以下是两种采用该思路的实现方式。

方法二

from scipy.spatial.distance import cdist

p0 = cdist(pairs[:,0,:],positions)
p1 = cdist(pairs[:,1,:],positions)
out = pairs[((p0==0) | (p1==0)).sum(1)!=2]

第三种方法

mask_idx = ~((cdist(pairs.reshape(-1,3),positions)==0).any(1).reshape(-1,2).all(1))
out = pairs[mask_idx]

运行时测试 -
In [80]: n = 5000
    ...: pairs = np.random.randint(0,100,(n,2,3))
    ...: positions= np.random.randint(0,100,(n,3))
    ...: 

In [81]: def cdist_split(pairs,positions):
    ...:    p0 = cdist(pairs[:,0,:],positions)
    ...:    p1 = cdist(pairs[:,1,:],positions)
    ...:    return pairs[((p0==0) | (p1==0)).sum(1)!=2]
    ...: 
    ...: def cdist_merged(pairs,positions):
    ...:    mask_idx = ~((cdist(pairs.reshape(-1,3),positions)==0).any(1).reshape(-1,2).all(1))
    ...:    return pairs[mask_idx]
    ...: 
    ...: def XYZ_merged(pairs,positions):
    ...:    maxlen = np.max(pairs,axis=(0,1))
    ...:    dims = np.append(maxlen[::-1][:-1].cumprod()[::-1],1)
    ...:    pairs1D = np.dot(pairs.reshape(-1,3),dims)
    ...:    positions1D = np.dot(positions,dims)
    ...:    mask_idx1 = ~(np.in1d(pairs1D,positions1D).reshape(-1,2).all(1))
    ...:    return pairs[mask_idx1]
    ...: 

In [82]: %timeit cdist_split(pairs,positions)
1 loops, best of 3: 662 ms per loop

In [83]: %timeit cdist_merged(pairs,positions)
1 loops, best of 3: 615 ms per loop

In [84]: %timeit XYZ_merged(pairs,positions)
100 loops, best of 3: 4.02 ms per loop

验证结果 -

In [85]: np.allclose(cdist_split(pairs,positions),cdist_merged(pairs,positions))
Out[85]: True

In [86]: np.allclose(cdist_split(pairs,positions),XYZ_merged(pairs,positions))
Out[86]: True

3

我来详细说明一下我的评论:

扩展pairs,让它更有趣。 可以随意尝试使用更大、更现实的数组进行测试:

In [260]: pairs = np.array([[[1,2,4],[3,4,4]],[[1,2,5],[5,6,5]],[[3,4,5],[3,5,6]],[[6,7,5],[1,2,3]]])

In [261]: positions = np.array([[ 1,  2,  4],
       [ 3,  4,  5],
       [ 5,  6,  3],
       [ 3,  5,  6],
       [ 6,  7,  5],
       [12,  2,  5]])

将两个数组扩展成可广播的形状:

In [262]: I = pairs[None,...]==positions[:,None,None,:]

In [263]: I.shape
Out[263]: (6, 4, 2, 3)

大型布尔数组,展示所有维度上逐个元素的匹配情况。可以随意替换其他比较方式(例如difference ==0np.isclose 用于浮点数等)。

In [264]: J = I.all(axis=-1).any(axis=0).sum(axis=-1)

In [265]: J
Out[265]: array([1, 0, 2, 1])

整合不同维度的结果。在坐标上匹配所有数字,在位置上匹配任意数字,通过成对匹配计算匹配数量。

In [266]: pairs[J==1,...]
Out[266]: 
array([[[1, 2, 4],
        [3, 4, 4]],

       [[6, 7, 5],
        [1, 2, 3]]])

J==1代表只有一对值匹配的元素。(见注释)

anyandsum的组合可以用于测试用例,但可能需要针对更大的测试用例进行调整。但是这个思路通常适用。


对于https://dev59.com/vo3da4cB1Zd3GeqP5-K5#31901675测试的数组大小,我的解决方案非常慢。特别是它正在执行==测试,导致I形状为(5000, 5000, 2, 3)

压缩最后一个维度会有很大帮助。

dims = np.array([10000,100,1])  # simpler version of dims from XYZmerged
pairs1D = np.dot(pairs.reshape(-1,3),dims)
positions1D = np.dot(positions,dims)
I1d = pairs1D[:,None]==positions1D[None,:]
J1d = I1d.any(axis=-1).reshape(pairs.shape[:2]).sum(axis=-1)

我将J1d表达式更改为匹配我的表达式-以计算每对的匹配次数。

Divakar使用的in1d1甚至更快:

mask = np.in1d(pairs1D, positions1D).reshape(-1,2)
Jmask = mask.sum(axis=-1)

我刚意识到OP要求“最多有一对在positions数组中”,而我测试的是“每对恰好匹配”。因此我的测试都应该改为

(在我的随机样本n=5000中,这就是全部内容。没有任何pairs同时出现在positions中。其中54个J等于1,其余都是0,没有匹配)。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接