我不确定你想要如何处理额外的维度,因为集合差异(像任何过滤器一样)本质上会丢失形状信息。
无论如何,NumPy提供了np.setdiff1d()
来优雅地解决这个问题。
编辑 根据提供的澄清,您似乎正在寻找一种在给定轴上计算集合差异的方法,即集合的元素实际上是数组。
NumPy 中没有专门针对此问题的内置函数,但制作一个并不太困难。
为了简单起见,我们假设操作轴是第一个轴(因此集合的元素为 arr[i]
),第一个数组中仅出现唯一元素,并且数组为 2D。
它们都基于这样一个思想:渐近最优的方法是构建第二个数组的 set()
,然后使用它来过滤掉第一个数组中的条目。
在 Python / NumPy 中构建这样的 set 的惯用方式是使用:
set(map(tuple, arr))
在将 tuple
映射时,会冻结 arr[i]
,使它们可哈希,从而可以与 set()
一起使用。
不幸的是,由于过滤器可能产生大小不可预测的结果,因此 NumPy 数组并不是理想的容器。
为了解决这个问题,可以使用:
- 一个中间的
list
import numpy as np
def setdiff2d_list(arr1, arr2):
delta = set(map(tuple, arr2))
return np.array([x for x in arr1 if tuple(x) not in delta])
np.fromiter()
后跟np.reshape()
import numpy as np
def setdiff2d_iter(arr1, arr2):
delta = set(map(tuple, arr2))
return np.fromiter((x for xs in arr1 if tuple(xs) not in delta for x in xs), dtype=arr1.dtype).reshape(-1, arr1.shape[-1])
- NumPy的高级索引
def setdiff2d_idx(arr1, arr2):
delta = set(map(tuple, arr2))
idx = [tuple(x) not in delta for x in arr1]
return arr1[idx]
将两个输入都转换为
set()
(这将强制输出元素的唯一性并且会失去排序):
import numpy as np
def setdiff2d_set(arr1, arr2):
set1 = set(map(tuple, arr1))
set2 = set(map(tuple, arr2))
return np.array(list(set1 - set2))
或者,可以使用广播, np.any()
和np.all()
构建高级索引:
def setdiff2d_bc(arr1, arr2):
idx = (arr1[:, None] != arr2).any(-1).all(1)
return arr1[idx]
上述方法的某些形式最初是在@QuangHoang's answer中提出的。
类似的方法也可以在Numba中实现,遵循与上述相同的思路,但使用哈希而不是实际的数组视图arr [i]
(由于Numba中set()
支持的限制),并预先计算输出大小(以提高速度):
import numpy as np
import numba as nb
@nb.njit
def mul_xor_hash(arr, init=65537, k=37):
result = init
for x in arr.view(np.uint64):
result = (result * k) ^ x
return result
@nb.njit
def setdiff2d_nb(arr1, arr2):
delta = {mul_xor_hash(arr2[0])}
for i in range(1, arr2.shape[0]):
delta.add(mul_xor_hash(arr2[i]))
n = 0
for i in range(arr1.shape[0]):
if mul_xor_hash(arr1[i]) not in delta:
n += 1
result = np.empty((n, arr1.shape[-1]), dtype=arr1.dtype)
j = 0
for i in range(arr1.shape[0]):
if mul_xor_hash(arr1[i]) not in delta:
result[j] = arr1[i]
j += 1
return result
虽然它们都会产生相同的结果:
funcs = setdiff2d_iter, setdiff2d_list, setdiff2d_idx, setdiff2d_set, setdiff2d_bc, setdiff2d_nb
arr1 = np.array([[0, 3], [0, 4], [1, 3], [1, 7]])
print(arr1)
arr2 = np.array([[0, 3], [1, 7], [4, 0]])
print(arr2)
result = funcs[0](arr1, arr2)
print(result)
for func in funcs:
print(f'{func.__name__:>24s}', np.all(result == func(arr1, arr2)))
他们的表现似乎存在变化:
for func in funcs:
print(f'{func.__name__:>24s}', end=' ')
%timeit func(arr1, arr2)
# setdiff2d_iter 16.3 µs ± 719 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
# setdiff2d_list 14.9 µs ± 528 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
# setdiff2d_idx 17.8 µs ± 1.75 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)
# setdiff2d_set 17.5 µs ± 1.31 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)
# setdiff2d_bc 9.45 µs ± 405 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
# setdiff2d_nb 1.58 µs ± 51.8 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
所提出的基于Numba的方法似乎比其他方法表现更好(使用给定输入大约快10倍)。
使用更大的输入时观察到类似的时间:
np.random.seed(42)
arr1 = np.random.randint(0, 100, (1000, 2))
arr2 = np.random.randint(0, 100, (1000, 2))
print(setdiff2d_nb(arr1, arr2).shape)
for func in funcs:
print(f'{func.__name__:>24s}', end=' ')
%timeit func(arr1, arr2)
作为一个侧面的注释,
setdiff2d_bc()
受第二个输入的大小影响最大。
ones = np.argwhere(img == 255)
并不能实现你的要求。那段代码中的arr1
和arr2
是什么?话虽如此,我认为你无法对np.argwhere
函数进行太多优化。 - Quang Hoang