在Matlab中将多个常数乘以矩阵,并将它们转换为块对角矩阵

4
我有三个常数a1、a2和a3。我有一个矩阵A。我想做的是得到a1*A、a2*A和a3*A三个矩阵。然后我想把它们转换成一个对角块矩阵。对于三个常数的情况,这很容易。我可以让b1=a1*A,b2=a2*A,b3=a3*A,然后在matlab中使用blkdiag(b1, b2, b3)。
如果我有n个常数a1 ... an,怎么办?如何在没有任何循环的情况下完成此操作?我知道可以通过kronecker乘积来完成,但这非常耗时,而且需要进行大量不必要的0 * constant计算。
谢谢。
3个回答

5

讨论和代码

这可能是一种使用 bsxfun(@plus 的方法,它有助于在函数格式中编写的线性索引

function out = bsxfun_linidx(A,a)
%// Get sizes
[A_nrows,A_ncols] = size(A);
N_a = numel(a);

%// Linear indexing offsets between 2 columns in a block & between 2 blocks
off1 = A_nrows*N_a;
off2 = off1*A_ncols+A_nrows;

%// Get the matrix multiplication results
vals = bsxfun(@times,A,permute(a,[1 3 2])); %// OR vals = A(:)*a_arr;

%// Get linear indices for the first block
block1_idx = bsxfun(@plus,[1:A_nrows]',[0:A_ncols-1]*off1);  %//'

%// Initialize output array base on fast pre-allocation inspired by -
%// http://undocumentedmatlab.com/blog/preallocation-performance
out(A_nrows*N_a,A_ncols*N_a) = 0; 

%// Get linear indices for all blocks and place vals in out indexed by them
out(bsxfun(@plus,block1_idx(:),(0:N_a-1)*off2)) = vals;

return;

如何使用:要使用上面列出的函数代码,假设您已将 a1、a2、a3、....、an 存储在向量 a 中,则可以像这样执行 out = bsxfun_linidx(A,a) 以获得所需的输出到 out。
基准测试:本节将运行时性能方面的方法与其他两个答案中列出的方法进行比较或基准测试。
其他答案被转换为函数形式,如下所示-
function B = bsxfun_blkdiag(A,a)
B = bsxfun(@times, A, reshape(a,1,1,[])); %// step 1: compute products as a 3D array
B = mat2cell(B,size(A,1),size(A,2),ones(1,numel(a))); %// step 2: convert to cell array
B = blkdiag(B{:}); %// step 3: call blkdiag with comma-separated list from cell array

and,

function out = kron_diag(A,a_arr)
out = kron(diag(a_arr),A);

为了比较,测试了四种Aa大小的组合,它们分别是:

  • A500 x 500a1 x 10
  • A200 x 200a1 x 50
  • A100 x 100a1 x 100
  • A50 x 50a1 x 200

使用的基准测试代码如下:

%// Datasizes
N_a = [10  50  100 200];
N_A = [500 200 100 50];

timeall = zeros(3,numel(N_a)); %// Array to store runtimes
for iter = 1:numel(N_a)
    
    %// Create random inputs
    a = randi(9,1,N_a(iter));
    A = rand(N_A(iter),N_A(iter));
    
    %// Time the approaches
    func1 = @() kron_diag(A,a);
    timeall(1,iter) = timeit(func1); clear func1
    
    func2 = @() bsxfun_blkdiag(A,a);
    timeall(2,iter) = timeit(func2); clear func2
    
    func3 = @() bsxfun_linidx(A,a);
    timeall(3,iter) = timeit(func3); clear func3
end

%// Plot runtimes against size of A
figure,hold on,grid on
plot(N_A,timeall(1,:),'-ro'),
plot(N_A,timeall(2,:),'-kx'),
plot(N_A,timeall(3,:),'-b+'),
legend('KRON + DIAG','BSXFUN + BLKDIAG','BSXFUN + LINEAR INDEXING'),
xlabel('Datasize (Size of A) ->'),ylabel('Runtimes (sec)'),title('Runtime Plot')

%// Plot runtimes against size of a
figure,hold on,grid on
plot(N_a,timeall(1,:),'-ro'),
plot(N_a,timeall(2,:),'-kx'),
plot(N_a,timeall(3,:),'-b+'),
legend('KRON + DIAG','BSXFUN + BLKDIAG','BSXFUN + LINEAR INDEXING'),
xlabel('Datasize (Size of a) ->'),ylabel('Runtimes (sec)'),title('Runtime Plot')

我得到的运行时图如下:

enter image description here

enter image description here

结论: 如你所见,根据你处理的数据大小,可以考虑使用其中一种基于bsxfun的方法!


干得好!也许使用 timeit 进行基准测试会更可靠? - Luis Mendo
感谢您在基准测试中包含了我的方法! - Luis Mendo

5

下面是另一种方法:

  1. 使用 bsxfun 将乘积计算为3D数组;
  2. 将其转换为每个单元格中包含一个乘积(矩阵)的单元格数组;
  3. 从单元格数组生成逗号分隔列表,调用 blkdiag

假设 A 表示您的矩阵,a 表示您的常量向量。那么所需的结果 B 可以通过以下方式获得:

B = bsxfun(@times, A, reshape(a,1,1,[])); %// step 1: compute products as a 3D array
B = mat2cell(B,size(A,1),size(A,2),ones(1,numel(a))); %// step 2: convert to cell array
B = blkdiag(B{:}); %// step 3: call blkdiag with comma-separated list from cell array

如果矩阵A是一个方阵,且其中a的数量小于行数或列数,则这可能是最佳选择!非常棒的发现/代码,干得好! - Divakar

3
这里有一种使用kron的方法,似乎比Divakar基于bsxfun的解决方案更快且更节省内存。我不确定这是否与您的方法不同,但时间似乎很好。值得在不同方法之间进行一些测试,以确定哪种方法对您的问题更有效。
A=magic(4);

a1=1;
a2=2;
a3=3;

kron(diag([a1 a2 a3]),A)

你可以使用 kron(diag(sparse(a_arr)),A) 来节省更多的内存。 - knedlsepp

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接