为什么我的Java自底向上归并排序如此缓慢?

8

我花了几个小时的时间来弄清楚为什么我的Java版本排序算法比递归合并排序慢两倍,因为C和C++版本要快40-50%。 我一直在删除越来越多的代码,直到我将所有内容剥离到一个简单的循环和合并,但它仍然是两倍的速度。为什么只有Java这么慢?

供参考,这是自下而上合并排序的样子:

public static <T> void sort(T[] a, T[] aux, Comparator<T> comp) {
    int N = a.length;
    for (int n = 1; n < N; n = n+n)
        for (int i = 0; i < N-n; i += n+n)
            merge(a, aux, i, i+n-1, Math.min(i+n+n-1, N-1), comp);
}

这里是递归版本的代码:
public static <T> void sort(T[] a, T[] aux, int lo, int hi, Comparator<T> comp) {
    int mid = lo + (hi - lo) / 2;
    sort(a, aux, lo, mid, comp);
    sort(a, aux, mid + 1, hi, comp);
    merge(a, aux, lo, mid, hi, comp);
}

这些算法基本上只是从这个网站上复制过来的。作为最后的选择,我想复制粘贴一些在线内容,但它的速度也比递归版本慢两倍。

Java有什么“特别之处”我错过了吗?

编辑:根据要求,下面是一些代码:

import java.util.*;
import java.lang.*;
import java.io.*;

class Test {
    public int value;
    public int index;
}

class TestComparator implements Comparator<Test> {
    public int compare(Test a, Test b) {
        if (a.value < b.value) return -1;
        if (a.value > b.value) return 1;
        return 0;
    }
}


class Merge<T> {
    private static <T> void Merge(T[] array, int start, int mid, int end, Comparator<T> comp, T[] buffer) {
        java.lang.System.arraycopy(array, start, buffer, 0, (mid - start));
        int A_count = 0, B_count = 0, insert = 0;
        while (A_count < (mid - start) && B_count < (end - mid)) {
            if (comp.compare(array[mid + B_count], buffer[A_count]) >= 0)
                array[start + insert++] = buffer[A_count++];
            else
                array[start + insert++] = array[mid + B_count++];
        }
        java.lang.System.arraycopy(buffer, A_count, array, start + insert, (mid - start) - A_count);
    }

    private static <T> void SortR(T[] array, int start, int end, T[] buffer, Comparator<T> comp) {
        if (end - start <= 2) {
            if (end - start == 2) {
                if (comp.compare(array[start], array[end - 1]) > 0) {
                    T swap = array[start];
                    array[start] = array[end - 1];
                    array[end - 1] = swap;
                }
            }

            return;
        }

        int mid = start + (end - start)/2;
        SortR(array, start, mid, buffer, comp);
        SortR(array, mid, end, buffer, comp);
        Merge(array, start, mid, end, comp, buffer);
    }

    public static <T> void Recursive(T[] array, Comparator<T> comp) {
        @SuppressWarnings("unchecked")
        T[] buffer = (T[]) new Object[array.length];
        SortR(array, 0, array.length, buffer, comp);
    }

    public static <T> void BottomUp(T[] array, Comparator<T> comp) {
        @SuppressWarnings("unchecked")
        T[] buffer = (T[]) new Object[array.length];

        int size = array.length;
        for (int index = 0; index < size - 1; index += 2) {
            if (comp.compare(array[index], array[index + 1]) > 0) {
                T swap = array[index];
                array[index] = array[index + 1];
                array[index + 1] = swap;
            }
        }

        for (int length = 2; length < size; length += length)
            for (int index = 0; index < size - length; index += length + length)
                Merge(array, index, index + length, Math.min(index + length + length, size), comp, buffer);
    }
}


class SortRandom {
    public static Random rand;
    public static int nextInt(int max) {
        // set the seed on the random number generator
        if (rand == null) rand = new Random();
        return rand.nextInt(max);
    }
    public static int nextInt() {
        return nextInt(2147483647);
    }
}

class Sorter {
    public static void main (String[] args) throws java.lang.Exception {
        int max_size = 1500000;
        TestComparator comp = new TestComparator();

        for (int total = 0; total < max_size; total += 2048 * 16) {
            Test[] array1 = new Test[total];
            Test[] array2 = new Test[total];

            for (int index = 0; index < total; index++) {
                Test item = new Test();

                item.value = SortRandom.nextInt();
                item.index = index;

                array1[index] = item;
                array2[index] = item;
            }

            double time1 = System.currentTimeMillis();
            Merge.BottomUp(array1, comp);
            time1 = System.currentTimeMillis() - time1;

            double time2 = System.currentTimeMillis();
            Merge.Recursive(array2, comp);
            time2 = System.currentTimeMillis() - time2;

            if (time1 >= time2)
                System.out.format("%f%% as fast\n", time2/time1 * 100.0);
            else
                System.out.format("%f%% faster\n", time2/time1 * 100.0 - 100.0);

            System.out.println("verifying...");
            for (int index = 0; index < total; index++) {
                if (comp.compare(array1[index], array2[index]) != 0) throw new Exception();
                if (array2[index].index != array1[index].index) throw new Exception();
            }
            System.out.println("correct!");
        }
    }
}

以下是C++版本:

这里是C++版:

#include <iostream>
#include <cassert>
#include <cstring>
#include <ctime>

class Test {
public:
    size_t value, index;
};

bool TestCompare(Test item1, Test item2) {
    return (item1.value < item2.value);
}

namespace Merge {
    template <typename T, typename Comparison>
    void Merge(T array[], int start, int mid, int end, Comparison compare, T buffer[]) {
        std::copy(&array[start], &array[mid], &buffer[0]);
        int A_count = 0, B_count = 0, insert = 0;
        while (A_count < (mid - start) && B_count < (end - mid)) {
            if (!compare(array[mid + B_count], buffer[A_count]))
                array[start + insert++] = buffer[A_count++];
            else
                array[start + insert++] = array[mid + B_count++];
        }
        std::copy(&buffer[A_count], &buffer[mid - start], &array[start + insert]);
    }

    template <typename T, typename Comparison>
    void SortR(T array[], int start, int end, T buffer[], Comparison compare) {
        if (end - start <= 2) {
            if (end - start == 2)
                if (compare(array[end - 1], array[start]))
                    std::swap(array[start], array[end - 1]);
            return;
        }

        int mid = start + (end - start)/2;
        SortR(array, start, mid, buffer, compare);
        SortR(array, mid, end, buffer, compare);
        Merge(array, start, mid, end, compare, buffer);
    }

    template <typename T, typename Comparison>
    void Recursive(T array[], int size, Comparison compare) {
        T *buffer = new T[size];
        SortR(array, 0, size, buffer, compare);
        delete[] buffer;
    }

    template <typename T, typename Comparison>
    void BottomUp(T array[], int size, Comparison compare) {
        T *buffer = new T[size];

        for (int index = 0; index < size - 1; index += 2) {
            if (compare(array[index + 1], array[index]))
                std::swap(array[index], array[index + 1]);
        }

        for (int length = 2; length < size; length += length)
            for (int index = 0; index < size - length; index += length + length)
                Merge(array, index, index + length, std::min(index + length + length, size), compare, buffer);

        delete[] buffer;
    }
}

int main() {
    srand(time(NULL));
    int max_size = 1500000;
    for (int total = 0; total < max_size; total += 2048 * 16) {
        Test *array1 = new Test[total];
        Test *array2 = new Test[total];

        for (int index = 0; index < total; index++) {
            Test item;
            item.value = rand();
            item.index = index;

            array1[index] = item;
            array2[index] = item;
        }

        double time1 = clock() * 1.0/CLOCKS_PER_SEC;
        Merge::BottomUp(array1, total, TestCompare);
        time1 = clock() * 1.0/CLOCKS_PER_SEC;

        double time2 = clock() * 1.0/CLOCKS_PER_SEC;
        Merge::Recursive(array2, total, TestCompare);
        time2 = clock() * 1.0/CLOCKS_PER_SEC;

        if (time1 >= time2)
           std::cout << time2/time1 * 100.0 << "% as fast" << std::endl;
        else
            std::cout << time2/time1 * 100.0 - 100.0 << "% faster" << std::endl;

        std::cout << "verifying... ";
        for (int index = 0; index < total; index++) {
            assert(array1[index].value == array2[index].value);
            assert(array2[index].index == array1[index].index);
        }
        std::cout << "correct!" << std::endl;

        delete[] array1;
        delete[] array2;
    }
    return 0;
}

这两个版本之间的差异并不像原始版本那么大,但是C++迭代版本更快,而Java迭代版本则更慢。(是的,我意识到这些版本有点糟糕,分配的内存比实际使用的还要多)。
更新2: 当我将自底向上的归并排序转换为后序遍历时,它与递归版本中数组访问的顺序非常相似,最终运行速度比递归版本快了约10%。因此,看起来这与缓存未命中有关,而不是微型基准测试或不可预测的JVM。
仅影响Java版本的原因可能是因为Java缺乏C++版本中使用的自定义值类型。我将在C++版本中单独分配所有Test类,并查看性能如何。我正在开发的排序算法不能轻松地适应这种类型的遍历,但如果C++版本的性能也下降,我可能没有太多选择。
更新3: 不,将C++版本切换到分配的类似乎对其性能没有任何明显影响。看起来这确实是Java特有的问题。

1
C++版本的自底向上排序比递归排序快50%,而Java版本的自底向上排序比递归排序慢两倍。 - BonzaiThePenguin
1
这取决于您的机器架构和JVM。可能还有其他一些因素。 - maxx777
1
你调用sort()几次进行测量?如果只调用一次,我猜这是JIT:在递归版本中,由于多次调用sort(),因此它会被JIT编译,而在自底向上的版本中则不会。 - axtavt
大约25次左右,尽管迭代的次数似乎并不重要。 - BonzaiThePenguin
@AnubianNoob:你完全错了。或者只是晚了15年?有些任务可能本质上比Java慢,但不是因为JVM。 - maaartinus
显示剩余10条评论
1个回答

2

有趣的问题。我无法弄清楚为什么底部向上版本比递归版本慢,而在数组大小为2的幂时它们表现相同。

至少底部向上版本只比递归版本略慢一点,而不是慢两倍。

Benchmark                             Mode          Mean   Mean error    Units
RecursiveVsBottomUpSort.bottomUp      avgt        64.436        0.376    us/op
RecursiveVsBottomUpSort.recursive     avgt        58.902        0.552    us/op

代码:

@OutputTimeUnit(TimeUnit.MICROSECONDS)
@BenchmarkMode(Mode.AverageTime)
@Warmup(iterations = 5, time = 1)
@Measurement(iterations = 10, time = 1)
@State(Scope.Thread)
@Threads(1)
@Fork(1)
public class RecursiveVsBottomUpSort {

    static final int N = 1024;
    int[] a = new int[N];
    int[] aux = new int[N];

    @Setup(Level.Invocation)
    public void fill() {
        Random r = ThreadLocalRandom.current();
        for (int i = 0; i < N; i++) {
            a[i] = r.nextInt();
        }
    }

    @GenerateMicroBenchmark
    public static int bottomUp(RecursiveVsBottomUpSort st) {
        int[] a = st.a, aux = st.aux;
        int N = a.length;
        for (int n = 1; n < N; n = n + n) {
            for (int i = 0; i < N - n; i += n + n) {
                merge(a, aux, i, i + n - 1, Math.min(i + n + n - 1, N - 1));
            }
        }
        return a[N - 1];
    }

    @GenerateMicroBenchmark
    public static int recursive(RecursiveVsBottomUpSort st) {
        sort(st.a, st.aux, 0, N - 1);
        return st.a[N - 1];
    }

    static void sort(int[] a, int[] aux, int lo, int hi) {
        if (lo == hi)
            return;
        int mid = lo + (hi - lo) / 2;
        sort(a, aux, lo, mid);
        sort(a, aux, mid + 1, hi);
        merge(a, aux, lo, mid, hi);
    }

    static void merge(int[] a, int[] aux, int lo, int mid, int hi) {
        System.arraycopy(a, lo, aux, lo, mid + 1 - lo);

        for (int j = mid+1; j <= hi; j++)
            aux[j] = a[hi-j+mid+1];

        int i = lo, j = hi;
        for (int k = lo; k <= hi; k++)
            if (aux[j] < aux[i]) a[k] = aux[j--];
            else                      a[k] = aux[i++];
    }
}

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接