从pandas dataframe中获取N个最小距离对

3
考虑以下代码,该代码从标记坐标的列表生成距离矩阵:
import numpy as np
import pandas as pd
from scipy.spatial.distance import pdist, squareform

coord_data = [
    [1, 2],
    [4, 3],
    [5, 8],
    [6, 7],
]

df = pd.DataFrame(coord_data, index=list('ABCD'))

dist_matrix = squareform(pdist(df, metric='euclidean'))
dist_df = pd.DataFrame(dist_matrix, index=df.index, columns=df.index)

print(dist_df)

          A         B         C         D
A  0.000000  3.162278  7.211103  7.071068
B  3.162278  0.000000  5.099020  4.472136
C  7.211103  5.099020  0.000000  1.414214
D  7.071068  4.472136  1.414214  0.000000

有没有一种高效的方法(使用numpy,pandas等)从这个距离矩阵中获取N个最小距离对?

例如,如果N=2,则希望得到类似于以下示例的输出:

[['C', 'D'], ['A', 'B']] # corresponding to minimum distances [1.414214, 3.162278]

1
我会使用melt(又称为unpivot)函数并按距离排序。 - Laurent R
我试图使用你建议的方法,但在我能够操作之前就出现了一个答案。不过,感谢你提供有用的信息,我学到了新东西。 - Enigma Machine
2个回答

4

以下是使用 np.argpartition 进行性能优化的示例代码 -

def topN_index_columns_from_symmmdist(df, N):
    a = dist_df.to_numpy(copy=True)
    a[np.tri(len(a), dtype=bool)] = np.inf
    idx = np.argpartition(a.ravel(),range(N))[:N]
    r,c = np.unravel_index(idx, a.shape)
    return list(zip(dist_df.index[r], dist_df.columns[c]))

示例运行 -

In [43]: dist_df
Out[43]: 
          A         B         C         D
A  0.000000  3.162278  7.211103  7.071068
B  3.162278  0.000000  5.099020  4.472136
C  7.211103  5.099020  0.000000  1.414214
D  7.071068  4.472136  1.414214  0.000000

In [44]: topN_index_columns_from_symmmdist(df, N=2)
Out[44]: [('C', 'D'), ('A', 'B')]

In [45]: topN_index_columns_from_symmmdist(df, N=4)
Out[45]: [('C', 'D'), ('A', 'B'), ('B', 'D'), ('B', 'C')]

1

为了完整起见,这里提供另一种使用pandas的答案,基于Laurent R的评论。但最终我还是采用了Divakar的解决方案。

def topN_index_columns_from_symmmdist2(dist_df, N):
    dist_df = pd.melt(dist_df.reset_index(), id_vars="index")
    dist_df = dist_df.rename(columns={"index": "start", "variable": "end"})
    dist_df = dist_df.sort_values("value")
    dist_df = dist_df.drop_duplicates(subset=["value"], keep="last")
    dist_pair_list = dist_df.iloc[1:N+1, :2].values.tolist()
    return dist_pair_list

示例输出:

print(topN_index_columns_from_symmmdist2(dist_df, 2))
print(topN_index_columns_from_symmmdist2(dist_df, 4))

[['C', 'D'], ['A', 'B']]
[['C', 'D'], ['A', 'B'], ['B', 'D'], ['B', 'C']]

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接