在Julia中,使用多维数组广播一个二维数组

3

我有一个大小为m x n的2D矩阵A,其中n可能非常大(例如n>10000),以及一个大小为m x m x n的多维矩阵B。因此,针对A中的每一列i,我想计算A[:,i]'*B[:,:,i]。以下是我尝试使用Julia中的广播特性编写的代码。但是,我的代码性能相当慢。我想知道是否可以改进我的代码性能。请问是否有人有改进代码性能的想法?

using LinearAlgebra;
m = 500;
n = 20000; # this could be a very large number.

vecA = rand(m,n);
matB = rand(m,m,n);


combinedAB = Array{Array{Float64,2},2}(undef,n,1);
for ii in eachindex(combinedAB)
  combinedAB[ii] = [vecA[:,ii] matB[:,:,ii]];
end

# this is the result.
res = broadcast(eAB -> dotProd(eAB), combinedAB);

function dotProd(matZ::Array{Float64,2})
   return sum(broadcast(dot,matZ[:,1],matZ[:,2:end]),dims=1);
end
1个回答

4

这在您的情况下速度足够快吗?

res = [a'*b for (a, b) in zip(eachcol(vecA), eachslice(matB, dims=3))]

我没有足够的RAM来测试您的输入值,但是根据我对较小数据进行的测试,它应该在约3秒内运行。
我还假设您真的想要计算a的伴随(这是您在问题中写下的内容;如果您使用实数,则使用'或transpose不应该有影响)
代码之间的关键区别(在我的解决方案中,它被隐藏在引擎盖下,因为它更短)在于我的解决方案不会分配中间数组,而是使用视图。

谢谢您关注我的问题。确实,您的解决方案非常快且简短!但我们能进一步改进它吗?我正在使用真实数据,应该用transpose代替'。感谢您的解释。 - nhavt
实际上,伴随和转置是相同的,不应该有区别。只有当你使用复数时(因为虚部会被不同地处理),它才会有所不同。最快的方法可能是三重嵌套循环,使用 @inbounds@simd 和精心设计的遍历来避免 CPU 缓存未命中。但我认为所提出的解决方案速度要么相同,要么几乎相同(即速度数量级相同),而且肯定更简单。 - Bogumił Kamiński

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接