Python Numpy获取两个二维数组之间的差异

4

我有一个简单的问题一直困扰着我,基本上我有两个二维数组,它们都包含[x,y]坐标,我想将第一个数组与第二个数组进行比较,并生成一个包含第一个数组中所有不在第二个数组中出现的元素的第三个数组。这很简单,但我根本无法让它起作用。数组的大小差异很大,第一个数组可以有1000到200万个坐标,而第二个数组只有1到1000个。

这个操作会发生多次,第一个数组越大,它会发生的次数越多

示例:

arr1 = np.array([[0, 3], [0, 4], [1, 3], [1, 7], ])

arr2 = np.array([[0, 3], [1, 7]])

result = np.array([[0, 4], [1, 3]])

深入:基本上我有一个分辨率可变的二进制图像,由0和1(255)组成,我单独分析每个像素(使用已经优化过的算法),但是(故意地)每次执行这个函数时,它只分析一小部分像素,并且当它完成时,它会将这些像素的所有坐标返回给我。问题在于,当它执行时,运行以下代码:

ones = np.argwhere(img == 255) # ones = pixels array

这部分代码需要约0.02秒的时间,是最慢的部分。我的想法是将此变量创建一次,并在每次函数执行结束时,删除已解析的像素并将新数组作为参数传递,直到数组为空。


2
你能提供一个最简示例来说明你想要实现的内容吗?(例如:样本输入和期望输出)? - norok2
我有点困惑,ones = np.argwhere(img == 255) 并不能实现你的要求。那段代码中的 arr1arr2 是什么?话虽如此,我认为你无法对 np.argwhere 函数进行太多优化。 - Quang Hoang
@QuangHoang 所以,正如我之前解释的那样,这是旧代码,我现在正在使用它,我想要改进性能的方法是只调用一次argwhere并将其转换为arr1,在每个函数的末尾它将返回给我arr2,这样我打算使新的“ones”成为arr1和arr2之间的差异,这样我就不必多次调用argwhere(这是代码最慢的部分)。 - Pedro Bzz
1
我认为这可能是一个 XY 问题。也许你应该更好地描述你的原始问题,并展示一个更完整的例子,包含一些现实数据,特别是如果你追求某个性能目标的话。 - norok2
2个回答

4

我不确定你想要如何处理额外的维度,因为集合差异(像任何过滤器一样)本质上会丢失形状信息。

无论如何,NumPy提供了np.setdiff1d()来优雅地解决这个问题。


编辑 根据提供的澄清,您似乎正在寻找一种在给定轴上计算集合差异的方法,即集合的元素实际上是数组。

NumPy 中没有专门针对此问题的内置函数,但制作一个并不太困难。 为了简单起见,我们假设操作轴是第一个轴(因此集合的元素为 arr[i]),第一个数组中仅出现唯一元素,并且数组为 2D。

它们都基于这样一个思想:渐近最优的方法是构建第二个数组的 set(),然后使用它来过滤掉第一个数组中的条目。

在 Python / NumPy 中构建这样的 set 的惯用方式是使用:

set(map(tuple, arr))

在将 tuple 映射时,会冻结 arr[i],使它们可哈希,从而可以与 set() 一起使用。

不幸的是,由于过滤器可能产生大小不可预测的结果,因此 NumPy 数组并不是理想的容器。

为了解决这个问题,可以使用:

  1. 一个中间的 list
import numpy as np


def setdiff2d_list(arr1, arr2):
    delta = set(map(tuple, arr2))
    return np.array([x for x in arr1 if tuple(x) not in delta])
  1. np.fromiter()后跟np.reshape()
import numpy as np


def setdiff2d_iter(arr1, arr2):
    delta = set(map(tuple, arr2))
    return np.fromiter((x for xs in arr1 if tuple(xs) not in delta for x in xs), dtype=arr1.dtype).reshape(-1, arr1.shape[-1])
  1. NumPy的高级索引
def setdiff2d_idx(arr1, arr2):
    delta = set(map(tuple, arr2))
    idx = [tuple(x) not in delta for x in arr1]
    return arr1[idx]

将两个输入都转换为set()(这将强制输出元素的唯一性并且会失去排序):
import numpy as np


def setdiff2d_set(arr1, arr2):
    set1 = set(map(tuple, arr1))
    set2 = set(map(tuple, arr2))
    return np.array(list(set1 - set2))

或者,可以使用广播, np.any()np.all()构建高级索引:

def setdiff2d_bc(arr1, arr2):
    idx = (arr1[:, None] != arr2).any(-1).all(1)
    return arr1[idx]

上述方法的某些形式最初是在@QuangHoang's answer中提出的。

类似的方法也可以在Numba中实现,遵循与上述相同的思路,但使用哈希而不是实际的数组视图arr [i](由于Numba中set()支持的限制),并预先计算输出大小(以提高速度):

import numpy as np
import numba as nb


@nb.njit
def mul_xor_hash(arr, init=65537, k=37):
    result = init
    for x in arr.view(np.uint64):
        result = (result * k) ^ x
    return result


@nb.njit
def setdiff2d_nb(arr1, arr2):
    # : build `delta` set using hashes
    delta = {mul_xor_hash(arr2[0])}
    for i in range(1, arr2.shape[0]):
        delta.add(mul_xor_hash(arr2[i]))
    # : compute the size of the result
    n = 0
    for i in range(arr1.shape[0]):
        if mul_xor_hash(arr1[i]) not in delta:
            n += 1
    # : build the result
    result = np.empty((n, arr1.shape[-1]), dtype=arr1.dtype)
    j = 0
    for i in range(arr1.shape[0]):
        if mul_xor_hash(arr1[i]) not in delta:
            result[j] = arr1[i]
            j += 1
    return result

虽然它们都会产生相同的结果:

funcs = setdiff2d_iter, setdiff2d_list, setdiff2d_idx, setdiff2d_set, setdiff2d_bc, setdiff2d_nb

arr1 = np.array([[0, 3], [0, 4], [1, 3], [1, 7]])
print(arr1)
# [[0 3]
#  [0 4]
#  [1 3]
#  [1 7]]

arr2 = np.array([[0, 3], [1, 7], [4, 0]])
print(arr2)
# [[0 3]
#  [1 7]
#  [4 0]]

result = funcs[0](arr1, arr2)
print(result)
# [[0 4]
#  [1 3]]

for func in funcs:
    print(f'{func.__name__:>24s}', np.all(result == func(arr1, arr2)))
#           setdiff2d_iter True
#           setdiff2d_list True
#            setdiff2d_idx True
#            setdiff2d_set False  # because of ordering
#             setdiff2d_bc True
#             setdiff2d_nb True

他们的表现似乎存在变化:

for func in funcs:
    print(f'{func.__name__:>24s}', end='  ')
    %timeit func(arr1, arr2)
#           setdiff2d_iter  16.3 µs ± 719 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
#           setdiff2d_list  14.9 µs ± 528 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
#            setdiff2d_idx  17.8 µs ± 1.75 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)
#            setdiff2d_set  17.5 µs ± 1.31 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)
#             setdiff2d_bc  9.45 µs ± 405 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
#             setdiff2d_nb  1.58 µs ± 51.8 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

所提出的基于Numba的方法似乎比其他方法表现更好(使用给定输入大约快10倍)。

使用更大的输入时观察到类似的时间:

np.random.seed(42)

arr1 = np.random.randint(0, 100, (1000, 2))
arr2 = np.random.randint(0, 100, (1000, 2))
print(setdiff2d_nb(arr1, arr2).shape)
# (736, 2)


for func in funcs:
    print(f'{func.__name__:>24s}', end='  ')
    %timeit func(arr1, arr2)
#           setdiff2d_iter  3.51 ms ± 75.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
#           setdiff2d_list  2.92 ms ± 32.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
#            setdiff2d_idx  2.61 ms ± 38.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
#            setdiff2d_set  3.52 ms ± 67.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
#             setdiff2d_bc  25.6 ms ± 198 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
#             setdiff2d_nb  192 µs ± 1.66 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

作为一个侧面的注释,setdiff2d_bc() 受第二个输入的大小影响最大。

@PedroBzz 不确定是否符合您的性能要求,但请查看基于Numba的快速实现以及它与其他几种方法的比较。 - norok2
1
请注意,哈希表的碰撞处理也应该被实现。 - norok2

1

根据你的数组大小而定。如果它们不是太大(几千个),你可以

  1. 使用广播将x中的每个点与y中的每个点进行比较
  2. 使用any检查最后一个维度上的不等式
  3. 使用all检查匹配情况

代码:

idx = (arr1[:,None]!=arr2).any(-1).all(1)

arr1[idx]

输出:

array([[0, 4],
       [1, 3]])

更新:对于更长的数据,您可以尝试使用set和for循环:

set_y = set(map(tuple, y))
idx = [tuple(point) not in set_y for point in x]

x[idx]

大小差异很大,从一千到两百万不等,我会将其添加到原始帖子中。 - Pedro Bzz
@PedroBzz,如果有帮助,请查看更新的答案。 - Quang Hoang
谢谢!这个可以工作,但是执行时间大约为0.2秒,我正在寻找一个执行时间在0.02以下的东西,这就是我的当前代码所做的。这是一张图片的分析,我会在原始帖子中更详细地说明。 - Pedro Bzz
@PedroBzz 非常出色的表现。也许你可以/应该分享你的代码(这是发布问题时建议的)。此外,如果您不关心arr1中点的顺序,您也可以将其转换为集合,这样会更快。 - Quang Hoang
@QuangHoang 看起来将两个输入转换为 set() 并不是那么快,因为需要进行许多转换。 - norok2

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接