在2D数组中使用NumPy的argsort和take函数

4
目标是计算两组点(set1set2)之间的距离矩阵,使用argsort()获取排序后的索引,然后用take()提取排序后的数组。虽然我知道可以直接使用sort(),但我需要这些索引来进行下一步操作。
我正在使用这里讨论的精细索引概念。我不能直接使用获得的索引矩阵进行take(),但是为每一行添加相应数量使其起作用,因为take()会展开源数组,使第二行元素的索引+=len(set2),第三行索引+=2*len(set2)等(见下文):
dist  = np.subtract.outer( set1[:,0], set2[:,0] )**2
dist += np.subtract.outer( set1[:,1], set2[:,1] )**2
dist += np.subtract.outer( set1[:,2], set2[:,2] )**2
a = np.argsort( dist, axis=1 )
a += np.array([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
               [10, 10, 10, 10, 10, 10, 10, 10, 10, 10],
               [20, 20, 20, 20, 20, 20, 20, 20, 20, 20],
               [30, 30, 30, 30, 30, 30, 30, 30, 30, 30]])
s1 = np.sort(dist,axis=1)
s2 = np.take(dist,a)
np.nonzero((s1-s2)) == False
#True # meaning that it works...

主要问题是:是否有一种直接使用take()而不需要对这些索引求和的方法?
可用数据如下:
set1 = np.array([[ 250., 0.,    0.],
                 [ 250., 0.,  510.],
                 [-250., 0.,    0.],
                 [-250., 0.,    0.]])

set2 = np.array([[  61.0, 243.1, 8.3],
                 [ -43.6, 246.8, 8.4],
                 [ 102.5, 228.8, 8.4],
                 [  69.5, 240.9, 8.4],
                 [ 133.4, 212.2, 8.4],
                 [ -52.3, 245.1, 8.4],
                 [-125.8, 216.8, 8.5],
                 [-154.9, 197.1, 8.6],
                 [  61.0, 243.1, 8.7],
                 [ -26.2, 249.3, 8.7]])

其他相关问题:

- 两个不同的Numpy数组中点之间的欧几里得距离,而不是在同一数组内

2个回答

4

我认为使用np.take没有办法不转换为平坦索引。由于维度很可能会改变,因此最好使用np.ravel_multi_index来实现这一点,做如下操作:

a = np.argsort(dist, axis=1)
a = np.ravel_multi_index((np.arange(dist.shape[0])[:, None], a), dims=dist.shape)

另外,您也可以使用精确索引而不使用 take

s2 = dist[np.arange(4)[:, None], a]

1
截至2018年5月,存在np.take_along_axis
s2 = np.take_along_axis(dist, a, axis=1)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接