为什么JIT在边界检查消除方面表现如此糟糕?

3

我正在测试HotSpot JIT消除数组边界检查的能力。这里有两个版本的相同堆排序实现,一个使用普通的数组索引,另一个使用sun.misc.Unsafe API,没有边界检查:

public class HeapSort {
    // copied from http://en.wikibooks.org/wiki/Algorithm_Implementation/Sorting/Heapsort#C
    static int heapSortSimple(int[] arr) {
        int t;
        int n = arr.length, parent = n / 2, index, childI;
        while (true) {
            if (parent > 0) {
                t = arr[--parent]; // 1, i. e. first indexing
            } else {
                if (--n == 0) break;
                t = arr[n]; // 2
                arr[n] = arr[0]; // 3, 4
            }
            index = parent;
            childI = (index << 1) + 1;
            while (childI < n) {
                int childV = arr[childI]; // 5
                int right;
                if (childI + 1 < n && (right = arr[childI + 1]) > childV) { // 6
                    childI++;
                    childV = right;
                }
                if (childV > t) {
                    arr[index] = childV; // 7
                    index = childI;
                    childI = (index << 1) + 1;
                } else {
                    break;
                }
            }
            arr[index] = t; // 8
        }
        return arr[arr.length - 1];
    }

    static int heapSortUnsafe(int[] arr) {
        int t;
        long n = arr.length * INT_SCALE, parent = (arr.length / 2) * INT_SCALE, index, childI;
        while (true) {
            if (parent > 0) {
                t = U.getInt(arr, INT_BASE + (parent -= INT_SCALE));
            } else {
                if ((n -= INT_SCALE) == 0) break;
                t = U.getInt(arr, INT_BASE + n);
                U.putInt(arr, INT_BASE + n, U.getInt(arr, INT_BASE));
            }
            index = parent;
            childI = (index << 1) + INT_SCALE;
            while (childI < n) {
                int childV = U.getInt(arr, INT_BASE + childI);
                int right;
                if (childI + INT_SCALE < n &&
                        (right = U.getInt(arr, INT_BASE + (childI + INT_SCALE))) > childV) {
                    childI += INT_SCALE;
                    childV = right;
                }
                if (childV > t) {
                    U.putInt(arr, INT_BASE + index, childV);
                    index = childI;
                    childI = (index << 1) + INT_SCALE;
                } else {
                    break;
                }
            }
            U.putInt(arr, INT_BASE + index, t);
        }
        return arr[arr.length - 1];
    }

    @OutputTimeUnit(TimeUnit.MICROSECONDS)
    @BenchmarkMode(Mode.AverageTime)
    @Warmup(iterations = 5, time = 1)
    @Measurement(iterations = 10, time = 1)
    @State(Scope.Thread)
    @Threads(1)
    @Fork(1)
    public static class Benchmarks {
        static final int N = 1024;
        int[] a = new int[N];

        @Setup(Level.Invocation)
        public void fill() {
            Random r = ThreadLocalRandom.current();
            for (int i = 0; i < N; i++) {
                a[i] = r.nextInt();
            }
        }

        @GenerateMicroBenchmark
        public static int simple(Benchmarks st) {
            int[] arr = st.a;
            return heapSortSimple(arr);
        }

        @GenerateMicroBenchmark
        public static int unsafe(Benchmarks st) {
            int[] arr = st.a;
            return heapSortUnsafe(arr);
        }
    }

    public static void main(String[] args) {
        Benchmarks bs = new Benchmarks();

        // verify simple sort
        bs.fill();
        int[] a1 = bs.a;
        int[] a2 = a1.clone();
        Arrays.sort(a2);
        heapSortSimple(a1);
        if (!Arrays.equals(a2, a1))
            throw new AssertionError();

        // let JIT to generate optimized assembly
        for (int i = 0; i < 10000; i++) {
            bs.fill();
            heapSortSimple(bs.a);
        }

        // verify unsafe sort
        bs.fill();
        a1 = bs.a;
        a2 = a1.clone();
        Arrays.sort(a2);
        heapSortUnsafe(a1);
        if (!Arrays.equals(a2, a1))
            throw new AssertionError();

        for (int i = 0; i < 10000; i++) {
            bs.fill();
            heapSortUnsafe(bs.a);
        }
    }

    static final Unsafe U;
    static final long INT_BASE;
    static final long INT_SCALE = 4;

    static {
        try {
            Field f = Unsafe.class.getDeclaredField("theUnsafe");
            f.setAccessible(true);
            U = (Unsafe) f.get(null);
        } catch (Exception e) {
            throw new IllegalStateException(e);
        }
        INT_BASE = U.arrayBaseOffset(int[].class);
    }
}

不安全版本在英特尔SB和AMD K10 CPU上始终比安全版本快13%

我查看了生成的汇编代码:

  • 所有索引操作(1-8)的下限检查都被消除了
  • 仅对操作5消除了上限检查,对2和3的检查合并了
  • 对于操作4(arr [0]),在每次迭代中都会检查arr.length!= 0

显然,所有边界检查分支都得到完美预测,这就是为什么使用简单索引的堆排序慢于不安全版本仅有13%。

我认为至少对于操作1、2和3,索引从小于数组长度的某个值稳定递减到零,JIT肯定会优化边界检查。问题是标题:为什么HotSpot JIT在这种情况下边界检查消除的效果不佳?


一旦涉及到多线程,就会发生疯狂的事情。例如,即使我们忽略NaN!= NaN,x == x也可能为false,因为x可以在两次访问之间更改 - Richard Tingle
@OlegEstekhin 在操作1、2和3之前,索引被检查为正数,我认为零边界检查已经被消除了,尽管如果JIT生成两个相同的检查,这将是无意义的。 - leventov
1
@RichardTingle arr 是排序方法中的本地变量。将 final 添加到本地变量不会影响运行时性能,因为它在字节码中没有反映。在排序方法的汇编中,arr 地址和 arr.length 都保存在寄存器中,并且在排序循环期间不会更新。 - leventov
@leventov,您使用的是哪个Java版本?您用于启动JVM的参数是什么? - Gábor Bakos
@GáborBakos 那个时刻的最新Java 8版本。参数 - 不记得了。 - leventov
显示剩余5条评论
1个回答

1
我不认为所有的格子都是受限制的。
传递零长度数组将导致 IOOB。在循环之前尝试使用if (n==0) return

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接