如何让程序更快?[Keypad_Sticky_Note]

3

按键式便签

小黄人们将一些布尔教授的秘密安全地锁在了某处。或者他们这么认为。实际上,他们非常自信,甚至在锁的按键上贴了一个密码提示便利贴。

这个锁需要你在按键上输入一对非负整数 (a, b)。由于这些整数可能高达 20 亿,因此你需要查看提示便利贴以获得帮助。

提示便利贴上写有两个数字,但即使是小黄人也知道不要把密码放在那里。他们实际上写下了密码整数对 (a, b) 的和(称其为 s)和按位异或(xor,标记为 x)。这样,他们只需要记住一个数字。如果他们对减法感到困难,他们可以使用按位异或。

即,我们有 s = a+b 和 x = a^b(其中 ^ 是按位异或运算)。

使用你的自动化黑客设备,每次猜测尝试需要几毫秒。由于你很快就会被发现,所以你想知道在尝试所有组合之前可能需要多长时间。由于提示便利贴,你现在可以在不必要输入它们到按键上的情况下消除某些组合,并且可以在最坏情况下找出破解锁所需的确切时间。

编写一个名为 answer(s, x) 的函数,该函数查找具有目标和和异或的数字对 (a, b) 的数量。

例如,如果 s=10 和 x=4,则 (a, b) 的可能值为 (3, 7) 和 (7, 3),因此 answer 将返回 2。

如果 s=5 和 x=3,则没有可能的值,因此 answer 将返回 0。

s 和 x 至少为 0,最多为 20 亿。

语言

要提供 Python 解决方案,请编辑 solution.py 要提供 Java 解决方案,请编辑 solution.java

测试用例

输入: (int) s = 10 (int) x = 4 输出: (int) 2

输入: (int) s = 0 (int) x = 0 输出: (int) 1

public static int answer(int s, int x) {
    List<Integer> num = new ArrayList<>();
    int a;
    int b;
    int sum;
    int finalans;

    for(int i = 0; i <=s; i++){
        for(int e = 0; e <= s; e++){
            sum = i + e;
            if(sum == s){
                if((i^e) == x){
                    if(!num.contains(i)){
                        num.add(i);
                    }
                    if(!num.contains(e)){
                        num.add(e);
                    }
                }
            }
        }
    }

    finalans = num.size();
    if((finalans%2) == 0){
        return finalans*2;
    } else if(!((finalans%2) == 0)){
        return finalans;
    }
    return 0;

}

我的代码可以运行,但是当s和x变得很大时,程序运行时间太长了。我该如何让这个程序运行更快?


这个任务来自哪里? - user2418306
循环花费了太长时间... - MAC
1
Google foobar第3级挑战 - rkatakam
3
由于这两个操作都是交换律,所以您不需要嵌套循环。从[0; s]中取出 a,计算 b = s - a,检查异或,并在必要时添加到列表中。返回 list.size() * 2 - user2418306
你是怎么收到邀请的? - user2418306
3
通常情况下,你不能仅仅编写一个朴素的解决方案然后进行优化来解决算法问题。你必须先想出一种高效的方法,然后再进行编码。难点在于思考,而不是编码。 - Paul Hankin
5个回答

1
你可以通过意识到对于一个输入状态(异或数字,和数字,进位)只有有限数量的输出状态(进位)。您可以使用递归来计算组合的总数,并使用记忆化使递归更有效。您可以使用if条件来处理每个状态。下面是我的解决方案,可以在O(m)时间内解决问题,其中m是您数字数据类型中二进制位数的数量。由于问题规定m = 32(整数),因此这实际上是一个O(1)解决方案。

如果您有任何问题,请告诉我。我尝试在代码中添加有用的注释以解释各种情况。

public class SumAndXor {
    public static void main(String[] args) {
        int a = 3;
        int b = 7;

        int sum = a + b;
        int xor = a ^ b;

        System.out.println(answer(sum, xor));
    }

    private static final int NOT_SET = -1;

    // Driver
    public static int answer(int sum, int xor) {
        int numBitsPerInt = Integer.toBinaryString(Integer.MAX_VALUE).length() + 1;
        int[][] cache = new int[numBitsPerInt][2];

        for (int i = 0; i < numBitsPerInt; ++i) {
            cache[i][0] = NOT_SET;
            cache[i][1] = NOT_SET;
        }

        return answer(sum, xor, 0, 0, cache);
    }

    // Recursive helper
    public static int answer(int sum, int xor, int carry, int index, int[][] cache) {
        // Return memoized value if available
        if (cache[index][carry] != NOT_SET) {
            return cache[index][carry];
        }

        // Base case: nothing else to process
        if ((sum >> index) == 0 && (xor >> index) == 0 && carry == 0) {
            return 1;
        }

        // Get least significant bits
        int sumLSB = (sum >> index) & 1;
        int xorLSB = (xor >> index) & 1;

        // Recursion
        int result = 0;

        if (carry == 0) {
            if (xorLSB == 0 && sumLSB == 0) {
                // Since the XOR is 0, the binary digits are either [0, 0] or [1, 1]. Since the
                // sum is 0 and the incoming carry is 0, both [0, 0] and [1, 1] are valid. We
                // recurse with a carry of 0 to represent [0, 0], and we recurse with a carry of
                // 1 to represent [1, 1].
                result = answer(sum, xor, 0, index + 1, cache) + answer(sum, xor, 1, index + 1, cache);
            } else if (xorLSB == 0 && sumLSB == 1) {
                // Since the XOR is 0, the binary digits are either [0, 0] or [1, 1]. Since the
                // sum is 1 and the incoming carry is 0, neither [0, 0] nor [1, 1] is valid.
                result = 0;
            } else if (xorLSB == 1 && sumLSB == 0) {
                // Since the XOR is 1, the binary digits are either [0, 1] or [1, 0]. Since the
                // sum is 0 and the incoming carry is 0, neither [0, 1] nor [1, 0] is valid.
                result = 0;
            } else if (xorLSB == 1 && sumLSB == 1) {
                // Since the XOR is 1, the binary digits are either [0, 1] or [1, 0]. Since the
                // sum is 1 and the incoming carry is 0, both [0, 1] and [1, 0] is valid. We
                // recurse with a carry of 0 to represent [0, 1], and we recurse with a carry
                // of 0 to represent [1, 0].
                result = 2 * answer(sum, xor, 0, index + 1, cache);
            }
        } else {
            if (xorLSB == 0 && sumLSB == 0) {
                // Since the XOR is 0, the binary digits are either [0, 0] or [1, 1]. Since the
                // sum is 0 and the incoming carry is 1, neither [0, 0] nor [1, 1] is valid.
                result = 0;
            } else if (xorLSB == 0 && sumLSB == 1) {
                // Since the XOR is 0, the binary digits are either [0, 0] or [1, 1]. Since the
                // sum is 1 and the incoming carry is 1, both [0, 0] and [1, 1] are valid. We
                // recurse with a carry of 0 to represent [0, 0], and we recurse with a carry of
                // 1 to represent [1, 1].
                result = answer(sum, xor, 0, index + 1, cache) + answer(sum, xor, 1, index + 1, cache);
            } else if (xorLSB == 1 && sumLSB == 0) {
                // Since the XOR is 1, the binary digits are either [0, 1] or [1, 0]. Since the
                // sum is 0 and the incoming carry is 1, both [0, 1] and [1, 0] are valid. We
                // recurse with a carry of 0 to represent [0, 1], and we recurse with a carry
                // of 0 to represent [1, 0].
                result = 2 * answer(sum, xor, 1, index + 1, cache);
            } else if (xorLSB == 1 && sumLSB == 1) {
                // Since the XOR is 1, the binary digits are either [0, 1] or [1, 0]. Since the
                // sum is 1 and the incoming carry is 1, neither [0, 1] nor [1, 0] is valid.
                result = 0;
            }
        }

        cache[index][carry] = result;

        return result;
    }
}

1
谷歌说它运行的算法是O(n^2),但谷歌希望它变成O(lg n),所以这需要太长时间。如果你问我,这个等级3的挑战对我来说太难了。我曾经做过更简单的等级4题目。解决这个问题的方法与你想象的完全不同。事实上,在正确答案中,你甚至不会给(a,b)赋值,也不会将(a,b)与(S,x)进行比较。直到你看到并理解解决方案,才能反其道而行之。
无论如何,将正确答案绘制在二维图表或Excel电子表格中(使用S作为行,x作为列,将零留空),可以有所帮助。然后,寻找模式。数据点实际上形成了一个谢尔宾斯基三角形(参见http://en.wikipedia.org/wiki/Sierpinski_triangle)。
您还会注意到,每个列中大于零的数据点在该列的所有实例中都相同,因此,给定您的x值,只要对应于您的S值的行与三角形中的数据点相交,您就自动知道最终答案应该是什么。 您只需要确定S值(行)是否在列x处与三角形相交。明白了吗?
即使列中的值也形成了从0到x的模式:1、2、2、4、2、4、4、8、2、4、4、8、4、8、8、16...我相信您可以弄清楚。
这是“给定x的最终值”方法以及大部分剩余的代码(使用Python…Java过于冗长和复杂)。 您只需要编写三角形遍历算法(我不会透露这个,但这是正确方向上的坚实推动)。
def final(x, t):
    if x > 0:
        if x % 2: # x is odd
            return final(x / 2, t * 2)
        else: # x is even
            return final(x / 2, t)
    else:
        return t

def mid(l, r):
    return (l + r) / 2

def sierpinski_traverse(s_mod_xms, x. lo, hi, e, l, r):
    # you can do this in 16 lines of code to end with...
    if intersect:
        # always start with a t-value of 1 when first calling final in case x=0
        return final(x, 1)
    else:
        return 0

def answer(s, x):
    print final(x, 1)

    if s < 0 or x < 0 or s > 2000000000 or x > 2000000000 or s < x or s % 2 != x % 2:
        return 0
    if x == 0:
        return 1

    x_modulus_size = 2 ** int(math.log(x, 2) + 2)
    s_mod_xms = s % x_modulus_size
    lo_root = x_modulus_size / 4
    hi_root = x_modulus_size / 2
    exp = x_modulus_size / 4    # exponent of 2 (e.g. 2 ** exp)

    return sierpinski_traverse(s_mod_xms, x, lo_root, hi_root, exp, exp, 2 * exp)


if __name__ == '__main__':
    answer(10, 4)

我已经完成了这个挑战。如果你还是想不出来,给我发一条消息,我会把代码发给你。 - jppodo
谣言是真的吗?通过第三关有什么“奖励”吗? - Patrick Zawadzki
您将有机会与招聘人员交谈。这基本上是一种快捷方式,可以让您的简历排在队列前面。 - jppodo

0

尝试将num更改为HashSet。您还可以清理末尾的if/else。

例如:

public static int answer(int s, int x) {
    HashSet<Integer> num = new HashSet<>();
    int a;
    int b;
    int sum;
    int finalans;

    for(int i = 0; i <=s; i++){
        for(int e = 0; e <= s; e++){
            sum = i + e;
            if(sum == s){
                if((i^e) == x){
                    num.add(i);
                    num.add(e);
                }
            }
        }
    }

    finalans = num.size();
    if((finalans%2) == 0){
        return finalans*2;
    } else {
        return finalans;
    }        
}

这仍然超过了Google Foobar挑战的时间限制。这并没有真正产生影响。 - rkatakam

0

你的算法中大部分步骤都执行了过多的工作:

  • 你对所有非负整数进行了线性扫描,直到s。由于问题是对称的,扫描到s/2就足够了。
  • 你进行了第二次线性扫描,找到每个a对应的另一个整数b,满足a+b=s。简单的代数运算表明,只有一个这样的b,即s-a,因此根本不需要进行线性扫描。
  • 你进行了第三次线性扫描,检查是否已经找到一对(a,b)。如果只循环到s/2,则始终保持a≤b,因此不会发生重复计数。

最后,我可以想出一个简单的优化来节省一些工作:

如果 s 是偶数,则 ab 都是偶数或者都是奇数。因此,在这种情况下,a ^ b 是偶数。
如果 s 是奇数,则 ab 中至少有一个是奇数,因此 a ^ b 是奇数。
在执行任何操作之前,您可以添加此检查。
public static int answer(int s, int x) {
    int result = 0;
    if (s % 2 == x % 2) {
        for (int a = 0; a <= s / 2; a++) {
            int b = s - a;
            if ((a ^ b) == x) {
                result += 2;
            }
        }
        // we might have double counted the pair (s/2, s/2)
        // decrement the count if needed
        if (s % 2 == 0 && ((s / 2) ^ (s / 2)) == x) {
            result--;
        }
    }
    return result;
}

编译仍然需要太长时间,而且在某些情况下无法解决问题。我不得不修改它,以使程序不会因为异或运算符而出错,然后它就可以工作了,只是仍然需要很长时间。 - rkatakam
这些就是我所做的确切修改,但你的算法在我的IDE中似乎运行得足够快,但在挑战中,它会提示“时间超限”。 - rkatakam

0
为了更好地解释我的先前答案,请看整个图像...确切地说是三角形遍历算法,它的工作原理类似于二分查找,只不过有三个选项而不是两个(“三元”搜索?)。查看最大三角形内包含S和x的3个最大三角形。然后,选择包含S和x的三角形。接下来,在新选择的三角形中查看三个最大的三角形,并选择包含S和x的那一个。重复此过程,直到到达单个点。如果该点不为零,则返回我指定的“最终”值。在选择三角形并且行S不与数据点相交时,还有一些if-else语句可以加快速度。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接