我的FB HackerCup代码对于大输入太慢了

16

我正在使用Python解决Facebook Hackercup上的查找最小值问题,我的代码可以很好地处理样本输入,但对于大输入(10^9),需要花费数小时才能完成。

那么,这个问题的解决方案是否可能无法在6分钟内使用Python计算?或者我的方法太糟糕了?

问题陈述:

在发送表情符号后,John决定玩玩数组。您知道黑客喜欢玩数组吗?John有一个从零开始索引的数组m,其中包含n个非负整数。然而,他只知道数组的前k个值,并且他想弄清剩下的值。

John知道以下内容:对于每个索引i,其中k <= i < nm [i]是不包含在前*k*m值中的最小非负整数。

例如,如果k = 3n = 4,并且已知m的值为[2, 3, 0],他可以推断出m[3] = 1
John非常忙碌,致力于使世界更加开放和连接,因此他没有时间去计算数组的其余部分。你的任务是帮助他。
给定m的前k个值,计算该数组的第n个值(即m[n - 1])。
由于nk的值可能非常大,我们使用伪随机数生成器来计算m的前k个值。给定正整数abcr,可以按以下方式计算m的已知值:
m[0] = a
m[i] = (b * m[i - 1] + c) % r, 0 < i < k

输入

  • 第一行包含一个整数T(T <= 20),表示测试用例的数量。

  • 接下来是T个测试用例,每个测试用例由两行组成。

  • 每个测试用例的第一行包含两个以空格分隔的整数n和k(1 <= k <= 10^5k < n <= 10^9)。

  • 每个测试用例的第二行包含四个以空格分隔的整数a、b、c、r(0 <= a、b、c <= 10^9,1 <= r <= 10^9)。

我尝试了两种方法,但都无法在6分钟内返回结果,以下是我的两种方法:

第一种:

import sys
cases=sys.stdin.readlines()
def func(line1,line2):
    n,k=map(int,line1.split())
    a,b,c,r =map(int,line2.split())
    m=[None]*n                     #initialize the list
    m[0]=a
    for i in xrange(1,k):          #set the first k values using the formula
        m[i]= (b * m[i - 1] + c) % r
    #print m    
    for j in range(0,n-k):         #now set the value of m[k], m[k+1],.. upto m[n-1]

        temp=set(m[j:k+j])     # create a set from the K values relative to current index
        i=-1                   #start at 0, lowest +ve integer
        while True:           
            i+=1
            if i not in temp:  #if that +ve integer is not present in temp
                m[k+j]=i       
                break

    return m[-1]

for ind,case in enumerate(xrange(1,len(cases),2)):
    ans=func(cases[case],cases[case+1])
    print "Case #{0}: {1}".format(ind+1,ans)  

第二个:

import sys
cases=sys.stdin.readlines()
def func(line1,line2):
    n,k=map(int,line1.split())
    a,b,c,r =map(int,line2.split())
    m=[None]*n                       #initialize
    m[0]=a                  
    for i in xrange(1,k):            #same as above          
        m[i]= (b * m[i - 1] + c) % r

    #instead of generating a set in each iteration , I used a 
    # dictionary this time.
    #Now, if the count of an item is 0 then it
    #means the item is not present in the previous K items
    #and can be added as the min value


    temp={}
    for x in m[0:k]:                   
        temp[x]=temp.get(x,0)+1       

    i=-1
    while True:
            i+=1
            if i not in temp:
                m[k]=i          #set the value of m[k]
                break
    for j in range(1,n-k):      #now set the values of m[k+1] to m[n-1]
        i=-1
        temp[m[j-1]] -= 1       #decrement it's value, as it is now out of K items
        temp[m[k+j-1]]=temp.get(m[k+j-1],0)+1   # new item added to the current K-1 items

        while True:
            i+=1
            if i not in temp or temp[i]==0:  #if i not found in dict or it's val is 0
                m[k+j]=i                     
                break

    return m[-1]

for ind,case in enumerate(xrange(1,len(cases),2)):
    ans=func(cases[case],cases[case+1])
    print "Case #{0}: {1}".format(ind+1,ans)  

第二种方法中的最后一个for循环也可以写成:
for j in range(1,n-k):
    i=-1
    temp[m[j-1]] -= 1
    if temp[m[j-1]]==0:
        temp.pop(m[j-1])      #same as above but pop the key this time
    temp[m[k+j-1]]=temp.get(m[k+j-1],0)+1

    while True:
        i+=1
        if i not in temp:
            m[k+j]=i
            break

示例输入:

5
97 39
34 37 656 97
186 75
68 16 539 186
137 49
48 17 461 137
98 59
6 30 524 98
46 18
7 11 9 46

输出:

Case #1: 8
Case #2: 38
Case #3: 41
Case #4: 40
Case #5: 12

我已经尝试过 codereview,但还没有人回复。


1
你应该将问题陈述复制到问题中。 - Alexey Frunze
1
你介意为我们解释一下这个问题算法的描述吗? :P - phant0m
鉴于输入的描述,如何可能存在10^9这样大的输入? - phant0m
1
哦,所以109是10^9,105是10^5吗?还是其他的正确?你可以使用<sup>9</sup>。 - phant0m
11
您可以将标题改写为“糟糕的算法对于大规模输入来说速度太慢”。 - mmgp
显示剩余2条评论
3个回答

14

运行至多k+1步后,数组中最后k+1个数字将是0...k(以某种顺序)。随后,序列可以预测:m[i] = m[i-k-1]。因此,解决这个问题的方法是运行你的朴素实现k+1步。然后你会得到一个有2k+1个元素的数组(前k个是从随机序列生成的,另外k+1个是从迭代中生成的)。

现在,最后k+1个元素将无限重复。因此,你可以立即返回m[n]的结果:m[k + (n-k-1) % (k+1)]

下面是一些实现代码。

import collections

def initial_seq(k, a, b, c, r):
    v = a
    for _ in xrange(k):
        yield v
        v = (b * v + c) % r

def find_min(n, k, a, b, c, r):
    m = [0] * (2 * k + 1)
    for i, v in enumerate(initial_seq(k, a, b, c, r)):
        m[i] = v
    ks = range(k+1)
    s = collections.Counter(m[:k])
    for i in xrange(k, len(m)):
        m[i] = next(j for j in ks if not s[j])
        ks.remove(m[i])
        s[m[i-k]] -= 1
    return m[k + (n - k - 1) % (k + 1)]


print find_min(97, 39, 34, 37, 656, 97)
print find_min(186, 75, 68, 16, 539, 186)
print find_min(137, 49, 48, 17, 461, 137)
print find_min(1000000000, 100000, 48, 17, 461, 137)

这四个案例在我的电脑上运行了4秒钟,最后一个案例具有可能的最大值n


为什么最后的 k+1 个数字要满足这个条件? - phant0m
现在尝试在1到10^9之间选择随机数作为a、b、c和d,它不应该起作用。您假设在前k个元素中存在一个小于等于k的数字。 - phant0m
@phantom 在一个包含k个数字的列表中,最小的非负整数总是在0和k之间。 - Paul Hankin
算了,我在想最大下界而不是最小的 -.- - phant0m
我不明白你在做什么,为什么在查找最小值时计算所有k的值,但是ks是什么?我如何预测其余的m? - Alexander Fuchs
2
这个程序运行得很好,但是对于 print find_min(1000000000, 100000,99999, 1, 99999, 100000) 我的系统崩溃了。但这肯定非常有帮助。+1 - Ashwini Chaudhary

12

以下是我的 O(k) 解决方案,基于与上述相同的思路,但运行速度更快。

import os, sys

f = open(sys.argv[1], 'r')

T = int(f.readline())

def next(ary, start):
    j = start
    l = len(ary)
    ret = start - 1
    while j < l and ary[j]:
        ret = j
        j += 1
    return ret

for t in range(T):
    n, k = map(int, f.readline().strip().split(' '))
    a, b, c, r = map(int, f.readline().strip().split(' '))

    m = [0] * (4 * k)
    s = [0] * (k+1)
    m[0] = a
    if m[0] <= k:
        s[m[0]] = 1
    for i in xrange(1, k):
        m[i] = (b * m[i-1] + c) % r
        if m[i] < k+1:
            s[m[i]] += 1

    p = next(s, 0)
    m[k] = p + 1
    p = next(s, p+2)

    for i in xrange(k+1, n):
        if m[i-k-1] > p or s[m[i-k-1]] > 1:
            m[i] = p + 1
            if m[i-k-1] <= k:
                s[m[i-k-1]] -= 1
            s[m[i]] += 1
            p = next(s, p+2)
        else:
            m[i] = m[i-k-1]
        if p == k:
            break

    if p != k:
        print 'Case #%d: %d' % (t+1, m[n-1])
    else:
        print 'Case #%d: %d' % (t+1, m[i-k + (n-i+k+k) % (k+1)])

关键点在于,m[i]永远不会超过k,如果我们记住了0到p之间前k个数字中找到的连续数字,那么p就永远不会减少。
如果数字m[i-k-1]大于p,则显然应该将m[i]设置为p+1,并且p至少会增加1。
如果数字m[i-k-1]小于或等于p,则我们应该考虑在m[i-k:i]中是否存在相同的数字,如果不存在,则m[i]应该设置为m[i-k-1],如果存在,则应像“m[i-k-1]-larger-than-p”情况一样将m[i]设置为p+1。
每当p等于k时,循环开始,循环大小为(k+1),因此我们现在可以跳出计算并打印出答案。

0

我通过添加地图来提高性能。

import sys, os
import collections

def min(str1, str2):
    para1 = str1.split()
    para2 = str2.split()

    n = int(para1[0])
    k = int(para1[1])
    a = int(para2[0])
    b = int(para2[1])
    c = int(para2[2])
    r = int(para2[3])

    m = [0] * (2*k+1)
    m[0] = a

    s = collections.Counter()

    s[a] += 1
    rs = {}
    for i in range(k+1):
        rs[i] = 1

    for i in xrange(1,k):
        v = (b * m[i - 1] + c) % r
        m[i] = v
        s[v] += 1
        if v < k:
            if v in rs:
                rs[v] -= 1
                if rs[v] == 0:
                    del rs[v]

    for j in xrange(0,k+1):
        for t in rs:
            if not s[t]:
                m[k+j] = t
                if m[j] < k:
                    if m[j] in rs:
                        rs[m[j]] += 1
                    else:
                        rs[m[j]] = 0

                rs[t] -= 1
                if rs[t] == 0:
                    del rs[t]

                s[t] = 1
                break

        s[m[j]] -= 1

    return m[k + ((n-k-1)%(k+1))]

if __name__=='__main__':
    lines = []
    user_input = raw_input()
    num = int(user_input)

    for i in xrange(num):
        input1 = raw_input()
        input2 = raw_input()
        print "Case #%s: %s"%(i+1, min(input1, input2))

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接