最近邻算法中的不确定点问题

3
我有两个2D点集A和B。我想为B中的每个点找到A中最近的邻居。但是,我正在处理不确定的点(即点具有均值(2D向量)和2×2协方差矩阵)。
因此,我想使用Mahalanobis距离,但在scikit-learn(例如)中,我无法为每个点传递协方差矩阵,因为它期望单个协方差矩阵。
目前,只考虑平均位置(即我2D正态分布的平均值),我有:
nearest_neighbors = NearestNeighbors(n_neighbors=1, metric='l2').fit(A)
distance, indices = nearest_neighbors.kneighbors(B)

针对我不确定的点,我更倾向于计算它们的马氏距离(Mahalanobis distance),而不是使用L2范数作为距离,这个计算是在一个点a属于集合A,另一个点b属于集合B时进行的:

d(a, b) = sqrt( transpose(mu_a-mu_b) * C * (mu_a-mu_b))

其中 C = inv(cov_a + cov_b)

其中 mu_a(或 mu_b)和 cov_a(或 cov_b)是不确定点 a(或 b)的二维均值和2*2协方差矩阵。


展示你的代码尝试、输入和期望输出。 - depperm
脑海中唯一的想法是使用6D向量作为输入(用于存储位置和其协方差矩阵的四个分量),并定义自己的距离函数。 - floflo29
2个回答

0

您可以使用列表推导式简单地使用自己的距离函数实现KNN解决方案。以下是使用OpenCV库中内置的Mahalanobis距离实现的示例:

import numpy as np
import cv2

np_gallery=np.array(gallery)
np_query=np.array(query)

K=12

ids=[]

def insertionsort(comp_list):
    for i in range( 1, len(comp_list)):
    tmp = comp_list[i]
    k = min(i,K)
    while k > 0 and tmp[1] < comp_list[k - 1][1]:
        comp_list[k] = comp_list[k - 1]
        k -= 1
    comp_list[k] = tmp

def search():
    for q in np_query:
        c = [(i,cv2.Mahalanobis(q, x, icovar)) for i, x in enumerate(np_gallery)]
        insertionsort(c)
        ids.append(map(lambda tup: tup[0], c[0:K]))

或者

def search():
    for q in np_query:
        c = [(i,cv2.Mahalanobis(q, x, icovar)) for i, x in enumerate(np_gallery)]
        ids.append(map(lambda tup: tup[0], sorted(c, key=lambda tup: tup[1])[0:K]))

在第一种情况下,我使用了一种插入排序的变体,考虑到参数K。当N >> K时,这可能更有效。

0

最终我使用了自定义距离:

def my_mahalanobis_distance(x, y):
    '''
    x: array of shape (4,) x[0]: mu_x_1, x[1]: mu_x_2, 
                            x[2]: cov_x_11, x[3]: cov_x_22
    y: array of shape (4,) y[0]: mu_ y_1, y[1]: mu_y_2,
                            y[2]: cov_y_11, y[3]: cov_y_22 
    '''     



    return sp.spatial.distance.mahalanobis(x[:2], y[:2], 
                                           np.linalg.inv(np.diag(x[2:]) 
                                           + np.diag(y[2:])))

因此一个点有4个特征:

  • xy坐标
  • xy方差(在我的情况下协方差矩阵是对角线)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接