Matlab:对矩阵中每一行进行Argmax和点积

3
我有两个矩阵,X在R^(n*m)W在R^(k*m),其中k<<n
x_i是X的第i行,w_j是W的第j行。我需要为每个x_i找到最大化<w_j,x_i>的j值。
我无法想到遍历X中所有行的方法,但是否有一种方法可以不用每次遍历整个W就能找到最大的点积?
一个朴素的实现方式如下:
n = 100;
m = 50;
k = 10;
X = rand(n,m);
W = rand(k,m);
Y = zeros(n, 1);

for i = 1 : n
  max_ind = 1;
  max_val = dot(W(1,:), X(i,:));
  for j = 2 : k
       cur_val = dot(W(j,:),X(i,:));

       if cur_val > max_val
          max_val = cur_val;
          max_ind = j;
       end

   end

   Y(i,:) = max_ind;
end

请与我们分享迭代代码?也可以添加一个示例案例吗? - Divakar
我已经添加了我能想到的朴素实现。 - ginge
1
+1 用于可重现的示例 - Luis Mendo
2个回答

2

使用基于bsxfun的方法可以加速您的IT技术处理。

[~,Y] = max(sum(bsxfun(@times,X,permute(W,[3 2 1])),2),[],3)

在我的系统上,使用您的数据集,我使用这种方法得到了100倍以上的加速。


可以考虑另外两种“相邻”的方法,但它们似乎并没有比前面的方法有更大的改进 -

[~,Y] = max(squeeze(sum(bsxfun(@times,X,permute(W,[3 2 1])),2)),[],2)

并且

[~,Y] = max(squeeze(sum(bsxfun(@times,X',permute(W,[2 3 1]))))')

2

点积本质上是矩阵乘法:

[~, Y] = max(W*X');

@Divakar 谢谢 :-) 你的 bsxfun 也很好,而且花费的时间大致相同 (+1),非常感谢。 - Luis Mendo
哦,我怀疑这一点,特别是对于大数据集,bsxfun 的速度会变慢。 - Divakar

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接