如何快速检查两个数组是否具有相同的行

3
我正在尝试找出一种更好的方法来检查两个二维数组是否包含相同的行。以以下示例为简短示例:
>>> a
array([[0, 1, 2],
       [3, 4, 5],
       [6, 7, 8]])
>>> b
array([[6, 7, 8],
       [3, 4, 5],
       [0, 1, 2]])

在这种情况下,b=a[::-1]。要检查两行是否相等:
>>>a=a[np.lexsort((a[:,0],a[:,1],a[:,2]))]
>>>b=b[np.lexsort((b[:,0],b[:,1],b[:,2]))]
>>> np.all(a-b==0)
True

这很好而且相当快。但当两行“靠近”时,问题就出现了:
array([[-1.57839867  2.355354   -1.4225235 ],
       [-0.94728367  0.         -1.4225235 ],
       [-1.57839867 -2.355354   -1.4225215 ]]) <---note ends in 215 not 235
array([[-1.57839867 -2.355354   -1.4225225 ],
       [-1.57839867  2.355354   -1.4225225 ],
       [-0.94728367  0.         -1.4225225 ]])

在1E-5的允差范围内,这两个数组按行相等,但lexsort会告诉你它们不相等。这可以通过不同的排序顺序来解决,但我想要更一般的情况。

我正在考虑以下想法:

a=a.reshape(-1,1,3)
>>> a-b
array([[[-6, -6, -6],
        [-3, -3, -3],
        [ 0,  0,  0]],

       [[-3, -3, -3],
        [ 0,  0,  0],
        [ 3,  3,  3]],

       [[ 0,  0,  0],
        [ 3,  3,  3],
        [ 6,  6,  6]]])
>>> np.all(np.around(a-b,5)==0,axis=2)
array([[False, False,  True],
       [False,  True, False],
       [ True, False, False]], dtype=bool)
>>>np.all(np.any(np.all(np.around(a-b,5)==0,axis=2),axis=1))
True

这并不能告诉你两个数组是否按行相等,只是判断b中的所有元素是否都接近于a中的某个值。需要比较的行数可能有几百行,我需要经常进行此操作。有什么好的想法吗?


1
我会加入 scipy.spatial.cKDTree(根据scipy版本和用法可能是KDTree),作为一种更直接的方法。 - seberg
这正是我一直在寻找的。知道一定有更好的方法。 - Daniel
1个回答

1
你最后的代码并没有做你想象中的事情。它告诉你的是每一行在b是否接近于a中的每一行。如果你改变用于外部调用np.anynp.allaxis,你可以检查每一行在a中是否接近于b中的某些行。如果b中的每一行都接近于a中的一行,而且a中的每一行都接近于b中的一行,那么这些集合相等。可能不是很计算效率高,但对于中等大小的numpy数组来说,可能非常快速。
def same_rows(a, b, tol=5) :
    rows_close = np.all(np.round(a - b[:, None], tol) == 0, axis=-1)
    return (np.all(np.any(rows_close, axis=-1), axis=-1) and
            np.all(np.any(rows_close, axis=0), axis=0))

>>> rows, cols = 5, 3
>>> a = np.arange(rows * cols).reshape(rows, cols)
>>> b = np.arange(rows)
>>> np.random.shuffle(b)
>>> b = a[b]
>>> a
array([[ 0,  1,  2],
       [ 3,  4,  5],
       [ 6,  7,  8],
       [ 9, 10, 11],
       [12, 13, 14]])
>>> b
array([[ 9, 10, 11],
       [ 3,  4,  5],
       [ 0,  1,  2],
       [ 6,  7,  8],
       [12, 13, 14]])
>>> same_rows(a, b)
True
>>> b[0] = b[1]
>>> b
array([[ 3,  4,  5],
       [ 3,  4,  5],
       [ 0,  1,  2],
       [ 6,  7,  8],
       [12, 13, 14]])
>>> same_rows(a, b) # not all rows in a are close to a row in b
False

对于不太大的数组,性能是合理的,即使它需要构建一个 (行数,行数,列数) 的数组:

In [2]: rows, cols = 1000, 10

In [3]: a = np.arange(rows * cols).reshape(rows, cols)

In [4]: b = np.arange(rows)

In [5]: np.random.shuffle(b)

In [6]: b = a[b]

In [7]: %timeit same_rows(a, b)
10 loops, best of 3: 103 ms per loop

我提到了我发布的代码中存在一个问题。这是我最终编写的代码,加入了一些额外的参数。我添加了一个距离公式,以便更好地了解一个点有多接近,并使用lexsort方法大大减少了需要传递给此类型检查的行数。如果明天没有人提出更好的想法,我会检查你的答案。 - Daniel

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接