对于一个纯numpy的解决方案,您可以像这样操作:
Use np.unique
to get the unique values and corresponding indices in x
and y
separately:
u_x, u_idx_x = np.unique(x, return_index=True)
u_y, u_idx_y = np.unique(y, return_index=True)
Find the intersection of the unique values using np.intersect1d
:
i_xy = np.intersect1d(u_x, u_y, assume_unique=True)
Finally, use np.in1d
to select only the indices that correspond to unique values in x
or y
that also happen to be in the intersection of x
and y
:
i_idx_x = u_idx_x[np.in1d(u_x, i_xy, assume_unique=True)]
i_idx_y = u_idx_y[np.in1d(u_y, i_xy, assume_unique=True)]
将所有内容整合到一个函数中,代码如下:
def intersect_indices(x, y):
u_x, u_idx_x = np.unique(x, return_index=True)
u_y, u_idx_y = np.unique(y, return_index=True)
i_xy = np.intersect1d(u_x, u_y, assume_unique=True)
i_idx_x = u_idx_x[np.in1d(u_x, i_xy, assume_unique=True)]
i_idx_y = u_idx_y[np.in1d(u_y, i_xy, assume_unique=True)]
return i_idx_x, i_idx_y
例如:
x = np.array([4, 1, 10, 5, 8, 13, 11])
y = np.array([20, 5, 4, 9, 11, 7, 25])
i_idx_x, i_idx_y = intersect_indices(x, y)
print(i_idx_x, i_idx_y)
速度测试:
In [1]: k = 1000000
In [2]: %%timeit x, y = np.random.randint(k, size=(2, k))
intersect_indices(x, y)
....:
1 loops, best of 3: 597 ms per loop
更新:
我最初忽略了您的情况中 x
和 y
均仅包含唯一值这一事实。考虑到这一点,可以通过使用间接排序来稍微提高效率:
def intersect_indices_unique(x, y):
u_idx_x = np.argsort(x)
u_idx_y = np.argsort(y)
i_xy = np.intersect1d(x, y, assume_unique=True)
i_idx_x = u_idx_x[x[u_idx_x].searchsorted(i_xy)]
i_idx_y = u_idx_y[y[u_idx_y].searchsorted(i_xy)]
return i_idx_x, i_idx_y
这是一个更为真实的测试案例,其中
x
和
y
都包含独特(但部分重叠)的值:
In [1]: n, k = 10000000, 1000000
In [2]: %%timeit x, y = (np.random.choice(n, size=k, replace=False) for _ in range(2))
intersect_indices(x, y)
....:
1 loops, best of 3: 593 ms per loop
In [3]: %%timeit x, y = (np.random.choice(n, size=k, replace=False) for _ in range(2))
intersect_indices_unique(x, y)
....:
1 loops, best of 3: 453 ms per loop
@Divakar的解决方案在性能方面非常相似:
In [4]: %%timeit x, y = (np.random.choice(n, size=k, replace=False) for _ in range(2))
searchsorted_based(x, y)
....:
1 loops, best of 3: 472 ms per loop