CodeSprint 2的补码挑战运行速度过慢

3

在原始的InterviewStreet Codesprint上,有一个问题需要计算两个补码表示之间数字中1的数量。我通过迭代方式通过了所有精度测试用例,但只能在正确的时间内通过两个测试用例。有一个提示提到可以找到递推关系,所以我尝试使用递归,但最终它花费的时间与迭代相同。所以,是否有人可以找到比我提供的代码更快的方法?输入文件的第一个数是文件中测试用例的数目。我已经在代码后面提供了一个示例输入文件。

import java.util.Scanner;

public class Solution {

    public static void main(String[] args) {

        Scanner scanner = new Scanner(System.in);
        int numCases = scanner.nextInt();
        for (int i = 0; i < numCases; i++) {
            int a = scanner.nextInt();
            int b = scanner.nextInt();
            System.out.println(count(a, b));
        }
    }

    /**
     * Returns the number of ones between a and b inclusive
     */
    public static int count(int a, int b) {
        int count = 0;
        for (int i = a; i <= b; i++) {
            if (i < 0)
                count += (32 - countOnes((-i) - 1, 0));
            else
                count += countOnes(i, 0);
        }

        return count;
    }

    /**
     * Returns the number of ones in a
     */
    public static int countOnes(int a, int count) {
        if (a == 0)
            return count;
        if (a % 2 == 0)
            return countOnes(a / 2, count);
        else
            return countOnes((a - 1) / 2, count + 1);
    }
}

输入:

3
-2 0
-3 4
-1 4

Output:
63
99
37

你试过这个技巧吗? - Yu-Han Lyu
1个回答

2
第一步是替换
public static int countOnes(int a, int count) {
    if (a == 0)
        return count;
    if (a % 2 == 0)
        return countOnes(a / 2, count);
    else
        return countOnes((a - 1) / 2, count + 1);
}

使用更快的实现方式,例如著名的位操作技巧,可以使递归深度达到log2 a。

public static int popCount(int n) {
    // count the set bits in each bit-pair
    // 11 -> 10, 10 -> 01, 0* -> 0*
    n -= (n >>> 1) & 0x55555555;
    // count bits in each nibble
    n = ((n >>> 2) & 0x33333333) + (n & 0x33333333);
    // count bits in each byte
    n = ((n >> 4) & 0x0F0F0F0F) + (n & 0x0F0F0F0F);
    // accumulate the counts in the highest byte and shift
    return (0x01010101 * n) >> 24;
    // Java guarantees wrap-around, so we can use int here,
    // in C, one would need to use unsigned or a 64-bit type
    // to avoid undefined behaviour
}

这段代码使用了四次移位、五次按位与运算、一次减法、两次加法和一次乘法,共计十三个非常廉价的指令。

但是,除非范围非常小,否则我们可以比数每个数字的位数更好地完成任务。

首先考虑非负数。从0到2k-1的数字最多有k个位被设置为1。每个位在这些数字中都恰好出现一半次,因此总位数为k*2^(k-1)。现在假设2^k <= a < 2^(k+1)。数字0 <= n <= a中的总位数等于数字0 <= n < 2^k中的位数与数字2^k <= n <= a中的位数之和。第一个数的位数如上所述为k*2^(k-1)。在第二部分中,我们有a - 2^k + 1个数字,每个数字都设置了2k位,并忽略前导位,它们的位与数字0 <= n <= (a - 2^k)中的位相同,因此

totalBits(a) = k*2^(k-1) + (a - 2^k + 1) + totalBits(a - 2^k)

现在来讨论负数。在二进制补码中,-(n+1) = ~n,所以数字 -a <= n <= -1 是数字 0 <= m <= (a-1) 的补码,且数字 -a <= n <= -1 中设置的位数总数为 a*32 - totalBits(a-1)
对于范围内的位数总数 a <= n <= b,我们需要加或减,具体取决于范围两端的符号是否相同。
// if n >= 0, return the total of set bits for
// the numbers 0 <= k <= n
// if n < 0, return the total of set bits for
// the numbers n <= k <= -1
public static long totalBits(int n){
    if (n < 0) {
        long a = -(long)n;
        return (a*32 - totalBits((int)(a-1)));
    }
    if (n < 3) return n;
    int lg = 0, mask = n;
    // find the highest set bit in n and its position
    while(mask > 1){
        ++lg;
        mask >>= 1;
    }
    mask = 1 << lg;
    // total bit count for 0 <= k < 2^lg
    long total = 1L << lg-1;
    total *= lg;
    // add number of 2^lg bits
    total += n+1-mask;
    // add number of other bits for 2^lg <= k <= n
    total += totalBits(n-mask);
    return total;
}

// return total set bits for the numbers a <= n <= b
public static long totalBits(int a, int b) {
    if (b < a) throw new IllegalArgumentException("Invalid range");
    if (a == b) return popCount(a);
    if (b == 0) return totalBits(a);
    if (b < 0) return totalBits(a) - totalBits(b+1);
    if (a == 0) return totalBits(b);
    if (a > 0) return totalBits(b) - totalBits(a-1);
    // Now a < 0 < b
    return totalBits(a) + totalBits(b);
}

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接