在两个Matlab矩阵之间查找相等的行

4

我在Matlab中有一个大小为GxN的矩阵index和一个大小为MxN的矩阵A

在提出我的问题之前,让我先举个例子。

clear
N=3;
G=2;
M=5;

index=[1  2  3;
       13 14 15]; %GxN

A=[1  2  3; 
   5  6  7; 
   21 22 23; 
   1  2  3;
   13 14 15]; %MxN

我希望您能帮助构建一个大小为GxM的矩阵Response,如果行A(m,:)等于index(g,:),则Response(g,m)=1,否则为零。请参考上面的例子。
Response= [1 0 0 1 0; 
           0 0 0 0 1]; %GxM

这段代码可以满足我的需求(取自于我之前的一个问题,链接:previous question of mine - 仅为澄清:当前的问题是不同的)

Response=permute(any(all(bsxfun(@eq, reshape(index.', N, [], G), permute(A, [2 3 4 1])), 1), 2), [3 4 1 2]);

然而,对于我的真正矩阵大小(N=19, M=500, G=524288),该命令非常缓慢。我知道我无法获得巨大的速度,但任何可以改善这一点的东西都是受欢迎的。


我非常怀疑这个可以改进的空间不大。你可以尝试将代码分解成几部分并计时,而不是使用一行代码。 - Ander Biguri
3个回答

7

MATLAB有众多函数用于处理集合, 包括setdiff, intersect, union等。在这种情况下,可以使用ismember函数:

[~, Loc] = ismember(A,index,'rows');

这将会给出:

Loc =
     1
     0
     0
     1
     2

并且Response将按以下方式构建:

Response = (1:size(index,1) == Loc).';

Response =
  2×5 logical array
   1   0   0   1   0
   0   0   0   0   1


谢谢,但我不知道如何从那里获取“Response”。 - TEX

7

方法一:计算距离

如果您拥有统计工具箱:

Response = ~(pdist2(index, A));

或者:

Response = ~(pdist2(index, A, 'hamming'));

这样做是因为pdist2计算每一对行之间的距离。相等的行之间的距离为0。逻辑非~对于那些行对给出1,否则为0

方法二:将行缩减为唯一整数标签

在我的机器上,这种方法更快:

[~,~,u] = unique([index; A], 'rows');
Response = bsxfun(@eq, u(1:G), u(G+1:end).');

它通过减少行到唯一的整数标签(使用unique的第三个输出)来实现,然后比较后者而不是前者。

对于您的尺寸值,在我的计算机上大约需要1秒钟:

clear
N = 19; M = 500; G = 524288;
index = randi(5,G,N); A = randi(5,M,N);
tic
[~,~,u] = unique([index; A], 'rows');
Response = bsxfun(@eq, u(1:G), u(G+1:end).');
toc

提供

Elapsed time is 1.081043 seconds.

findgroups 可能比使用 unique 的第三个输出更快。过去没有进行速度测试,但认为它执行相同的操作。 - Wolfie
@Wolfie 我认为 findgroups 要求分组变量是向量。因此,这里的矩阵必须被拆分成它们的列,这需要时间。另外,findgroups 内部使用 unique 的第三个输出,所以我怀疑它并不更快。 - Luis Mendo
啊,那就不用理我了,我没意识到这两个是相互关联的!第二个选项比我的重塑方法快得多,这让我很惊讶。 - Wolfie
@Wolfie 是的,就OP所提供的尺寸而言,你的方法在我的电脑上比我第二种方法慢了5倍。我想当尺寸很大时,即使需要时间,尽早减少一个维度(如我第二种方法)也是有益的。 - Luis Mendo

3

您可以重新构造矩阵,使每行代表第三个维度。然后,我们可以使用隐式扩展 (查看 R2016b 或更早版本的 bsxfun) 来判断所有元素是否相等,并使用 all 在行上进行聚合 (即对于给定的行,如果不全部都相等,则为 false)。

Response = all( reshape( index, [], 1, size(index,2) ) == reshape( A, 1, [], size(A,2) ), 3 ); 

你甚至可以通过在另一个维度中使用all来避免一些重塑,但对我来说,用这种方式更容易形象化。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接