Matlab中的高效分类算法

7
我有一张大小为RGB uint8(576,720,3)的图像,我想将每个像素分类到一组颜色中。我使用rgb2lab从RGB转换为LAB空间,然后删除了L层,因此现在它是一个double(576,720,2),由AB组成。
现在,我想将其分类到另一张图像上训练的某些颜色,并计算它们各自的AB表示:
Cluster 1: -17.7903  -13.1170
Cluster 2: -30.1957   40.3520
Cluster 3:  -4.4608   47.2543
Cluster 4:  46.3738   36.5225
Cluster 5:  43.3134  -17.6443
Cluster 6:  -0.9003    1.4042
Cluster 7:   7.3884   11.5584

现在,为了将每个像素分类/标记到1-7个聚类中,我目前执行以下操作(伪代码):
clusters;
for each x
  for each y
    ab = im(x,y,2:3);
    dist = norm(ab - clusters); // norm of dist between ab and each cluster
    [~, idx] = min(dist);
  end
end

然而,由于图像分辨率和我手动循环每个x和y,这非常慢(52秒)。
是否有一些内置函数可以执行相同的工作?肯定有。
总结一下:我需要一种分类方法,将像素图像分类到已定义的一组聚类中。

@Divakar 是的,这实际上非常有趣。我的第一次尝试:52秒。 我的第一次尝试,但迁移到使用并行计算(4个池):10秒。 方法#1:0.06秒。 相当惊人。 - casparjespersen
使用第二种方法,你有机会尝试过吗?对于这些数字我可能有点过于兴奋了,抱歉打扰了。 - Divakar
呵呵,太酷了 :) 我真的很享受矩阵编程甚至比并行计算表现更好的事实!有没有什么限制,使得矩阵可以有多大才能保持在内存中?我还没有尝试过第二种方法,但我可以在今天晚些时候去尝试一下。 - casparjespersen
使用第一种方法,您很快就会达到内存带宽限制,但是使用第二种方法,对于大数据大小,它应该能够更好地保持。如果您想测试第二种方法的运行时比较,特别是针对大数据大小,请告诉我!顺便说一下,在矩阵编程中实现了这种速度提升的神奇技巧,我们称之为向量化,其中最通用的工具是bsxfun - Divakar
@Divakar方法1的时间范围为0.06-0.09秒,而方法2的时间范围为0.04-0.06秒。因此,方法2更快一些。 - casparjespersen
显示剩余2条评论
2个回答

11

方法一

对于一个大小为 N x 2 的点/像素数组,您可以避免像Luis的其他解决方案建议的使用permute,这可能会使事情变得有点慢,而是采用一种"展开排列"版本,同时让bsxfun朝向一个2D数组而不是一个3D数组,这样的性能肯定更好。

因此,假设聚类被排序为一个N x 2大小的数组,您可以尝试这种基于bsxfun的其他方法 -

%// Get a's and b's
im_a = im(:,:,2);
im_b = im(:,:,3);

%// Get the minimum indices that correspond to the cluster IDs
[~,idx]  = min(bsxfun(@minus,im_a(:),clusters(:,1).').^2 + ...
    bsxfun(@minus,im_b(:),clusters(:,2).').^2,[],2);
idx = reshape(idx,size(im,1),[]);

方法 #2

您可以尝试另一种方法,利用MATLAB中的快速矩阵乘法,并基于这个聪明的解决方案 -

d = 2; %// dimension of the problem size

im23 = reshape(im(:,:,2:3),[],2);

numA = size(im23,1);
numB = size(clusters,1);

A_ext = zeros(numA,3*d);
B_ext = zeros(numB,3*d);
for id = 1:d
    A_ext(:,3*id-2:3*id) = [ones(numA,1), -2*im23(:,id), im23(:,id).^2 ];
    B_ext(:,3*id-2:3*id) = [clusters(:,id).^2 ,  clusters(:,id), ones(numB,1)];
end
[~, idx] = min(A_ext * B_ext',[],2); %//'
idx = reshape(idx, size(im,1),[]); %// Desired IDs

矩阵乘法计算距离矩阵有什么作用?

考虑两个矩阵 A 和 B,我们想要计算它们之间的距离矩阵。为了方便下面的解释,假设 A 是一个 3 x 2 的矩阵,B 是一个 4 x 2 的矩阵,这表示我们在处理 X-Y 点。如果 A 是一个 N x 3 的矩阵,B 是一个 M x 3 的矩阵,则它们是 X-Y-Z 点。

现在,如果我们需要手动计算距离矩阵的第一个元素,它应该如下所示:

first_element = ( A(1,1) – B(1,1) )^2 + ( A(1,2) – B(1,2) )^2         

这将是 -

first_element = A(1,1)^2 + B(1,1)^2 -2*A(1,1)* B(1,1)   +  ...
                A(1,2)^2 + B(1,2)^2 -2*A(1,2)* B(1,2)    … Equation  (1)

现在,根据我们提出的矩阵乘法,如果你在早期代码的循环结束后检查 A_extB_ext 的输出,它们将如下所示 -

enter image description here

enter image description here

因此,如果您在A_extB_ext的转置之间执行矩阵乘法,则产品的第一个元素将是A_extB_ext的第一行之间逐元素相乘的总和,即这些的总和 -

enter image description here

这个结果与早先的“方程式(1)”获得的结果完全相同。对于所有元素的A与与A在同一列中的B的所有元素进行计算,这将持续进行。因此,我们将得到完整的平方距离矩阵。就是这样了!

向量化变体

基于矩阵乘法的距离矩阵计算的向量化变体是可能的,但是它们并没有表现出任何大的性能改进。下面列出了两种这样的变化。

变体#1

[nA,dim] = size(A);
nB = size(B,1);

A_ext = ones(nA,dim*3);
A_ext(:,2:3:end) = -2*A;
A_ext(:,3:3:end) = A.^2;

B_ext = ones(nB,dim*3);
B_ext(:,1:3:end) = B.^2;
B_ext(:,2:3:end) = B;

distmat = A_ext * B_ext.';

变体#2

[nA,dim] = size(A);
nB = size(B,1);

A_ext = [ones(nA*dim,1) -2*A(:) A(:).^2];
B_ext = [B(:).^2 B(:) ones(nB*dim,1)];

A_ext = reshape(permute(reshape(A_ext,nA,dim,[]),[1 3 2]),nA,[]);
B_ext = reshape(permute(reshape(B_ext,nB,dim,[]),[1 3 2]),nB,[]);

distmat = A_ext * B_ext.';

因此,这些也可以被视为实验版本。


抱歉我的线性代数有点生疏。我希望你能创建一个详细的智能解决方案说明,因为你一直在发布它,而我并没有完全理解。特别是 HelpAHelpBhelpA * helpB'。为什么要做 ones(numA,1)?为什么在 -2*im23(:,id) 中使用 -2?为什么按那个顺序创建 helpAHelpB 的值? helpA * helpB' 的目的是什么? - kkuilla
@kkuilla 请看一下编辑中的解释是否有意义? - Divakar
@Divakar 对 bsxfun 的改进很不错,回答也非常详细!+1 - Luis Mendo
哦,太好了。谢谢你。+50 :-)。我没有看到你用那个我记不住名字的规则重新写了(A(1,1)-B(1,1))^2 +((A(1,2)-B(1,2)^2)。非常好的解释。 - kkuilla
1
@kkuilla 很有趣的是,即使我也忘记了它的名字,小学已经过去很久了,但我认为它可能被命名为减法平方的扩展 :) - Divakar

4
使用pdist2(统计工具箱)以向量化的方式计算距离:
ab = im(:,:,2:3);                              % // get A, B components
ab = reshape(ab, [size(im,1)*size(im,2) 2]);   % // reshape into 2-column
dist = pdist2(clusters, ab);                   % // compute distances
[~, idx] = min(dist);                          % // find minimizer for each pixel
idx = reshape(idx, size(im,1), size(im,2));    % // reshape result

如果您没有统计工具箱,您可以将第三行替换为
dist = squeeze(sum(bsxfun(@minus, clusters, permute(ab, [3 2 1])).^2, 2));

这会给出平方距离而不是距离,但对于最小化而言并不重要。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接