矩阵乘法的分治算法,时间复杂度

5

在此输入图片描述 我理解该算法使用了8次乘法和4次加法,时间复杂度为: 在此输入图片描述

每个n/2 * n/2矩阵都进行了乘法运算。我有几个问题:

  1. 是否通过执行T(n/2)将每个n * n矩阵最终缩小到n=1大小?如果是这样的话,对于以下矩阵中的a11*b11返回1*6似乎毫无意义:

enter image description here

那么基本情况应该是n==2执行else部分,因为下面的操作似乎是合法的。

enter image description here

为什么加法部分需要 0(n^2) 的时间复杂度?我的意思是,我们完全没有处理矩阵加法,而只是处理简单的数字,因为每个矩阵都被缩小到像下面这样的2 * 2大小:

enter image description here

所以加法部分应该只贡献4?(为什么是0(n^2)?)
3个回答

0

1) 最终矩阵会被简化为一个1*1的矩阵。但这并不重要,即使对于n==2也可以设置基本情况,因为乘以一个2*2的矩阵仍然需要恒定的时间,复杂度仍将保持不变,为O(1)。

2) 加法部分的复杂度应该是O(n^2),因为每个子问题都有(n^2)/4个元素,而且有4个这样的子问题,这意味着你实际上执行了n^2次操作,导致了O(n^2)的复杂度。


0
如果我正确理解了问题,那么部分可以如下回答。
  1. 实际上,所有矩阵最终都被缩减为 1*1 矩阵;这不应太令人惊讶,因为矩阵乘法的基本定义最终是用底层环的乘法来定义的。

  2. 加法部分在递归的每个级别上的复杂度为 0(n^2),因为它是在乘法的递归评估结果上执行的。


1
这不是 Strassen 算法,只是普通的分治技术。 - sagar_jeevan

0

仅凭算法,可能不清楚为什么加法步骤需要theta(n^2)的时间。我也有同样的困惑,认为加法应该需要恒定时间。在addMatrices()方法中,如果我们对2*2矩阵进行以下更改

C[rowC][columnC] = A[0][0] + B[0][0];

然后它也会给出相同的结果。

但是一旦我们采用4*4矩阵,就可以看到调用堆栈中会发生一些addMatrices()方法调用,该方法从矩阵A和B中添加多个元素。这就是为什么需要在循环内运行加法的原因。

实现程序后,理解起来要容易得多。我已经尝试解释了,有关详细信息,请参阅方法注释。

package matrix;

/***
 * Square Matrix multiplication(2^x) using divide and conquer technique
 * 
 * @author kmandal
 *
 */
public class MatrixMultiplication {

    public static void main(String[] args) {
        int[][] A = { { 1, 2 }, { 3, 4 } };
        int[][] B = { { 5, 6 }, { 7, 8 } };
        int C[][] = squareMatrixMultiplyRecursive(A, B);

        for (int i = 0; i < C.length; i++) {
            for (int j = 0; j < C.length; j++) {
                System.out.print(C[i][j] + "    ");
            }
            System.out.println();
        }
    }

    private static int[][] squareMatrixMultiplyRecursive(int[][] A, int[][] B) {
        return squareMatrixMultiplicationDNC(A, B, 0, 0, 0, 0, A.length);
    }

    /**
     * <pre>
     * Let A and B are 2 square matrices with dimension 2^x
     * A = [
     *      A00     A01
     *      A10     A11
     *      ]
     * ,
     * B = [
     *      B00     B01
     *      B10     B11
     *      ]
     * 
     * C be another matrix stores the result of multiplication of A and B.
     * 
     *  C = A.B;
     *  
     *  C = [
     *      C00     C01
     *      C10     C11
     *      ]
     *  
     *  where
     *  for C00 calculation, elements in 0th row of A and 0th column of B considered
     *  C00 = A00*B00+A01*B10;  
     *  
     *  for C01 calculation, elements in 0th row of A and 1st column of B considered
     *  C01 = A00*B01+A01*B11; 
     *  
     *  for C10 calculation, elements in 1st row of A and 0th column of B considered
     *  C10 = A10*B00+A11*B10; 
     *  
     *  for C11 calculation, elements in 1st row of A and 1st column of B considered
     *  C11 = A10*B01+A11*B11;
     * 
     * Here we are using index based calculation, 
     * hence time complexity for index calculation is Theta(1). 
     * 
     * We have divided the problem into 8 sub-problems with size n/2.
     * Hence the recurrence for this divide part is: 8T(n/2).
     * 
     * Additionally we need to consider the cost of matrix addition step, 
     * which is Theta(n^2). For more details refer addMatrices() method.
     * 
     * Hence the recurrence relation become 
     * T(n) = Theta(1) + 8T(n/2)+ Theta(n^2);
     * 
     * Applying Master theorem, 
     * the time complexity of this algorithm become O(n^3)
     * </pre>
     * 
     * @param A
     * @param B
     * @param rowA
     * @param columnA
     * @param rowB
     * @param columnB
     * @param size
     * @return
     */
    private static int[][] squareMatrixMultiplicationDNC(int[][] A, int[][] B,
            int rowA, int columnA, int rowB, int columnB, int size) {
        int[][] C = new int[size][size];
        if (size == 1) {
            C[0][0] = A[rowA][columnA] * B[rowB][columnB];
        } else {
            int newSize = size / 2;
            // calculate C00 = A00*B00+A01*B10;
            addMatrices(
                    C,
                    squareMatrixMultiplicationDNC(A, B, rowA, columnA, rowB,
                            columnB, newSize),
                    squareMatrixMultiplicationDNC(A, B, rowA,
                            columnA + newSize, rowB + newSize, columnB, newSize),
                    0, 0);
            // calculate C01 = A00*B01+A01*B11;
            addMatrices(
                    C,
                    squareMatrixMultiplicationDNC(A, B, rowA, columnA, rowB,
                            columnB + newSize, newSize),
                    squareMatrixMultiplicationDNC(A, B, rowA,
                            columnA + newSize, rowB + newSize, columnB
                                    + newSize, newSize), 0, newSize);
            // calculate C10 = A10*B00+A11*B10;
            addMatrices(
                    C,
                    squareMatrixMultiplicationDNC(A, B, rowA + newSize,
                            columnA, rowB, columnB, newSize),
                    squareMatrixMultiplicationDNC(A, B, rowA + newSize, columnA
                            + newSize, rowB + newSize, columnB, newSize),
                    newSize, 0);
            // calculate C11 = A10*B01+A11*B11;
            addMatrices(
                    C,
                    squareMatrixMultiplicationDNC(A, B, rowA + newSize,
                            columnA, rowB, columnB + newSize, newSize),
                    squareMatrixMultiplicationDNC(A, B, rowA + newSize, columnA
                            + newSize, rowB + newSize, columnB + newSize,
                            newSize), newSize, newSize);

        }
        return C;
    }

    /**
     * Matrix I represented by 2 dimensional array hence for addition of 2
     * matrices, need to fetch same element from both the matrices and then
     * add them. Traversing 2D array mean need to access elements by row and
     * column index thus need to loop inside loop. Hence time complexity of
     * addition is Theta(n^2)
     * 
     * @param C
     * @param A
     * @param B
     * @param rowC
     * @param columnC
     */
    private static void addMatrices(int[][] C, int[][] A, int[][] B, int rowC,
            int columnC) {
        int n = A.length;
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                C[i + rowC][j + columnC] = A[i][j] + B[i][j];
            }
        }
    }
}

Output:
19    22    
43    50  

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接