寻找K个最近邻居及其实现方式

18

我正在使用KNN和欧几里得距离对简单数据进行分类。我已经看到一个例子,展示了如何使用MATLAB的knnsearch函数来完成我想做的事情:

load fisheriris 
x = meas(:,3:4);
gscatter(x(:,1),x(:,2),species)
newpoint = [5 1.45];
[n,d] = knnsearch(x,newpoint,'k',10);
line(x(n,1),x(n,2),'color',[.5 .5 .5],'marker','o','linestyle','none','markersize',10)

上面的代码使用一个新的点,即 [5 1.45],并找到距离该新点最近的10个值。请问有人可以展示一份MATLAB算法,并详细解释 knnsearch 函数是做什么的吗?还有其他方法可以实现这个功能吗?


很简单。对于一个特定的点,我们找到数据和这个点之间最近的10个点,并返回那些属于您的数据的最近点。通常使用欧几里得距离,其中一个点的组件用于比较另一个点的组件。维基百科上的这篇文章特别有用:http://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm - rayryeng
哦...你想自己实现这个过程吗?我可以为你提供答案。实际上,实现算法并不像你想象的那么难。请说明你需要什么。 - rayryeng
是的,我正在尝试自己实现“knnsearch”函数,就像我的代码示例一样。谢谢! - Young_DataAnalyst
没问题。我一会儿就为您写一个答案。我现在没有MATLAB来测试我的代码。等我有了,我会给您写一个答案。但是,为了让您开始,基本的步骤是找到测试点与数据矩阵中所有其他点之间的欧几里得距离。您将距离从小到大排序,然后选择产生最小距离的“k”个点。很快就会回来给您答案! - rayryeng
你好rayryeng,只是想澄清一下:在这种情况下,我的测试点是newpoint = [5 1.45],对吗?现在我将计算我的数据中其他点与该点之间的欧几里得距离,因此x = meas(:,3:4); fisheriris数据是Matlab示例数据,请加载并查看一下是否有机会。谢谢! - Young_DataAnalyst
完成了。看一下吧。祝你好运! - rayryeng
1个回答

40
K-最近邻(KNN)算法的基础是您有一个数据矩阵,它由N行和M列组成,其中N是我们拥有的数据点数,而M是每个数据点的维度。例如,如果我们将笛卡尔坐标放入数据矩阵中,这通常是一个N x 2N x 3矩阵。使用此数据矩阵,您提供一个查询点,并在该数据矩阵中搜索距离该查询点最近的k个点。
通常,我们使用查询点与数据矩阵中其余点之间的欧几里得距离来计算距离。但是,也会使用其他距离,如L1或城市块/曼哈顿距离。完成此操作后,您将拥有N个欧几里得或曼哈顿距离,这些距离表示查询与数据集中每个对应点之间的距离。找到这些距离后,您只需按升序排序距离并检索那些与数据集和查询之间距离最小的k个点即可。
假设您的数据矩阵存储在x中,而newpoint是一个样本点,其中它有M列(即1 x M),这是您将遵循的一般过程:
  1. 计算newpointx中每个点之间的欧几里得或曼哈顿距离。
  2. 按升序对这些距离进行排序。
  3. 返回x中最接近newpointk个数据点。
让我们慢慢地完成每个步骤。

步骤 #1

有一种方法可以使用 for 循环来实现:

N = size(x,1);
dists = zeros(N,1);
for idx = 1 : N
    dists(idx) = sqrt(sum((x(idx,:) - newpoint).^2));
end

如果你想要实现曼哈顿距离,这将非常简单:
N = size(x,1);
dists = zeros(N,1);
for idx = 1 : N
    dists(idx) = sum(abs(x(idx,:) - newpoint));
end
将是一个包含x中每个数据点与newpoint之间距离的N元素向量。我们对newpointx中的数据点进行逐元素相减,平方差值,然后将它们全部sum在一起。然后对这个总和进行平方根运算,就完成了欧几里得距离的计算。对于曼哈顿距离,您需要执行逐元素相减,取绝对值,然后将所有分量相加。这可能是最简单易懂的实现方法,但对于较大的数据集和更高维度的数据来说,可能效率不高...

另一个可能的解决方案是复制newpoint并使该矩阵与x具有相同的大小,然后对此矩阵进行逐元素相减,然后对每行进行列求和并进行平方根运算。因此,我们可以这样做:

N = size(x, 1);
dists = sqrt(sum((x - repmat(newpoint, N, 1)).^2, 2));

对于曼哈顿距离,您将执行以下操作:

N = size(x, 1);
dists = sum(abs(x - repmat(newpoint, N, 1)), 2);

repmat接受矩阵或向量,并在给定方向上重复它们一定次数。在我们的情况下,我们想要取出我们的newpoint向量,并将其在顶部重复N次,以创建一个N x M矩阵,其中每行长度为M。我们将这两个矩阵相减,然后平方每个元素。完成后,我们对每行的所有列进行sum,最后对所有结果取平方根。对于曼哈顿距离,我们执行减法,取绝对值,然后求和。

然而,在我看来,最有效的方法是使用bsxfun。这实际上是在单个函数调用中在内部执行了我们谈论过的复制操作。因此,代码将简单地如下所示:

dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));

对我来说,这看起来更加简洁明了。对于曼哈顿距离,您可以执行以下操作:
dists = sum(abs(bsxfun(@minus, x, newpoint)), 2);

步骤 #2

现在我们有了距离,我们只需要对它们进行排序。我们可以使用 sort 来对距离进行排序:

[d,ind] = sort(dists);

d会按升序包含距离,而ind告诉您在未排序数组中每个值在排序结果中出现的位置。我们需要使用ind,提取该向量的前k个元素,然后使用ind索引到我们的x数据矩阵中返回那些最接近newpoint的点。

第3步

最后一步是现在返回那些最接近newpointk数据点。我们可以通过以下方法非常简单地完成:

ind_closest = ind(1:k);
x_closest = x(ind_closest,:);

ind_closest应该包含原始数据矩阵x中最接近newpoint的索引。具体而言,ind_closest包含你需要从x中抽样的哪些以获得最接近newpoint的点。x_closest将包含这些实际数据点。


为了方便复制粘贴,以下是代码的样子:

dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
%// Or do this for Manhattan
% dists = sum(abs(bsxfun(@minus, x, newpoint)), 2);
[d,ind] = sort(dists);
ind_closest = ind(1:k);
x_closest = x(ind_closest,:);

运行您的示例,让我们看看我们的代码实际效果:

load fisheriris 
x = meas(:,3:4);
newpoint = [5 1.45];
k = 10;

%// Use Euclidean
dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
[d,ind] = sort(dists);
ind_closest = ind(1:k);
x_closest = x(ind_closest,:);

通过检查ind_closestx_closest,我们得到以下结果:

>> ind_closest

ind_closest =

   120
    53
    73
   134
    84
    77
    78
    51
    64
    87

>> x_closest

x_closest =

    5.0000    1.5000
    4.9000    1.5000
    4.9000    1.5000
    5.1000    1.5000
    5.1000    1.6000
    4.8000    1.4000
    5.0000    1.7000
    4.7000    1.4000
    4.7000    1.4000
    4.7000    1.5000

如果你运行了knnsearch函数,你会发现变量nind_closest匹配。然而,变量d返回的是从newpoint到每个点x距离,而不是实际的数据点本身。如果你想要实际的距离,只需在我写的代码后面执行以下操作:
dist_sorted = d(1:k);

请注意,上面的答案仅在一批N示例中使用一个查询点。很频繁地,KNN会同时用于多个示例。假设我们有Q个查询点要在KNN中进行测试。这将导致一个k x M x Q矩阵,在每个示例或每个切片中,我们返回具有M维度的k个最接近的点。或者,我们可以返回k个最接近点的ID,从而得到一个Q x k矩阵。让我们计算两者。
一种天真的方法是在循环中应用上述代码并循环遍历每个示例。
像这样的东西可以工作,我们分配一个Q x k矩阵,并应用基于bsxfun的方法来将输出矩阵的每一行设置为数据集中k个最接近的点,其中我们将使用与之前相同的Fisher Iris数据集。我们还将保持与之前示例中相同的维度,并使用四个示例,因此Q = 4M = 2
%// Load the data and create the query points
load fisheriris;
x = meas(:,3:4);
newpoints = [5 1.45; 7 2; 4 2.5; 2 3.5];

%// Define k and the output matrices
Q = size(newpoints, 1);
M = size(x, 2);
k = 10;
x_closest = zeros(k, M, Q);
ind_closest = zeros(Q, k);

%// Loop through each point and do logic as seen above:
for ii = 1 : Q
    %// Get the point
    newpoint = newpoints(ii, :);

    %// Use Euclidean
    dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
    [d,ind] = sort(dists);

    %// New - Output the IDs of the match as well as the points themselves
    ind_closest(ii, :) = ind(1 : k).';
    x_closest(:, :, ii) = x(ind_closest(ii, :), :);
end

尽管这很好,但我们可以做得更好。有一种方法可以高效地计算两组向量之间的平方欧几里得距离。如果您想使用曼哈顿距离,请参考此博客。假设A是一个Q1 x M矩阵,其中每一行都是一个维度为M的点,具有Q1个点,并且B是一个Q2 x M矩阵,其中每一行也是一个维度为M的点,具有Q2个点,我们可以使用以下矩阵公式高效地计算距离矩阵D(i, j),其中第i行和第j列的元素表示A的第i行和B的第j行之间的距离:
nA = sum(A.^2, 2); %// Sum of squares for each row of A
nB = sum(B.^2, 2); %// Sum of squares for each row of B
D = bsxfun(@plus, nA, nB.') - 2*A*B.'; %// Compute distance matrix
D = sqrt(D); %// Compute square root to complete calculation

因此,如果我们让A成为查询点的矩阵,B成为由原始数据组成的数据集,我们可以通过逐行排序并确定每行中最小的k个位置来确定k个最近的点。我们还可以使用这种方法检索实际的数据点。
%// Load the data and create the query points
load fisheriris;
x = meas(:,3:4);
newpoints = [5 1.45; 7 2; 4 2.5; 2 3.5];

%// Define k and other variables
k = 10;
Q = size(newpoints, 1);
M = size(x, 2);

nA = sum(newpoints.^2, 2); %// Sum of squares for each row of A
nB = sum(x.^2, 2); %// Sum of squares for each row of B
D = bsxfun(@plus, nA, nB.') - 2*newpoints*x.'; %// Compute distance matrix
D = sqrt(D); %// Compute square root to complete calculation 

%// Sort the distances 
[d, ind] = sort(D, 2);

%// Get the indices of the closest distances
ind_closest = ind(:, 1:k);

%// Also get the nearest points
x_closest = permute(reshape(x(ind_closest(:), :).', M, k, []), [2 1 3]);

我们看到,用于计算距离矩阵的逻辑相同,但一些变量已更改以适应示例。我们还使用两个输入版本的sort对每行进行独立排序,因此ind将包含每行的ID,d将包含相应的距离。然后,我们通过简单地将此矩阵截断为k列来确定哪些索引最接近每个查询点。然后,我们使用permutereshape确定相关联的最接近点是什么。我们首先使用所有最接近的索引,并创建一个点矩阵,将所有ID堆叠在一起,以便我们获得一个Q * k x M矩阵。使用reshapepermute允许我们创建我们的3D矩阵,使其成为我们指定的k x M x Q矩阵。如果您想获取实际距离本身,我们可以索引d并获取所需内容。要做到这一点,您需要使用sub2ind获取线性索引,以便我们可以一次性地索引到d中。ind_closest的值已经给出了我们需要访问哪些列。我们需要访问的行只是1,k次,2,k次,等等,直到Qk是我们想要返回的点数:
row_indices = repmat((1:Q).', 1, k);
linear_ind = sub2ind(size(d), row_indices, ind_closest);
dist_sorted = D(linear_ind);

当我们运行上述代码来处理以上查询点时,我们获得的是以下索引、点和距离:
>> ind_closest

ind_closest =

   120   134    53    73    84    77    78    51    64    87
   123   119   118   106   132   108   131   136   126   110
   107    62    86   122    71   127   139   115    60    52
    99    65    58    94    60    61    80    44    54    72

>> x_closest

x_closest(:,:,1) =

    5.0000    1.5000
    6.7000    2.0000
    4.5000    1.7000
    3.0000    1.1000
    5.1000    1.5000
    6.9000    2.3000
    4.2000    1.5000
    3.6000    1.3000
    4.9000    1.5000
    6.7000    2.2000


x_closest(:,:,2) =

    4.5000    1.6000
    3.3000    1.0000
    4.9000    1.5000
    6.6000    2.1000
    4.9000    2.0000
    3.3000    1.0000
    5.1000    1.6000
    6.4000    2.0000
    4.8000    1.8000
    3.9000    1.4000


x_closest(:,:,3) =

    4.8000    1.4000
    6.3000    1.8000
    4.8000    1.8000
    3.5000    1.0000
    5.0000    1.7000
    6.1000    1.9000
    4.8000    1.8000
    3.5000    1.0000
    4.7000    1.4000
    6.1000    2.3000


x_closest(:,:,4) =

    5.1000    2.4000
    1.6000    0.6000
    4.7000    1.4000
    6.0000    1.8000
    3.9000    1.4000
    4.0000    1.3000
    4.7000    1.5000
    6.1000    2.5000
    4.5000    1.5000
    4.0000    1.3000

>> dist_sorted

dist_sorted =

    0.0500    0.1118    0.1118    0.1118    0.1803    0.2062    0.2500    0.3041    0.3041    0.3041
    0.3000    0.3162    0.3606    0.4123    0.6000    0.7280    0.9055    0.9487    1.0198    1.0296
    0.9434    1.0198    1.0296    1.0296    1.0630    1.0630    1.0630    1.1045    1.1045    1.1180
    2.6000    2.7203    2.8178    2.8178    2.8320    2.9155    2.9155    2.9275    2.9732    2.9732

与使用 knnsearch 进行比较,您需要指定一个点矩阵作为第二个参数,其中每一行都是查询点,并且您会发现此实现和 knnsearch 之间的索引和排序距离匹配。
希望这能帮到你。祝好运!

@年轻数据分析师 - 很高兴能够帮到你!如果我的回答有所帮助,请考虑接受我的答案:)。祝好运! - rayryeng
@Kamtal - 很酷 :)!我很高兴! - rayryeng
@rayryeng 如果x和newpoints都对称于同一对称平面,则knnsearch可能会返回该平面上的不对称索引。有没有一种实现对称性的方法? - June Wang
@rayryeng 这是一个粗略的草图。https://ibb.co/g4b0YD7 例如,#1(x,y,z)沿yz对称面镜像点是新点上的#59(-x,y,z)。#88(x1,y1,z1)沿yz平面镜像点为#69(-x1,y1,z1)。如果#88在新点上最近的点是#1,则我希望#69的最近点是#59。但是knnsearch()可能会返回B作为#69的最近点。是否有一种方法可以在搜索期间保持这种对称性呢? - June Wang
@rayryeng 我刚刚运行了几个测试,发现结果与点的搜索顺序有关。如果我在对称平面上重新排序源点索引,KNN就会返回对称性,否则不会。不确定为什么。谢谢。 - June Wang
显示剩余3条评论

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接