多线程矩阵乘法

5

我最近开始学习Java中的多线程。由于我在我的大学正在编写一个数字计算程序,所以我决定通过编写多线程矩阵乘法来进行一些初步尝试。

这是我的代码。请记住,这只是作为第一次尝试而制作的,并不是非常干净。

    public class MultithreadingTest{

        public static void main(String[] args) {
            // TODO Auto-generated method stub
            double[][] matrix1 = randomSquareMatrix(2000);
            double[][] matrix2 = randomSquareMatrix(2000);

            matrixMultiplication(matrix1,matrix2,true);
            matrixMultiplicationSingleThread(matrix1, matrix2);
            try {
                matrixMultiplicationParallel(matrix1,matrix2, true);
            } catch (InterruptedException | ExecutionException e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
            }
            try {
                matrixMultiplicationParallel2(matrix1,matrix2, true);
            } catch (InterruptedException | ExecutionException e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
            }

        }

        public static double[][] randomSquareMatrix(int n){
            double[][] mat = new double[n][n];
            Random rand = new Random();
            for(int i=0; i<n; i++) for(int j=0; j<n; j++) mat[i][j]=rand.nextInt(10);
            return mat;
        }
        public static void printSquareMat(double[][] mat){
            int n=mat.length;
            for(int i=0; i<n; i++){ for(int j=0; j<n; j++) System.out.print(mat[i][j]+" "); System.out.print("\n");}
            System.out.print("\n");
        }

        public static void average(double[][] matrix)
        {
            int n=matrix.length;
            double sum=0;
            for(int i=0; i<n; i++) for(int j=0; j<n; j++) sum+=matrix[i][j];

            System.out.println("Average of all Elements of Matrix : "+(sum/(n*n)));
        }

        public static void matrixMultiplication(double[][] matrix1, double[][] matrix2, boolean printMatrix){

            int n=matrix1.length;
            double[][] resultMatrix = new double[n][n];

            double startTime = System.currentTimeMillis();

            for(int i=0; i<n; i++)for(int j=0; j<n; j++)for(int k=0; k<n; k++) resultMatrix[i][j]+=matrix1[i][k]*matrix2[k][j];


            if (printMatrix && n<=5)for(int i=0; i<n; i++){for(int j=0; j<n; j++) System.out.print(resultMatrix[i][j]+" ");System.out.print("\n"); }

            System.out.print("\n");
            System.out.println(((System.currentTimeMillis()-startTime)/1000)+
                    " seconds for matrix of size "+n+" in main thread.");
            average(resultMatrix);
        }

        public static void matrixMultiplicationSingleThread(double[][] m1, double[][] m2)
        {
            int n=m1.length;
            double startTime = System.currentTimeMillis();
            Thread t = new Thread(new multiSingle(m1,m2));
            t.start();
            try {
                t.join();
            } catch (InterruptedException e) {
                // TODO Auto-generated catch block
                System.out.println("Error");
                e.printStackTrace();
            }
            System.out.print("\n");
            System.out.println(((System.currentTimeMillis()-startTime)/1000)+
                    " seconds for matrix of size "+n+" in external Thread.");

        }

        public static void matrixMultiplicationParallel(double[][] matrix1, double[][] matrix2, boolean printMatrix) throws InterruptedException, ExecutionException{

            int n=matrix1.length;
            double[][] resultMatrix=new double[n][n];
            double tmp;
            ExecutorService exe = Executors.newFixedThreadPool(2);
            Future<Double>[][] result = new Future[n][n];
            double startTime = System.currentTimeMillis();
            for(int i=0; i<n; i++)
            {
                for(int j=0; j<=i; j++)
                {
                    tmp=matrix2[i][j];
                    matrix2[i][j]=matrix2[j][i];
                    matrix2[j][i]=tmp;
                }
            }

            for(int i=0; i<n; i++)
            {
                for(int j=0; j<n; j++)
                {
                    result[i][j] = exe.submit(new multi(matrix1[i],matrix2[j]));
                }
            }

            exe.shutdown();
            exe.awaitTermination(1, TimeUnit.DAYS);

            for(int i=0; i<n; i++)
            {
                for(int j=0; j<n; j++)
                {
                    resultMatrix[i][j] = result[i][j].get();
                }
            }
            for(int i=0; i<n; i++)
            {
                for(int j=0; j<=i; j++)
                {
                    tmp=matrix2[i][j];
                    matrix2[i][j]=matrix2[j][i];
                    matrix2[j][i]=tmp;
                }
            }
            if (printMatrix && n<=5)for(int i=0; i<n; i++){for(int j=0; j<n; j++) System.out.print(resultMatrix[i][j]+" ");System.out.print("\n"); }

            System.out.print("\n");
            System.out.println(((System.currentTimeMillis()-startTime)/1000)+
                    " seconds for matrix of size "+n+" multithreaded with algorithm 1.");
            average(resultMatrix);
        }

        public static void matrixMultiplicationParallel2(double[][] matrix1, double[][] matrix2, boolean printMatrix) throws InterruptedException, ExecutionException{

            int n=matrix1.length;
            double[][] resultMatrix=new double[n][n];
            double tmp;
            ExecutorService exe = Executors.newFixedThreadPool(2);
            Future<Double>[][] result = new Future[n][n];
            double startTime = System.currentTimeMillis();


            for(int i=0; i<n; i++)
            {
                for(int j=0; j<n; j++)
                {
                    result[i][j] = exe.submit(new multi2(i,j,matrix1,matrix2));
                }
            }

            exe.shutdown();

            exe.awaitTermination(1, TimeUnit.DAYS);


            for(int i=0; i<n; i++)
            {
                for(int j=0; j<n; j++)
                {
                    resultMatrix[i][j] = result[i][j].get();
                }
            }

            if (printMatrix && n<=5)for(int i=0; i<n; i++){for(int j=0; j<n; j++) System.out.print(resultMatrix[i][j]+" ");System.out.print("\n"); }

            System.out.print("\n");
            System.out.println(((System.currentTimeMillis()-startTime)/1000)+
                    " seconds for matrix of size "+n+" multithreaded with algorithm 2.");
            average(resultMatrix);
        }

        public static class multi implements Callable<Double>{

            multi(double[] vec1, double[] vec2){
                this.vec1=vec1; this.vec2=vec2;
            }
            double result;
            double[] vec1, vec2;

            @Override
            public Double call() {
                result=0;
                for(int i=0; i<vec1.length; i++) result+=vec1[i]*vec2[i];
                return result;
            }
        }

        public static class multi2 implements Callable<Double>{

            multi2(int a, int b, double[][] vec1, double[][] vec2){
                this.a=a; this.b=b; this.vec1=vec1; this.vec2=vec2;
            }
            int a,b;
            double result;
            double[][] vec1, vec2;

            @Override
            public Double call() {
                result=0;
                for(int i=0; i<vec1.length; i++) result+=vec1[a][i]*vec2[i][b];
                return result;
            }
        }

        public static class multiSingle implements Runnable{

            double[][] matrix1, matrix2;

            multiSingle(double[][] m1, double[][] m2){
                matrix1=m1;
                matrix2=m2;
            }
            public static void matrixMultiplication(double[][] matrix1, double[][] matrix2, boolean printMatrix){

                int n=matrix1.length;
                double[][] resultMatrix = new double[n][n];

                for(int i=0; i<n; i++)for(int j=0; j<n; j++)for(int k=0; k<n; k++) resultMatrix[i][j]+=matrix1[i][k]*matrix2[k][j];

                MultithreadingTest.average(resultMatrix);
            }

            @Override
            public void run() {
                matrixMultiplication(matrix1, matrix2, false);
            }
        }

    }

我有两个关于多线程的一般性问题,希望可以在此不开新话题进行讨论。

  1. 是否有一种方法可以编写不需要额外实现可运行或可调用线程类的代码?我查看了使用匿名内部类和lambda的方法,但据我所知,我无法通过这种方式传递参数到线程中,因为run()和call()没有任何参数,除非参数是final。但是,假设我编写了一个矩阵操作类,我宁愿不为我想要在线程中运行的每个操作编写额外的类。
  2. 假设我的类执行许多多线程操作,在每个方法中创建新的线程池并关闭它将浪费大量资源。因此,我想在我的类作为成员时创建一个线程池,需要时进行实例化并使用invokeAll方法。但是,如果删除了我的对象,会发生什么情况?因为我从未关闭线程池,所以会出现问题吗?在C++中,我会使用析构函数来处理这个问题。还是gc在这种情况下全权负责?

现在直接涉及我的代码:

我以四种不同的方式实现矩阵乘法,作为在我的主线程中运行的方法,作为在新线程中运行的方法,但仍未多线程(确保我的主线程中没有任何后台任务减慢它的速度),以及两种不同的多线程矩阵乘法。第一个版本将第二个矩阵转置,将乘法作为向量-向量乘法提交,并将矩阵转置回其原始形式。第二个版本直接使用矩阵,并额外使用两个索引来定义矩阵的行和列进行向量-向量乘法。

对于所有版本,我测量了乘法所需的时间,并计算了结果矩阵的平均值,以查看结果是否相同。

我在两台计算机上运行了此代码,都是相同的JVM和Windows 10。第一台是我的笔记本电脑,i5第5代,2.6 GHz双核,第二台是我的台式电脑,i5第4代,4.2 GHz四核。

我预计我的桌面电脑会快得多。我还预计多线程版本需要大约单线程版本的一半/四分之一的时间,但仍然需要更长,因为要创建线程等额外工作。最后,我预计第二个多线程版本,不会两次转置一个矩阵,应该更快,因为操作较少。

运行代码后,我的结果有点混乱,请有人解释一下:

对于单线程的方法,对于矩阵大小为3000的情况,我的笔记本电脑需要大约340秒。因此我认为我的主线程没有执行昂贵的后台任务。另一方面,我的台式电脑需要440秒。那么问题来了,为什么我的笔记本电脑虽然速度明显较慢,但却比我的台式电脑快这么多呢?即使第五代处理器比第四代更快,由于我的台式电脑以1.6倍于笔记本电脑的速度运行,我仍然希望它更快。这两个处理器之间的差异不可能那么大。
对于多线程的方法,我的笔记本电脑需要大约34秒。如果多线程是完美的,则理论上应该不到一半的时间。为什么在两个线程上快了十倍?我的台式电脑也是如此。使用四个线程,乘法计算只需16秒,而不是440秒。这就像我的台式电脑以与我的笔记本相同的速度工作,只不过是在四个线程而不是两个线程上。
现在来比较两种多线程方法,将一个矩阵转置两次的版本在我的笔记本电脑上需要大约34秒,而直接使用矩阵的版本需要大约200秒。这听起来很现实,因为它比单线程方法慢了一半以上。但是为什么它比第一个版本慢那么多呢?我会认为两次转置矩阵比获取矩阵元素所需的附加时间更慢。我是否遗漏了什么或者使用矩阵真的比使用向量慢那么多?
希望有人能够回答这些问题。很抱歉写了这么长的文章。
此致 Thorsten
2个回答

3
这个谜题的答案是:矩阵乘法所需时间主要由从RAM到CPU缓存的数据传输时间决定。你可能有4个核心,但只有一个RAM总线,所以如果它们都在等待内存访问而相互阻塞,使用更多核心(多线程)将不会带来任何好处。
你应该尝试的第一个实验是:使用矩阵转置和向量乘法编写单线程版本。你会发现它非常快——可能与具有转置的多线程版本一样快。
原始单线程版本之所以如此慢,是因为它必须为每个要相乘的列中的单元格加载一个缓存块。如果使用矩阵转置,则所有这些单元格在内存中都是连续的,并且加载一个块可以获得一堆单元格。
因此,如果您想优化矩阵乘法,请先优化缓存效率的内存访问,然后将工作分配给几个线程——不超过您拥有的核心数量的两倍。任何超过这个数目的线程只会浪费时间和资源,包括上下文切换等方面。
关于您的其他问题:
1)使用捕获创建它们的范围中的变量的lambda表达式很方便,例如:
for(int i=0; i<n; i++)
{
    for(int j=0; j<n; j++)
    {
        final double[] v1 = matrix1[i];
        final double[] v2 = matrix2[j];
        result[i][j] = exe.submit(() -> vecdot(v1,v2));
    }
}

2) 垃圾回收器会处理它。您无需显式关闭线程池以释放任何资源。


我尝试了单线程-转置版本,在我的笔记本电脑上比多线程版本慢了3秒。 - Thorsten Schmitz

1

在创建线程时,您必须小心地将开销最小化。一个好的方法是使用ForkJoin框架来使用线程池分解问题。该框架

  • 重用现有的线程池。
  • 分解任务,直到有足够的任务使线程池保持繁忙,但不再多余。

每个核心只有一个浮点运算单元,因此您的可扩展性将基于您拥有的核心数。

我建议您阅读Java中的Fork Join矩阵乘法,但我找不到这段代码的原始来源。

http://gee.cs.oswego.edu/dl/papers/fj.pdf

http://gee.cs.oswego.edu/dl/cpjslides/fj.pdf 关于使用ForkJoin框架的内容。


1
嗨,谢谢提供的信息,我会看一下这个框架。关于处理器,我目前了解到的是第五代主要是第四代的缩小版(22纳米->14纳米),并且性能应该比第四代好约5%。但两者都有Turbo Boost,并且我给出的时钟速度是Turbo Boost的最高速度。它们通常运行在2和3.4 GHz。 - Thorsten Schmitz
@ThorstenSchmitz 你说得对,我发现Skylake(第六代)比Haswell在相同的GHz下快了多达20%。我还没有尝试过Broadwell。 - Peter Lawrey

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接