Java 8矩阵*向量乘法

7

我想知道在Java 8中是否有更简洁的方法来使用流(streams)完成以下操作:

public static double[] multiply(double[][] matrix, double[] vector) {
    int rows = matrix.length;
    int columns = matrix[0].length;

    double[] result = new double[rows];

    for (int row = 0; row < rows; row++) {
        double sum = 0;
        for (int column = 0; column < columns; column++) {
            sum += matrix[row][column]
                    * vector[column];
        }
        result[row] = sum;
    }
    return result;
}

修改代码。我收到了一个非常好的答案,但性能比旧实现慢10倍左右,因此我在此添加测试代码,以便有人可以调查:

@Test
public void profile() {
    long start;
    long stop;
    int tenmillion = 10000000;
    double[] vector = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 };

    double[][] matrix = new double[tenmillion][10];

    for (int i = 0; i < tenmillion; i++) {
        matrix[i] = vector.clone();
    }
    start = System.currentTimeMillis();
    multiply(matrix, vector);
    stop = System.currentTimeMillis();
 }
2个回答

9
使用Stream的一种直接方法如下所示:
public static double[] multiply(double[][] matrix, double[] vector) {
    return Arrays.stream(matrix)
            .mapToDouble(row -> IntStream.range(0, row.length)
                    .mapToDouble(col -> row[col] * vector[col])
                    .sum())
            .toArray();
}

这将创建一个矩阵每一行的流(Stream<double[]>),然后将每一行映射到与vector数组计算乘积所得的双精度值。
我们必须使用索引上的流来计算乘积,因为不幸的是没有内置的方法将两个流合并在一起。

3
嗨Tunaki - 谢谢 - 它有效!不过我对性能有点惊讶。对于一个有1000万行的矩阵,它需要10倍的时间才能完成...所以我想我还是坚持使用“旧”的循环吧。不过还是谢谢你。 - Ole

2
你现在使用的性能测试方法并不可靠,手动编写微基准测试通常也不是一个好主意。例如,在编译代码时,JVM 可能会选择更改执行顺序,并且开始和停止变量可能没有被分配到您期望它们被分配的位置,从而导致测量结果出现意外的结果。同样非常重要的是要预热 JVM,让 JIT 编译器进行所有优化。垃圾回收(GC)也可以在引入应用程序吞吐量和响应时间变化方面发挥非常重要的作用。我强烈建议使用专业工具,如 JMH 和 Caliper 进行微基准测试。
我还编写了一些带有 JVM 预热、随机数据集和更高迭代次数的基准测试代码。结果表明,Java 8 流提供了更好的结果。
/**
 *
 */
public class MatrixMultiplicationBenchmark {
    private static AtomicLong start = new AtomicLong();
    private static AtomicLong stop = new AtomicLong();
    private static Random random = new Random();

    /**
     * Main method that warms-up each implementation and then runs the benchmark.
     *
     * @param args main class args
     */
    public static void main(String[] args) {
        // Warming up with more iterations and smaller data set
        System.out.println("Warming up...");
        IntStream.range(0, 10_000_000).forEach(i -> run(10, MatrixMultiplicationBenchmark::multiplyWithStreams));
        IntStream.range(0, 10_000_000).forEach(i -> run(10, MatrixMultiplicationBenchmark::multiplyWithForLoops));

        // Running with less iterations and larger data set
        startWatch("Running MatrixMultiplicationBenchmark::multiplyWithForLoops...");
        IntStream.range(0, 10).forEach(i -> run(10_000_000, MatrixMultiplicationBenchmark::multiplyWithForLoops));
        endWatch("MatrixMultiplicationBenchmark::multiplyWithForLoops");

        startWatch("Running MatrixMultiplicationBenchmark::multiplyWithStreams...");
        IntStream.range(0, 10).forEach(i -> run(10_000_000, MatrixMultiplicationBenchmark::multiplyWithStreams));
        endWatch("MatrixMultiplicationBenchmark::multiplyWithStreams");
    }

    /**
     * Creates the random matrix and vector and applies them in the given implementation as BiFunction object.
     *
     * @param multiplyImpl implementation to use.
     */
    public static void run(int size, BiFunction<double[][], double[], double[]> multiplyImpl) {
        // creating random matrix and vector
        double[][] matrix = new double[size][10];
        double[] vector = random.doubles(10, 0.0, 10.0).toArray();
        IntStream.range(0, size).forEach(i -> matrix[i] = random.doubles(10, 0.0, 10.0).toArray());

        // applying matrix and vector to the given implementation. Returned value should not be ignored in test cases.
        double[] result = multiplyImpl.apply(matrix, vector);
    }

    /**
     * Multiplies the given vector and matrix using Java 8 streams.
     *
     * @param matrix the matrix
     * @param vector the vector to multiply
     *
     * @return result after multiplication.
     */
    public static double[] multiplyWithStreams(final double[][] matrix, final double[] vector) {
        final int rows = matrix.length;
        final int columns = matrix[0].length;

        return IntStream.range(0, rows)
                .mapToDouble(row -> IntStream.range(0, columns)
                        .mapToDouble(col -> matrix[row][col] * vector[col])
                        .sum()).toArray();
    }

    /**
     * Multiplies the given vector and matrix using vanilla for loops.
     *
     * @param matrix the matrix
     * @param vector the vector to multiply
     *
     * @return result after multiplication.
     */
    public static double[] multiplyWithForLoops(double[][] matrix, double[] vector) {
        int rows = matrix.length;
        int columns = matrix[0].length;

        double[] result = new double[rows];

        for (int row = 0; row < rows; row++) {
            double sum = 0;
            for (int column = 0; column < columns; column++) {
                sum += matrix[row][column] * vector[column];
            }
            result[row] = sum;
        }
        return result;
    }

    private static void startWatch(String label) {
        System.out.println(label);
        start.set(System.currentTimeMillis());
    }

    private static void endWatch(String label) {
        stop.set(System.currentTimeMillis());
        System.out.println(label + " took " + ((stop.longValue() - start.longValue()) / 1000) + "s");
    }
}

这里是输出结果

Warming up...
Running MatrixMultiplicationBenchmark::multiplyWithForLoops...
MatrixMultiplicationBenchmark::multiplyWithForLoops took 100s
Running MatrixMultiplicationBenchmark::multiplyWithStreams...
MatrixMultiplicationBenchmark::multiplyWithStreams took 89s

嗨 - 非常感谢您的反馈。这非常鼓舞人心。有一瞬间我以为我只能使用基本功能 : )。我需要再多尝试一下。再次感谢! - Ole
更好的做法是使用像JMH这样的基准测试框架,它可以处理所有预热阶段。 - Tunaki

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接