如何进行不重复的逐步抽样?

32

Python中有my_sample = random.sample(range(100), 10)以从[0,100)的范围内无重复随机采样。

假设我已经随机采样了n个数,现在我想要再采样一个数而不重复(不包括之前采样的n个数),如何高效地实现?

更新:从“相对高效”更改为“超级高效”(但忽略常数因子)


1
你是否只想在 [0, x) 范围内采样整数?你期望的 x 是多少? - Chronial
[0, n) 对我来说完全可行。我可以让任何问题适应它。 - necromancer
这是你需要的吗?让另一个问题适应它会花费大量时间,并且考虑到您所要求的紧密边界,这非常重要。 - Chronial
1
你可能想查看 random.sample 的源代码。 - Eric
2
这个线程真是太棒了!一个简单的问题最终需要支付300分赏金来表达对惊人答案的感激之情。其中一个人提供了4个答案,另一个人提供了3个答案,而原帖作者提供了一个作为正确答案基础的回答。还有一个惊人的类似论文的答案,实际上包含了多个子答案。希望大家都能得出满意的结论。谢谢大家。 :-) - necromancer
13个回答

26

如果您事先知道想要多个不重叠的样本,最简单的方法是在list(range(100))上执行random.shuffle()(Python 3中可以跳过Python 2中的list()),然后根据需要削片。

s = list(range(100))
random.shuffle(s)
first_sample = s[-10:]
del s[-10:]
second_sample = s[-10:]
del s[-10:]
# etc

否则,@Chronial的答案相当有效。


+1 谢谢,我仍在寻找完美的解决方案,可以传入之前采样对象的列表,并且能够“智能地”(@Chronial的答案使用了蛮力)采样下一个对象。 - necromancer
@no_answer_not_upvoted 我认为你已经看到了所有的可能性。如果你的range()较小,请使用我的答案;如果你的样本大小较小,请使用Veedrac的答案。如果它们都很大,请在你的问题中说明,并希望有人给你一个更复杂的算法。但请注意,在前两种情况下这会更慢。 - Chronial
2
@no_answer_not_upvoted,祝你好运;-) 它需要检查您传入的每个“禁止”值,并且需要检查基本列表中的每个值,以确保每个值都不在禁止列表中。如果没有更智能的数据结构作为起点,它至少需要花费与两个列表大小之和成正比的时间。 - Tim Peters
2
@Chronial,不,这里删除非常快:这就是为什么它从列表的末尾删除。不需要移动任何项目。CPython只是减少了列表的长度,并且在尾部解除引用指针。 - Tim Peters
@Chronial,我在谈论你的答案中的集合差异。这个答案在适用的情况下是可以的。 - necromancer
显示剩余3条评论

11

读者注意:请先查看原始答案以了解逻辑,然后再理解本答案。

为了完整起见:这是necromancer的答案的概念,但改编成以禁止数字列表作为输入。这与我之前的答案中的代码相同,但我们在生成数字之前从forbid构建状态。

  • 这需要时间O(f+k)和内存O(f+k)。显然,如果没有对forbid格式(排序/集合)的要求,则这是可能的最快速度。我认为这在某种程度上使其成为赢家^^。
  • 如果forbid是一个集合,则重复猜测方法更快,时间复杂度为O(k⋅n/(n-(f+k))),当f+k不是非常接近n时,这非常接近于O(k)
  • 如果forbid已排序,则我的荒谬算法更快,时间复杂度为:
    O(k⋅(log(f+k)+log²(n/(n-(f+k))))
import random
def sample_gen(n, forbid):
    state = dict()
    track = dict()
    for (i, o) in enumerate(forbid):
        x = track.get(o, o)
        t = state.get(n-i-1, n-i-1)
        state[x] = t
        track[t] = x
        state.pop(n-i-1, None)
        track.pop(o, None)
    del track
    for remaining in xrange(n-len(forbid), 0, -1):
        i = random.randrange(remaining)
        yield state.get(i, i)
        state[i] = state.get(remaining - 1, remaining - 1)
        state.pop(remaining - 1, None)

使用方法:

gen = sample_gen(10, [1, 2, 4, 8])
print gen.next()
print gen.next()
print gen.next()
print gen.next()

10

简便方法

如果样本数量远小于总体数量,只需抽样,检查是否已被选择,并在此过程中重复。这听起来可能很傻,但您有指数级递减的选择相同数字的可能性,因此如果有一小部分未选择,则比O(n)快得多。


漫长的道路

Python使用Mersenne Twister作为其伪随机数生成器,这是好的足够好的。我们可以完全使用其他方法以能够以可预测的方式生成不重叠的数字。

这是秘密:

  • 二次剩余,x² mod p,当 2x < pp 是素数时是唯一的。

  • 如果你"翻转"剩余部分,p - (x² % p), 同时也满足 p = 3 mod 4,结果将是其余空间。

  • 这不是一个非常令人信服的数字分布,所以您可以增加功率,添加一些升级常数,然后分布就很好了。


首先我们需要生成质数:

from itertools import count
from math import ceil
from random import randrange

def modprime_at_least(number):
    if number <= 2:
        return 2

    number = (number // 4 * 4) + 3
    for number in count(number, 4):
        if all(number % factor for factor in range(3, ceil(number ** 0.5)+1, 2)):
            return number

你可能会担心生成质数的成本。对于10⁶个元素,这需要1/10毫秒的时间。运行[None] * 10**6需要更长的时间,因为它只计算一次,所以这不是一个真正的问题。
此外,该算法不需要质数的精确值;它只需要比输入数字大一个常数因子的东西。这可以通过保存值列表并搜索它们来实现。如果进行线性扫描,那么时间复杂度为O(log number),如果进行二进制搜索,则为O(log number of cached primes)。事实上,如果使用galloping,您可以将其降至O(log log number),这基本上是常数(log log googol = 2)。
然后我们实现生成器。
def sample_generator(up_to):
    prime = modprime_at_least(up_to+1)

    # Fudge to make it less predictable
    fudge_power = 2**randrange(7, 11)
    fudge_constant = randrange(prime//2, prime)
    fudge_factor = randrange(prime//2, prime)

    def permute(x):
        permuted = pow(x, fudge_power, prime) 
        return permuted if 2*x <= prime else prime - permuted

    for x in range(prime):
        res = (permute(x) + fudge_constant) % prime
        res = permute((res * fudge_factor) % prime)

        if res < up_to:
            yield res

并检查它是否正常工作:

set(sample_generator(10000)) ^ set(range(10000))
#>>> set()

现在,这个算法的可爱之处在于,如果忽略主要性测试(大约为O(√n),其中n是元素数量),它的时间复杂度为O(k),其中k是样本大小,而且内存使用为O(1)!从技术上讲,这是O(√n + k),但实际上它是O(k)


要求:

  1. 您不需要一个经过验证的伪随机数生成器。这个伪随机数生成器比线性同余生成器(它很受欢迎;Java使用它)好得多,但它没有Mersenne Twister那么可靠。

  2. 您不需要先用其他函数生成任何项。这通过数学避免了重复,而不是检查。下一节中我将展示如何消除此限制。

  3. 短方法必须不足够(k 必须接近 n)。如果 k 只有 n 的一半,就按照我的最初建议进行。

优势:

  1. 极大地节省内存。这需要恒定的内存……甚至不到 O(k)

  2. 生成下一个项目所需的时间是恒定的。从恒定的角度来看,这实际上相当快:它不像内置的 Mersenne Twister 那样快,但它在2倍以内。

  3. 酷。


要删除此要求:

您不首先使用其他函数生成任何项。这通过数学避免重复,而不是检查。

我已经制作了最佳的算法,在时间和空间复杂度上都是最简单的扩展。 这是我的先前生成器的简单扩展。

以下是摘要(n是数字池的长度,k是“外部”键的数量):

初始化时间O(√n); 对于所有合理的输入,O(log log n)

这是我的算法在算法复杂性方面技术上唯一不完美的因素,由于O(√n)成本。实际上,这不会成为问题,因为预计算将其降至O(log log n),这几乎等同于常数时间。

如果您以任何固定百分比耗尽可迭代物,则成本为分摊免费。

这不是一个实际的问题。

分摊O(1)密钥生成时间

显然,这是无法改进的。

最坏情况下O(k)密钥生成时间

如果您有从外部生成的密钥,只要满足它不是此生成器已生成的密钥,这些就被称为“外部密钥”。 外部密钥被假定为完全随机。 因此,任何能够从池中选择项目的函数都可以这样做。
由于外部密钥可以有任意数量且可以是完全随机的,因此最佳算法的最坏情况为O(k)
最坏情况下空间复杂度为O(k) 如果假定外部密钥是完全独立的,则每个密钥表示一个不同的信息项。 因此,所有密钥必须存储。 算法恰好在看到密钥时会丢弃密钥,因此随着生成器的寿命,内存成本将清除。
算法
好吧,它是我的两个算法。 它实际上非常简单:
def sample_generator(up_to, previously_chosen=set(), *, prune=True):
    prime = modprime_at_least(up_to+1)

    # Fudge to make it less predictable
    fudge_power = 2**randrange(7, 11)
    fudge_constant = randrange(prime//2, prime)
    fudge_factor = randrange(prime//2, prime)

    def permute(x):
        permuted = pow(x, fudge_power, prime) 
        return permuted if 2*x <= prime else prime - permuted

    for x in range(prime):
        res = (permute(x) + fudge_constant) % prime
        res = permute((res * fudge_factor) % prime)

        if res in previously_chosen:
            if prune:
                previously_chosen.remove(res)

        elif res < up_to:
            yield res

更改很简单,只需添加:
if res in previously_chosen:
    previously_chosen.remove(res)

您可以通过添加到传递的set中随时向previously_chosen添加内容。实际上,您还可以从集合中删除以便将其重新添加到潜在池中,但仅当sample_generator尚未使用prune=False跳过它时才起作用。

所以这就是它的全部内容。很容易看出它满足所有要求,并且很容易看出这些要求是绝对的。请注意,如果您没有一个集合,它仍然通过将输入转换为集合来满足最坏情况,尽管它会增加开销。


测试RNG的质量

我很好奇这个PRNG从统计上来说到底有多好。

一些快速搜索引导我创建了以下三个测试,它们似乎都显示出良好的结果!

首先是一些随机数:

N = 1000000

my_gen = list(sample_generator(N))

target = list(range(N))
random.shuffle(target)

control = list(range(N))
random.shuffle(control)

这些是包含100万个数字的“洗牌”列表,范围从010⁶-1,其中一个使用我们有趣的伪随机数生成器,另一个使用梅森旋转算法作为基准。第三个是对照组。
这是一个测试,它考察了沿着直线的两个随机数之间的平均距离。将差异与对照组进行比较:
from collections import Counter

def birthdat_calc(randoms):
    return Counter(abs(r1-r2)//10000 for r1, r2 in zip(randoms, randoms[1:]))

def birthday_compare(randoms_1, randoms_2):
    birthday_1 = sorted(birthdat_calc(randoms_1).items())
    birthday_2 = sorted(birthdat_calc(randoms_2).items())

    return sum(abs(n1 - n2) for (i1, n1), (i2, n2) in zip(birthday_1, birthday_2))

print(birthday_compare(my_gen, target), birthday_compare(control, target))
#>>> 9514 10136

这比每个方差都要小。


这是一个测试,它依次获取5个数字并查看元素的顺序。它们应该均匀分布在所有120种可能的顺序中。
def permutations_calc(randoms):
    permutations = Counter()        

    for items in zip(*[iter(randoms)]*5):
        sorteditems = sorted(items)
        permutations[tuple(sorteditems.index(item) for item in items)] += 1

    return permutations

def permutations_compare(randoms_1, randoms_2):
    permutations_1 = permutations_calc(randoms_1)
    permutations_2 = permutations_calc(randoms_2)

    keys = sorted(permutations_1.keys() | permutations_2.keys())

    return sum(abs(permutations_1[key] - permutations_2[key]) for key in keys)

print(permutations_compare(my_gen, target), permutations_compare(control, target))
#>>> 5324 5368

这仍然比每个方差都小。
这是一个测试,用于检测“运行”的长度,也就是连续递增或递减的部分。
def runs_calc(randoms):
    runs = Counter()

    run = 0
    for item in randoms:
        if run == 0:
            run = 1

        elif run == 1:
            run = 2
            increasing = item > last

        else:
            if (item > last) == increasing:
                run += 1

            else:
                runs[run] += 1
                run = 0

        last = item

    return runs

def runs_compare(randoms_1, randoms_2):
    runs_1 = runs_calc(randoms_1)
    runs_2 = runs_calc(randoms_2)

    keys = sorted(runs_1.keys() | runs_2.keys())

    return sum(abs(runs_1[key] - runs_2[key]) for key in keys)

print(runs_compare(my_gen, target), runs_compare(control, target))
#>>> 1270 975

这里的方差非常大,在多次执行中,我看到了两者都比较均匀的分布。因此,此测试通过。


有人向我提到了线性同余生成器,称其可能会“更有成果”。我自己实现了一个糟糕的线性同余生成器,以验证这个说法是否准确。

据我所知,LCG与普通生成器一样,并不是为了循环而设计的。因此,我查阅的大多数参考资料(如维基百科)仅涵盖了定义周期的内容,而没有介绍如何制作具有特定周期的强大LCG。这可能会影响结果。

下面开始:

from operator import mul
from functools import reduce

# Credit https://dev59.com/RmQn5IYBdhLWcg3wPlCj#16996439
# Meta: Also Tobias Kienzler seems to have credit for my
#       edit to the post, what's up with that?
def factors(n):
    d = 2
    while d**2 <= n:
        while not n % d:
            yield d
            n //= d
        d += 1
    if n > 1:
       yield n

def sample_generator3(up_to):
    for modulier in count(up_to):
        modulier_factors = set(factors(modulier))
        multiplier = reduce(mul, modulier_factors)
        if not modulier % 4:
            multiplier *= 2

        if multiplier < modulier - 1:
            multiplier += 1
            break

    x = randrange(0, up_to)

    fudge_constant = random.randrange(0, modulier)
    for modfact in modulier_factors:
        while not fudge_constant % modfact:
            fudge_constant //= modfact

    for _ in range(modulier):
        if x < up_to:
            yield x

        x = (x * multiplier + fudge_constant) % modulier

我们不再检查质数,但我们需要对因子做一些奇怪的事情。

  • modulier ≥ up_to > multiplier, fudge_constant > 0
  • a - 1 必须被 modulier 中的每个因子整除...
  • ...而 fudge_constant 必须与 modulier 互质

请注意,这些不是线性同余发生器(LCG)的规则,而是具有完整周期的 LCG 规则,其显然等于模数(modulier)。

我是这样做的:

尝试至少使用每个modulier,当满足条件时停止。
  • 生成其因子的集合,
  • multiplier成为去除重复项的的乘积
  • 如果multiplier不小于modulier,则继续下一个modulier
  • fudge_constant成为小于modulier的随机选择数字
  • fudge_constant中删除在中的因子
这不是生成它的非常好的方法,但我不明白为什么它会影响数字的质量,除了低fudge_constantmultiplier比完美生成器更常见。
无论如何,结果令人震惊:
print(birthday_compare(lcg, target), birthday_compare(control, target))
#>>> 22532 10650

print(permutations_compare(lcg, target), permutations_compare(control, target))
#>>> 17968 5820

print(runs_compare(lcg, target), runs_compare(control, target))
#>>> 8320 662

总之,我的随机数生成器很好,而线性同余发生器不好。考虑到Java使用线性同余发生器(尽管只使用低位),我认为我的版本应该足够了。

9

好的,我们开始吧。这应该是最快的非概率算法。它的运行时间为O(k⋅log²(s) + f⋅log(f)) ⊂ O(k⋅log²(f+k) + f⋅log(f))),空间为O(k+f)f是禁止数字的数量,s是禁止数字中最长连续串的长度。对于期望值更复杂,但显然受f的限制。如果您假设s^log₂(s)大于f或者对s再次使用概率方法感到不满意,可以将对forbidden[pos:]的对数部分改为二分搜索,以获得O(k⋅log(f+k) + f⋅log(f))

实际实现中,forbid列表中的插入是O(n),因此其时间复杂度为O(k⋅(k+f)+f⋅log(f))。通过用blist sortedlist替换该列表,可以轻松解决此问题。

我还添加了一些注释,因为这个算法太过复杂了。 lin部分与log部分相同,但需要s而不是log²(s)的时间。

import bisect
import random

def sample(k, end, forbid):
    forbidden = sorted(forbid)
    out = []
    # remove the last block from forbidden if it touches end
    for end in reversed(xrange(end+1)):
        if len(forbidden) > 0 and forbidden[-1] == end:
            del forbidden[-1]
        else:
            break

    for i in xrange(k):
        v = random.randrange(end - len(forbidden) + 1)
        # increase v by the number of values < v
        pos = bisect.bisect(forbidden, v)
        v += pos
        # this number might also be already taken, find the
        # first free spot
        ##### linear
        #while pos < len(forbidden) and forbidden[pos] <=v:
        #    pos += 1
        #    v += 1
        ##### log
        while pos < len(forbidden) and forbidden[pos] <= v:
            step = 2
            # when this is finished, we know that:
            # • forbidden[pos + step/2] <= v + step/2
            # • forbidden[pos + step]   >  v + step
            # so repeat until (checked by outer loop):
            #   forbidden[pos + step/2] == v + step/2
            while (pos + step <= len(forbidden)) and \
                  (forbidden[pos + step - 1] <= v + step - 1):
                step = step << 1
            pos += step >> 1
            v += step >> 1

        if v == end:
            end -= 1
        else:
            bisect.insort(forbidden, v)
        out.append(v)
    return out

现在来比较一下Veedrac提出的“hack”(以及Python中的默认实现),其空间复杂度为O(f+k),时间复杂度为n/(n-(f+k))(其中n/(n-(f+k))是“猜测”的期望次数): O(f+k*(n/(n-(f+k))) 我刚才为k=10和一个相当大的n=10000绘制了图表(对于更大的n,情况只会更加极端)。我必须说:我只是因为它似乎是一个有趣的挑战而实现了这个算法,但连我自己都对这个算法的极端性感到惊讶: enter image description here 让我们放大看一下: enter image description here 是的——你甚至可以更快地猜出第9998个数字。请注意,正如你在第一个图中所看到的那样,即使对于更大的f/n,我的一行代码可能也更快(但对于大的n,它仍然具有相当可怕的空间需求)。
为了强调这一点:在这里你所花费的时间唯一的事情就是生成集合,因为那是Veedrac方法中的f因子。 enter image description here 所以我希望我的时间不是浪费了,我成功地说服了你,Veedrac的方法是最好的方法。我可以理解为什么那个概率部分会让你感到困扰,但也许考虑一下哈希映射(= Python dict)和其他大量算法使用类似的方法,它们似乎做得很好。
你可能会担心重复次数的方差。如上所述,这遵循p=n-f/n几何分布。因此,标准偏差(=期望平均值与实际平均值之间的预期偏差量)为: enter image description here

这基本上与平均值相同 (√f⋅n < √n² = n)。

****编辑**:
我刚意识到s实际上也是n/(n-(f+k))。因此,我的算法更精确的运行时间为O(k⋅log²(n/(n-(f+k))) + f⋅log(f))。这很好,因为根据上面的图表,它证明了我的直觉,即它比O(k⋅log(f+k) + f⋅log(f))快得多。但请放心,这并不会改变上面的结果,因为f⋅log(f)是运行时间中绝对占主导地位的部分。


哇...暂时+1..我保证尽快完全理解它!!非常感谢! - necromancer
如果你有任何问题,请随意提问。 - Chronial
@TimPeters 是的,显然我没有任何有用的算法可写 ^^。 - Chronial
@Chronial,你的努力几乎肯定会得到赏金,但我不确定所有这些复杂性是否是必要或有帮助的。请看看我的答案。它可以在没有任何循环的情况下进行增量采样。期待你的批评。 - necromancer
@Chronial,如约而至,赏金归你。享受这10K+的奖励吧 :-) 我仍然认为,与我刚刚发布的相当简单的解决方案相比,你的回答并不正确。在考虑任何批评之后,我可能会接受我的答案。期待听取你的想法。 - necromancer
显示剩余4条评论

8

好的,最后一次尝试;-)虽然会改变基本序列,但这不需要额外的空间,并且每次 sample(n) 调用所需的时间与 n 成正比:

class Sampler(object):
    def __init__(self, base):
        self.base = base
        self.navail = len(base)
    def sample(self, n):
        from random import randrange
        if n < 0:
            raise ValueError("n must be >= 0")
        if n > self.navail:
            raise ValueError("fewer than %s unused remain" % n)
        base = self.base
        for _ in range(n):
            i = randrange(self.navail)
            self.navail -= 1
            base[i], base[self.navail] = base[self.navail], base[i]
        return base[self.navail : self.navail + n]

小驱动程序:

s = Sampler(list(range(100)))
for i in range(9):
    print s.sample(10)
    print s.sample(1)
print s.sample(1)

实际上,这实现了可恢复的random.shuffle(),在选择了n个元素后暂停。 base未被销毁,但会被重新排列。

非常感谢 :-) 至少 +1 直到我理解为止。我确实看到它使用了 O(range) 的内存,但如果我不能自己写出更好的代码,我会接受这个努力。再次感谢! - necromancer
LOL - 你还没有定义你的问题;-) 如果你的范围大小比你期望提取的样本总大小要大得多,那么@Veedrac的方法是一个绝佳的选择。至少在Python 3中,即使range(n)中的n非常大,它也只会使用很少的内存(在Python 2中,如果你使用xrange(n),情况也差不多)。 - Tim Peters
你们让我左右为难。这就是我对这两个解决方案的感觉。唉。。 - necromancer
哎呀,我最开始忘记加上+1了吗?好的,现在已经加上了。也可以随意多次使用哦;-) - necromancer
我已经添加了自己的答案。欢迎批评指正。谢谢! - necromancer

7
这里有一种不显式构建差集的方法,但它使用了 @Veedrac 的“接受/拒绝”逻辑形式。如果你不想在操作过程中改变基础序列,恐怕这是不可避免的。
def sample(n, base, forbidden):
    # base is iterable, forbidden is a set.
    # Every element of forbidden must be in base.
    # forbidden is updated.
    from random import random
    nusable = len(base) - len(forbidden)
    assert nusable >= n
    result = []
    if n == 0:
        return result
    for elt in base:
        if elt in forbidden:
            continue
        if nusable * random() < n:
            result.append(elt)
            forbidden.add(elt)
            n -= 1
            if n == 0:
                return result
        nusable -= 1
    assert False, "oops!"

这是一个小驱动程序:

base = list(range(100))
forbidden = set()
for i in range(10):
    print sample(10, base, forbidden)

+1 谢谢,如果我理解正确的话,这是 O(size(base))?与 @Chronialis 的答案相比,您的空间复杂度更低,但平均时间复杂度相同?(附注:只采样1个将足以简化逻辑)。 - necromancer
每次调用需要 O(len(base)) 的时间。但是如果你用完了基础集合(就像我的示例驱动程序一样),在所有调用结束时,len(forbidden) == len(base),你最终仍然会得到与 base 相同大小的集合。抱歉,我不明白 "sampling just 1" 是什么意思。如果你只想要一个大小为 1 的样本,请将 n 设为 1 传递给我的 sample() 函数。如果你将 1 硬编码进去,代码会变得简单一些,但并不会有太大改观。"难点" 仍然是跳过前面调用返回的所有元素。 - Tim Peters
顺便提一下,你应该再考虑一下我的“shuffle”答案。如果你要取多个样本,这绝对是最有效的方法:第一次调用需要O(len(base))时间,每次请求大小为n的样本时需要O(n)时间。如果你愿意让“base”被销毁,它不需要额外的空间。 - Tim Peters
毫无疑问,但如果范围很大而我只选择了少量样本,则会过度。这也是一个O(range)的解决方案,而我正试图避免这种情况。 - necromancer
1
如果你讨厌那个“hack”,我强烈建议不要使用Python的sample()函数 - Chronial
显示剩余8条评论

6
你可以实现一个洗牌生成器,基于维基百科的“Fisher-Yates shuffle#Modern method”
def shuffle_gen(src):
    """ yields random items from base without repetition. Clobbers `src`. """
    for remaining in xrange(len(src), 0, -1):
        i = random.randrange(remaining)
        yield src[i]
        src[i] = src[remaining - 1]

然后可以使用 itertools.islice 进行切片:

>>> import itertools
>>> sampler = shuffle_gen(range(100))
>>> sample1 = list(itertools.islice(sampler, 10))
>>> sample1
[37, 1, 51, 82, 83, 12, 31, 56, 15, 92]
>>> sample2 = list(itertools.islice(sampler, 80))
>>> sample2
[79, 66, 65, 23, 63, 14, 30, 38, 41, 3, 47, 42, 22, 11, 91, 16, 58, 20, 96, 32, 76, 55, 59, 53, 94, 88, 21, 9, 90, 75, 74, 29, 48, 28, 0, 89, 46, 70, 60, 73, 71, 72, 93, 24, 34, 26, 99, 97, 39, 17, 86, 52, 44, 40, 49, 77, 8, 61, 18, 87, 13, 78, 62, 25, 36, 7, 84, 2, 6, 81, 10, 80, 45, 57, 5, 64, 33, 95, 43, 68]
>>> sample3 = list(itertools.islice(sampler, 20))
>>> sample3
[85, 19, 54, 27, 35, 4, 98, 50, 67, 69]

Eric,这基本上与我早期的一个答案相同。请注意,这里xrange()的第二个参数应该是0,而不是1(例如,list(xrange(4, 1, -1)[4, 3, 2]- range/xrange总是在stop参数之前停止。 - Tim Peters
1
@TimPeters:以前遇到过这种情况,不过因为某些原因决定更改...现在已经修复了。是的,从算法上讲,它们是相同的,但我认为将其作为迭代器实现更加清晰。 - Eric
@Eric,我刚刚添加了自己的答案。你有什么想法吗? - necromancer

6
这是我对Knuth洗牌算法的版本,该算法最初由Tim Peters发布,由Eric优化,然后被necromancer进一步优化以节省空间。基于Eric的版本,因为我确实觉得他的代码非常漂亮 :)

import random
def shuffle_gen(n):
    # this is used like a range(n) list, but we don’t store
    # those entries where state[i] = i.
    state = dict()
    for remaining in xrange(n, 0, -1):
        i = random.randrange(remaining)
        yield state.get(i,i)
        state[i] = state.get(remaining - 1,remaining - 1)
        # Cleanup – we don’t need this information anymore
        state.pop(remaining - 1, None)

用法:

out = []
gen = shuffle_gen(100)
for n in range(100):
    out.append(gen.next())
print out, len(set(out))

顺便说一下,这段代码太啰嗦了;-) 你可以摆脱最后一行(state.pop(...)),在倒数第二行将“get”替换为“pop”。然后它就是一个好答案 - LOL;-) - Tim Peters
2
@TimPeters 哈哈,不是的。之所以有一个 get 和一个 pop 是因为可能会出现 i=remaining-1 的情况。如果只用 pop,我们会删除该项,然后再重新添加它。我想说的是,remaining 要么很大,这种情况很少见,要么很小,我们很快就会完成,所以“泄漏”问题并不太严重。但我想做到彻底 :)。 - Chronial
1
啊!懂了。我很抱歉尝试画蛇添足;-) - Tim Peters

5

编辑:请参阅@TimPeters和@Chronial提供的更清晰版本。一个小修改把它推到了前面。

以下是我认为最有效的增量采样解决方案。与其使用以前采样数字的列表,调用者所需维护的状态包括一个由增量采样器使用的字典和剩余范围内数字的计数。

以下是一个演示实现。与其他解决方案相比:

  • 没有循环(没有标准Python / Veedrac hack;Python实现和Veedrac分享信用)
  • 时间复杂度为O(log(number_previously_sampled))
  • 空间复杂度为O(number_previously_sampled)

代码:

import random

def remove (i, n, state):
  if i == n - 1:
    if i in state:
      t = state[i]
      del state[i]
      return t
    else:
      return i
  else:
    if i in state:
      t = state[i]
      if n - 1 in state:
        state[i] = state[n - 1]
        del state[n - 1]
      else:
        state[i] = n - 1
      return t
    else:
      if n - 1 in state:
        state[i] = state[n - 1]
        del state[n - 1]
      else:
        state[i] = n - 1
      return i

s = dict()
for n in range(100, 0, -1):
  print remove(random.randrange(n), n, s)

1
相对于 Veedrac Hack,它有什么优势?跟踪少量辅助变量比跟踪整个字典要少,所以我不知道你在哪些情况下会使用这个而不是我的。 - Veedrac
1
真的很酷!我打算发布一个更易于使用和理解的重写版本 - 我不想“进入”它,我只是想为后人发布它;-) 关于时间复杂度的声明,这并没有意义:每个元素提取都是O(1),因此获取大小为k的样本是O(k) - Tim Peters
1
不,CPython中字典访问的时间复杂度是O(1)。这是期望的时间复杂度。最坏情况下的时间复杂度是O(len(dict)),但这种情况几乎不会出现。但是,要相信O(1)的说法,你需要对概率有信心;-) - Tim Peters
1
@TimPeters 这里的整数并不是来自连续的范围,所以在循环过程中肯定有可能遇到最坏情况。但是如果你假设这一点,那么算法的时间复杂度是 O(number_previously_sampled) 而不是 O(log(number_previously_sampled)) - Chronial
1
@TimPeters 天文数字般的不可能 - 是的,但这也不能阻止 no_answer_not_upvoted 不喜欢 Veedrac 的回答 ;). 顺便说一句:我认为我赢得了漂亮比赛 :P. - Chronial
显示剩余16条评论

5
这是@necromancer很棒解决方案的重写版本。将其包装在一个类中,使其更易于正确使用,并使用更多字典方法来减少代码行数。
from random import randrange

class Sampler:
    def __init__(self, n):
        self.n = n # number remaining from original range(n)
        # i is a key iff i < n and i already returned;
        # in that case, state[i] is a value to return
        # instead of i.
        self.state = dict()

    def get(self):
        n = self.n
        if n <= 0:
            raise ValueError("range exhausted")
        result = i = randrange(n)
        state = self.state
        # Most of the fiddling here is just to get
        # rid of state[n-1] (if it exists).  It's a
        # space optimization.
        if i == n - 1:
            if i in state:
                result = state.pop(i)
        elif i in state:
            result = state[i]
            if n - 1 in state:
                state[i] = state.pop(n - 1)
            else:
                state[i] = n - 1
        elif n - 1 in state:
            state[i] = state.pop(n - 1)
        else:
            state[i] = n - 1
        self.n = n-1
        return result

这是一个基本的驱动程序:
s = Sampler(100)
allx = [s.get() for _ in range(100)]
assert sorted(allx) == list(range(100))

from collections import Counter
c = Counter()
for i in range(6000):
    s = Sampler(3)
    one = tuple(s.get() for _ in range(3))
    c[one] += 1
for k, v in sorted(c.items()):
    print(k, v)

同时附上样例输出:

(0, 1, 2) 1001
(0, 2, 1) 991
(1, 0, 2) 995
(1, 2, 0) 1044
(2, 0, 1) 950
(2, 1, 0) 1019

通过肉眼观察,这个分布是好的(如果您有疑问,请运行卡方检验)。这里的一些解决方案不会给出每个排列的等概率性(尽管它们以等概率返回每个k子集),因此在这方面与random.sample()不同。


+1 感谢您的优雅重写。我是 Python 的新手,所以这也帮助我学习。 - necromancer
1
TimPeters,你可以选择正确的答案(希望在这个答案和Chronial的改写之间),我几乎更喜欢后者,但由于他已经有了300个声望并且你做出了巨大贡献。非常感谢您的贡献 - 一个简单的问题变成了一个有点传奇的线程。多个单一作者的答案,自我回答和@Veedrac的半论文最终进入社区维基。 - necromancer
没有竞争 - @Chronial的最漂亮!这非常有趣,但是最干净的代码获胜 :-) 话虽如此,我不会使用他的版本 - 或者我的。 我想要一个界面,让我指定一次采样取多少个数据。操纵(e.g.)itertools.slice()来达到这个目的会让用户的生活变得更加复杂,并且为代码读者掩盖了意图。尽管如此,像这样的代码仍然是它的核心。感谢大家的参与! :-) - Tim Peters
@TimPeters,你为什么没有在这里使用生成器呢?它似乎完美地适合这种情况。或者我漏掉了一些缺点吗? - Chronial
@TimPeters 我认为Chronial的“无if”版本确实非常出色,所以我接受了它。我想为您的贡献增加另一个赏金,但StackOverflow赏金的工作方式是,我在这个问题上可以给出的第二个赏金的最低金额是500分,这对我的声誉来说是太大的打击了。很高兴有你的答案,尤其是来自THE Tim Peters!! 这是一种特权 :-) 期待着关注您的其他答案。 PS:有人仅仅基于您写的东西提出了一个问题,就获得了将近1000点的声誉:http://stackoverflow.com/questions/228181/zen-of-python :) - necromancer
一切都很好 - 我在这里将“声望”视为FarmVille中“升级”的方式 - 这是规则的一个奇特产物;-) 我没有使用生成器,因为对于大多数可信用的用途来说,这是一种错误的方法:一次实现一种函数几乎总是不适合要求“k”个一次。而关心随机抽样的人通常非常关心挂钟速度,而不仅仅是O()行为。因此,在任何真正的部署中,我都会有一个循环收集 - 并返回 - 每次调用的k个样本。更快,更合适。 - Tim Peters

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接