多线程矩阵乘法。

5
我编写了一个多线程矩阵乘法的代码。我相信我的方法是正确的,但我不确定是否百分之百正确。关于线程,我不明白为什么不能只运行(new MatrixThread(...)).start()而不使用ExecutorService
此外,当我对多线程方法和传统方法进行基准测试时,传统方法要快得多...
我做错了什么?
矩阵类:
import java.util.*;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

class Matrix
{
   private int dimension;
   private int[][] template;

   public Matrix(int dimension)
   {
      this.template = new int[dimension][dimension];
      this.dimension = template.length;
   }

   public Matrix(int[][] array) 
   {
      this.dimension = array.length;
      this.template = array;      
   }

   public int getMatrixDimension() { return this.dimension; }

   public int[][] getArray() { return this.template; }

   public void fillMatrix()
   {
      Random randomNumber = new Random();
      for(int i = 0; i < dimension; i++)
      {
         for(int j = 0; j < dimension; j++)
         {
            template[i][j] = randomNumber.nextInt(10) + 1;
         }
      }
   }

   @Override
   public String toString()
   {
      String retString = "";
      for(int i = 0; i < this.getMatrixDimension(); i++)
      {
         for(int j = 0; j < this.getMatrixDimension(); j++)
         {
            retString += " " + this.getArray()[i][j];
         }
         retString += "\n";
      }
      return retString;
   }

   public static Matrix classicalMultiplication(Matrix a, Matrix b)
   {      
      int[][] result = new int[a.dimension][b.dimension];
      for(int i = 0; i < a.dimension; i++)
      {
         for(int j = 0; j < b.dimension; j++)
         {
            for(int k = 0; k < b.dimension; k++)
            {
               result[i][j] += a.template[i][k] * b.template[k][j];
            }
         }
      }
      return new Matrix(result);
   }

   public Matrix multiply(Matrix multiplier) throws InterruptedException
   {
      Matrix result = new Matrix(dimension);
      ExecutorService es = Executors.newFixedThreadPool(dimension*dimension);
      for(int currRow = 0; currRow < multiplier.dimension; currRow++)
      {
         for(int currCol = 0; currCol < multiplier.dimension; currCol++)
         {            
            //(new MatrixThread(this, multiplier, currRow, currCol, result)).start();            
            es.execute(new MatrixThread(this, multiplier, currRow, currCol, result));
         }
      }
      es.shutdown();
      es.awaitTermination(2, TimeUnit.DAYS);
      return result;
   }

   private class MatrixThread extends Thread
   {
      private Matrix a, b, result;
      private int row, col;      

      private MatrixThread(Matrix a, Matrix b, int row, int col, Matrix result)
      {         
         this.a = a;
         this.b = b;
         this.row = row;
         this.col = col;
         this.result = result;
      }

      @Override
      public void run()
      {
         int cellResult = 0;
         for (int i = 0; i < a.getMatrixDimension(); i++)
            cellResult += a.template[row][i] * b.template[i][col];

         result.template[row][col] = cellResult;
      }
   }
} 

主类:

import java.util.Scanner;

public class MatrixDriver
{
   private static final Scanner kb = new Scanner(System.in);

   public static void main(String[] args) throws InterruptedException
   {      
      Matrix first, second;
      long timeLastChanged,timeNow;
      double elapsedTime;

      System.out.print("Enter value of n (must be a power of 2):");
      int n = kb.nextInt();

      first = new Matrix(n);
      first.fillMatrix();      
      second = new Matrix(n);
      second.fillMatrix();

      timeLastChanged = System.currentTimeMillis();
      //System.out.println("Product of the two using threads:\n" +
                                                        first.multiply(second);
      timeNow = System.currentTimeMillis();
      elapsedTime = (timeNow - timeLastChanged)/1000.0;
      System.out.println("Threaded took "+elapsedTime+" seconds");

      timeLastChanged = System.currentTimeMillis();
      //System.out.println("Product of the two using classical:\n" +
                                  Matrix.classicalMultiplication(first,second);
      timeNow = System.currentTimeMillis();
      elapsedTime = (timeNow - timeLastChanged)/1000.0;
      System.out.println("Classical took "+elapsedTime+" seconds");
   }
} 

附言:如果需要进一步的澄清,请告知。


你的代码缺少“Multiply”方法。 - Dror Helper
1
为什么要像这样多线程处理呢?这完全是CPU密集型的,不像你有一个线程被I/O阻塞了。 - matt b
多线程可能运行良好,但更取决于有多少个 CPU(例如,在您的示例中,10x10 乘以 10x10 创建了 100 个线程...您可能只有 2-8 个 CPU),以及矩阵有多大(它们是否适合 L2/L3 缓存?)。像 MKL 和 OpenCL 这样的本地库做得更好。 - basszero
Matt b:多个硬件线程??虽然可能远远达不到n^2个。 - Tom Hawtin - tackline
在扩展Thread方面,几乎总是一个坏主意。在这种情况下,代码甚至没有启动线程。Thread实现Runnable的事实是不幸的。 - Tom Hawtin - tackline
@matt:如果你有一个巨大的矩阵和多个处理器/核心,那么这可能是有益的。不过,最有可能的原因是它是一项作业任务。 - erikkallen
3个回答

6
即使使用ExecutorService创建线程,也会涉及大量的开销。我猜测你的多线程方法之所以很慢,是因为你花费了99%的时间来创建新线程,只有1%或更少的时间用于实际计算。
通常,为了解决这个问题,你需要将一堆操作分批处理,并在单个线程上运行。我不确定如何在这种情况下做到这一点,但我建议将矩阵分成较小的块(例如10个较小的矩阵),并在线程上运行它们,而不是在每个单元格中运行单独的线程。

5
您正在创建大量线程。不仅创建线程的成本很高,而且对于CPU绑定的应用程序,您不希望拥有比可用处理器更多的线程(如果这样做,您必须花费处理能力在线程之间切换,这也可能导致缓存未命中,这是非常昂贵的)。
此外,将线程发送到execute是不必要的;它只需要一个Runnable。通过应用这些更改,您将获得巨大的性能提升:
  1. Make the ExecutorService a static member, size it for the current processor, and send it a ThreadFactory so it doesn't keep the program running after main has finished. (It would probably be architecturally cleaner to send it as a parameter to the method rather than keeping it as a static field; I leave that as an exercise for the reader. ☺)

    private static final ExecutorService workerPool = 
        Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors(), new ThreadFactory() {
            public Thread newThread(Runnable r) {
                Thread t = new Thread(r);
                t.setDaemon(true); 
                return t;
            }
        });
    
  2. Make MatrixThread implement Runnable rather than inherit Thread. Threads are expensive to create; POJOs are very cheap. You can also make it static which makes the instances smaller (as non-static classes get an implicit reference to the enclosing object).

    private static class MatrixThread implements Runnable
    
  3. From change (1), you can no longer awaitTermination to make sure all tasks are finished (as this worker pool). Instead, use the submit method which returns a Future<?>. Collect all the future objects in a list, and when you've submitted all the tasks, iterate over the list and call get for each object.

您的multiply方法现在应该看起来像这样:

public Matrix multiply(Matrix multiplier) throws InterruptedException {
    Matrix result = new Matrix(dimension);
    List<Future<?>> futures = new ArrayList<Future<?>>();
    for(int currRow = 0; currRow < multiplier.dimension; currRow++) {
        for(int currCol = 0; currCol < multiplier.dimension; currCol++) {            
            Runnable worker = new MatrixThread(this, multiplier, currRow, currCol, result);
            futures.add(workerPool.submit(worker));
        }
    }
    for (Future<?> f : futures) {
        try {
            f.get();
        } catch (ExecutionException e){
            throw new RuntimeException(e); // shouldn't happen, but might do
        }
    }
    return result;
}

它的多线程版本比单线程版本更快吗?嗯,在我那个糟糕的盒子上,对于n < 1024的值,多线程版本速度较慢。
这只是冰山一角。真正的问题在于您创建了大量的MatrixThread实例 - 您的内存消耗为O(n²),这是一个非常糟糕的迹象。将内部for循环移动到MatrixThread.run中可以将性能提高数倍(理想情况下,您不应创建超过工作线程数量的任务)。
编辑:由于我有更紧迫的事情要做,我无法抵制进一步优化它的诱惑。我想出了这个(可怕的丑陋代码),仅创建O(n)个作业:
 public Matrix multiply(Matrix multiplier) throws InterruptedException {
     Matrix result = new Matrix(dimension);
     List<Future<?>> futures = new ArrayList<Future<?>>();
     for(int currRow = 0; currRow < multiplier.dimension; currRow++) {
         Runnable worker = new MatrixThread2(this, multiplier, currRow, result);
         futures.add(workerPool.submit(worker)); 
     }
     for (Future<?> f : futures) {
         try {
             f.get();
         } catch (ExecutionException e){
             throw new RuntimeException(e); // shouldn't happen, but might do
         }
     }
     return result;
 }


private static class MatrixThread2 implements Runnable
{
   private Matrix self, mul, result;
   private int row, col;      

   private MatrixThread2(Matrix a, Matrix b, int row, Matrix result)
   {         
      this.self = a;
      this.mul = b;
      this.row = row;
      this.result = result;
   }

   @Override
   public void run()
   {
      for(int col = 0; col < mul.dimension; col++) {
         int cellResult = 0;
         for (int i = 0; i < self.getMatrixDimension(); i++)
            cellResult += self.template[row][i] * mul.template[i][col];
         result.template[row][col] = cellResult;
      }
   }
}

虽然还有改进的空间,但多线程版本基本上可以计算出您愿意等待的任何内容,并且它会比单线程版本更快地完成。


非常感谢您的帮助!代码有点混乱,但我认为我能够弄清楚它。由于某种原因,当我运行代码时,未线程化版本仍然更快,但差异比以前合理得多。谢谢! - Alex Wood
将工作分成几个部分总会有一些开销。对于小的 n 值,多线程版本可能总是比较慢,但是 n 越大,多线程版本就越好。这个解决方案仍然有很多开销,因为它创建了 n 个任务(因此具有 O(n) 的同步开销)。如果您可以将乘法拆分为最多一些固定数量的任务(例如,可用处理器 * 2 或其他值),则程序在处理大的 n 值时会更快。 - gustafc
此外,对于小的 n 值,您可以直接执行非线程乘法,因为这很可能总是更快的。 - gustafc

1

首先,您应该使用与核心数量相同的newFixedThreadPool,对于四核处理器,您应该使用4个。其次,不要为每个矩阵创建一个新的线程池。

如果将executorservice设置为静态成员变量,则在矩阵大小为512时,线程版本的执行速度几乎始终更快。

此外,将MatrixThread更改为实现Runnable而不是扩展Thread也可以加快执行速度,在我的机器上,512的线程版本速度是非线程版本的2倍。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接