Python中有my_sample = random.sample(range(100), 10)
以从[0,100)
的范围内无重复随机采样。
假设我已经随机采样了n
个数,现在我想要再采样一个数而不重复(不包括之前采样的n
个数),如何高效地实现?
更新:从“相对高效”更改为“超级高效”(但忽略常数因子)
Python中有my_sample = random.sample(range(100), 10)
以从[0,100)
的范围内无重复随机采样。
假设我已经随机采样了n
个数,现在我想要再采样一个数而不重复(不包括之前采样的n
个数),如何高效地实现?
更新:从“相对高效”更改为“超级高效”(但忽略常数因子)
如果您事先知道想要多个不重叠的样本,最简单的方法是在list(range(100))
上执行random.shuffle()
(Python 3中可以跳过Python 2中的list()
),然后根据需要削片。
s = list(range(100))
random.shuffle(s)
first_sample = s[-10:]
del s[-10:]
second_sample = s[-10:]
del s[-10:]
# etc
否则,@Chronial的答案相当有效。
读者注意:请先查看原始答案以了解逻辑,然后再理解本答案。
为了完整起见:这是necromancer的答案的概念,但改编成以禁止数字列表作为输入。这与我之前的答案中的代码相同,但我们在生成数字之前从forbid
构建状态。
O(f+k)
和内存O(f+k)
。显然,如果没有对forbid
格式(排序/集合)的要求,则这是可能的最快速度。我认为这在某种程度上使其成为赢家^^。forbid
是一个集合,则重复猜测方法更快,时间复杂度为O(k⋅n/(n-(f+k)))
,当f+k
不是非常接近n
时,这非常接近于O(k)
。forbid
已排序,则我的荒谬算法更快,时间复杂度为:import random
def sample_gen(n, forbid):
state = dict()
track = dict()
for (i, o) in enumerate(forbid):
x = track.get(o, o)
t = state.get(n-i-1, n-i-1)
state[x] = t
track[t] = x
state.pop(n-i-1, None)
track.pop(o, None)
del track
for remaining in xrange(n-len(forbid), 0, -1):
i = random.randrange(remaining)
yield state.get(i, i)
state[i] = state.get(remaining - 1, remaining - 1)
state.pop(remaining - 1, None)
使用方法:
gen = sample_gen(10, [1, 2, 4, 8])
print gen.next()
print gen.next()
print gen.next()
print gen.next()
如果样本数量远小于总体数量,只需抽样,检查是否已被选择,并在此过程中重复。这听起来可能很傻,但您有指数级递减的选择相同数字的可能性,因此如果有一小部分未选择,则比O(n)
快得多。
Python使用Mersenne Twister作为其伪随机数生成器,这是好的足够好的。我们可以完全使用其他方法以能够以可预测的方式生成不重叠的数字。
二次剩余,x² mod p
,当 2x < p
且 p
是素数时是唯一的。
如果你"翻转"剩余部分,p - (x² % p)
, 同时也满足 p = 3 mod 4
,结果将是其余空间。
这不是一个非常令人信服的数字分布,所以您可以增加功率,添加一些升级常数,然后分布就很好了。
首先我们需要生成质数:
from itertools import count
from math import ceil
from random import randrange
def modprime_at_least(number):
if number <= 2:
return 2
number = (number // 4 * 4) + 3
for number in count(number, 4):
if all(number % factor for factor in range(3, ceil(number ** 0.5)+1, 2)):
return number
[None] * 10**6
需要更长的时间,因为它只计算一次,所以这不是一个真正的问题。O(log number)
,如果进行二进制搜索,则为O(log number of cached primes)
。事实上,如果使用galloping,您可以将其降至O(log log number)
,这基本上是常数(log log googol = 2
)。def sample_generator(up_to):
prime = modprime_at_least(up_to+1)
# Fudge to make it less predictable
fudge_power = 2**randrange(7, 11)
fudge_constant = randrange(prime//2, prime)
fudge_factor = randrange(prime//2, prime)
def permute(x):
permuted = pow(x, fudge_power, prime)
return permuted if 2*x <= prime else prime - permuted
for x in range(prime):
res = (permute(x) + fudge_constant) % prime
res = permute((res * fudge_factor) % prime)
if res < up_to:
yield res
并检查它是否正常工作:
set(sample_generator(10000)) ^ set(range(10000))
#>>> set()
现在,这个算法的可爱之处在于,如果忽略主要性测试(大约为O(√n)
,其中n
是元素数量),它的时间复杂度为O(k)
,其中k
是样本大小,而且内存使用为O(1)
!从技术上讲,这是O(√n + k)
,但实际上它是O(k)
。
您不需要一个经过验证的伪随机数生成器。这个伪随机数生成器比线性同余生成器(它很受欢迎;Java使用它)好得多,但它没有Mersenne Twister那么可靠。
您不需要先用其他函数生成任何项。这通过数学避免了重复,而不是检查。下一节中我将展示如何消除此限制。
短方法必须不足够(k
必须接近 n
)。如果 k
只有 n
的一半,就按照我的最初建议进行。
极大地节省内存。这需要恒定的内存……甚至不到 O(k)
!
生成下一个项目所需的时间是恒定的。从恒定的角度来看,这实际上相当快:它不像内置的 Mersenne Twister 那样快,但它在2倍以内。
酷。
要删除此要求:
您不首先使用其他函数生成任何项。这通过数学避免重复,而不是检查。
我已经制作了最佳的算法,在时间和空间复杂度上都是最简单的扩展。 这是我的先前生成器的简单扩展。
以下是摘要(n
是数字池的长度,k
是“外部”键的数量):
O(√n)
; 对于所有合理的输入,O(log log n)
这是我的算法在算法复杂性方面技术上唯一不完美的因素,由于O(√n)
成本。实际上,这不会成为问题,因为预计算将其降至O(log log n)
,这几乎等同于常数时间。
如果您以任何固定百分比耗尽可迭代物,则成本为分摊免费。
这不是一个实际的问题。
O(1)
密钥生成时间显然,这是无法改进的。
O(k)
密钥生成时间O(k)
。O(k)
如果假定外部密钥是完全独立的,则每个密钥表示一个不同的信息项。 因此,所有密钥必须存储。 算法恰好在看到密钥时会丢弃密钥,因此随着生成器的寿命,内存成本将清除。def sample_generator(up_to, previously_chosen=set(), *, prune=True):
prime = modprime_at_least(up_to+1)
# Fudge to make it less predictable
fudge_power = 2**randrange(7, 11)
fudge_constant = randrange(prime//2, prime)
fudge_factor = randrange(prime//2, prime)
def permute(x):
permuted = pow(x, fudge_power, prime)
return permuted if 2*x <= prime else prime - permuted
for x in range(prime):
res = (permute(x) + fudge_constant) % prime
res = permute((res * fudge_factor) % prime)
if res in previously_chosen:
if prune:
previously_chosen.remove(res)
elif res < up_to:
yield res
if res in previously_chosen:
previously_chosen.remove(res)
set
中随时向previously_chosen
添加内容。实际上,您还可以从集合中删除以便将其重新添加到潜在池中,但仅当sample_generator
尚未使用prune=False
跳过它时才起作用。
所以这就是它的全部内容。很容易看出它满足所有要求,并且很容易看出这些要求是绝对的。请注意,如果您没有一个集合,它仍然通过将输入转换为集合来满足最坏情况,尽管它会增加开销。
我很好奇这个PRNG从统计上来说到底有多好。
一些快速搜索引导我创建了以下三个测试,它们似乎都显示出良好的结果!
首先是一些随机数:
N = 1000000
my_gen = list(sample_generator(N))
target = list(range(N))
random.shuffle(target)
control = list(range(N))
random.shuffle(control)
0
到10⁶-1
,其中一个使用我们有趣的伪随机数生成器,另一个使用梅森旋转算法作为基准。第三个是对照组。
from collections import Counter
def birthdat_calc(randoms):
return Counter(abs(r1-r2)//10000 for r1, r2 in zip(randoms, randoms[1:]))
def birthday_compare(randoms_1, randoms_2):
birthday_1 = sorted(birthdat_calc(randoms_1).items())
birthday_2 = sorted(birthdat_calc(randoms_2).items())
return sum(abs(n1 - n2) for (i1, n1), (i2, n2) in zip(birthday_1, birthday_2))
print(birthday_compare(my_gen, target), birthday_compare(control, target))
#>>> 9514 10136
这比每个方差都要小。
def permutations_calc(randoms):
permutations = Counter()
for items in zip(*[iter(randoms)]*5):
sorteditems = sorted(items)
permutations[tuple(sorteditems.index(item) for item in items)] += 1
return permutations
def permutations_compare(randoms_1, randoms_2):
permutations_1 = permutations_calc(randoms_1)
permutations_2 = permutations_calc(randoms_2)
keys = sorted(permutations_1.keys() | permutations_2.keys())
return sum(abs(permutations_1[key] - permutations_2[key]) for key in keys)
print(permutations_compare(my_gen, target), permutations_compare(control, target))
#>>> 5324 5368
def runs_calc(randoms):
runs = Counter()
run = 0
for item in randoms:
if run == 0:
run = 1
elif run == 1:
run = 2
increasing = item > last
else:
if (item > last) == increasing:
run += 1
else:
runs[run] += 1
run = 0
last = item
return runs
def runs_compare(randoms_1, randoms_2):
runs_1 = runs_calc(randoms_1)
runs_2 = runs_calc(randoms_2)
keys = sorted(runs_1.keys() | runs_2.keys())
return sum(abs(runs_1[key] - runs_2[key]) for key in keys)
print(runs_compare(my_gen, target), runs_compare(control, target))
#>>> 1270 975
这里的方差非常大,在多次执行中,我看到了两者都比较均匀的分布。因此,此测试通过。
有人向我提到了线性同余生成器,称其可能会“更有成果”。我自己实现了一个糟糕的线性同余生成器,以验证这个说法是否准确。
据我所知,LCG与普通生成器一样,并不是为了循环而设计的。因此,我查阅的大多数参考资料(如维基百科)仅涵盖了定义周期的内容,而没有介绍如何制作具有特定周期的强大LCG。这可能会影响结果。
下面开始:
from operator import mul
from functools import reduce
# Credit https://dev59.com/RmQn5IYBdhLWcg3wPlCj#16996439
# Meta: Also Tobias Kienzler seems to have credit for my
# edit to the post, what's up with that?
def factors(n):
d = 2
while d**2 <= n:
while not n % d:
yield d
n //= d
d += 1
if n > 1:
yield n
def sample_generator3(up_to):
for modulier in count(up_to):
modulier_factors = set(factors(modulier))
multiplier = reduce(mul, modulier_factors)
if not modulier % 4:
multiplier *= 2
if multiplier < modulier - 1:
multiplier += 1
break
x = randrange(0, up_to)
fudge_constant = random.randrange(0, modulier)
for modfact in modulier_factors:
while not fudge_constant % modfact:
fudge_constant //= modfact
for _ in range(modulier):
if x < up_to:
yield x
x = (x * multiplier + fudge_constant) % modulier
我们不再检查质数,但我们需要对因子做一些奇怪的事情。
modulier ≥ up_to > multiplier, fudge_constant > 0
a - 1
必须被 modulier
中的每个因子整除...fudge_constant
必须与 modulier
互质请注意,这些不是线性同余发生器(LCG)的规则,而是具有完整周期的 LCG 规则,其显然等于模数(modulier
)。
我是这样做的:
尝试至少使用每个modulier
,当满足条件时停止。
multiplier
成为去除重复项的
的乘积multiplier
不小于modulier
,则继续下一个modulier
fudge_constant
成为小于modulier
的随机选择数字fudge_constant
中删除在
中的因子fudge_constant
和multiplier
比完美生成器更常见。print(birthday_compare(lcg, target), birthday_compare(control, target))
#>>> 22532 10650
print(permutations_compare(lcg, target), permutations_compare(control, target))
#>>> 17968 5820
print(runs_compare(lcg, target), runs_compare(control, target))
#>>> 8320 662
好的,我们开始吧。这应该是最快的非概率算法。它的运行时间为O(k⋅log²(s) + f⋅log(f)) ⊂ O(k⋅log²(f+k) + f⋅log(f)))
,空间为O(k+f)
。f
是禁止数字的数量,s
是禁止数字中最长连续串的长度。对于期望值更复杂,但显然受f
的限制。如果您假设s^log₂(s)
大于f
或者对s
再次使用概率方法感到不满意,可以将对forbidden[pos:]
的对数部分改为二分搜索,以获得O(k⋅log(f+k) + f⋅log(f))
。
实际实现中,forbid
列表中的插入是O(n)
,因此其时间复杂度为O(k⋅(k+f)+f⋅log(f))
。通过用blist sortedlist替换该列表,可以轻松解决此问题。
我还添加了一些注释,因为这个算法太过复杂了。 lin
部分与log
部分相同,但需要s
而不是log²(s)
的时间。
import bisect
import random
def sample(k, end, forbid):
forbidden = sorted(forbid)
out = []
# remove the last block from forbidden if it touches end
for end in reversed(xrange(end+1)):
if len(forbidden) > 0 and forbidden[-1] == end:
del forbidden[-1]
else:
break
for i in xrange(k):
v = random.randrange(end - len(forbidden) + 1)
# increase v by the number of values < v
pos = bisect.bisect(forbidden, v)
v += pos
# this number might also be already taken, find the
# first free spot
##### linear
#while pos < len(forbidden) and forbidden[pos] <=v:
# pos += 1
# v += 1
##### log
while pos < len(forbidden) and forbidden[pos] <= v:
step = 2
# when this is finished, we know that:
# • forbidden[pos + step/2] <= v + step/2
# • forbidden[pos + step] > v + step
# so repeat until (checked by outer loop):
# forbidden[pos + step/2] == v + step/2
while (pos + step <= len(forbidden)) and \
(forbidden[pos + step - 1] <= v + step - 1):
step = step << 1
pos += step >> 1
v += step >> 1
if v == end:
end -= 1
else:
bisect.insort(forbidden, v)
out.append(v)
return out
O(f+k)
,时间复杂度为n/(n-(f+k))
(其中n/(n-(f+k))
是“猜测”的期望次数):
k=10
和一个相当大的n=10000
绘制了图表(对于更大的n
,情况只会更加极端)。我必须说:我只是因为它似乎是一个有趣的挑战而实现了这个算法,但连我自己都对这个算法的极端性感到惊讶:
f/n
,我的一行代码可能也更快(但对于大的n
,它仍然具有相当可怕的空间需求)。f
因子。
dict
)和其他大量算法使用类似的方法,它们似乎做得很好。p=n-f/n
的几何分布。因此,标准偏差(=期望平均值与实际平均值之间的预期偏差量)为:
这基本上与平均值相同 (√f⋅n < √n² = n
)。
****编辑**:
我刚意识到s
实际上也是n/(n-(f+k))
。因此,我的算法更精确的运行时间为O(k⋅log²(n/(n-(f+k))) + f⋅log(f))
。这很好,因为根据上面的图表,它证明了我的直觉,即它比O(k⋅log(f+k) + f⋅log(f))
快得多。但请放心,这并不会改变上面的结果,因为f⋅log(f)
是运行时间中绝对占主导地位的部分。
好的,最后一次尝试;-)虽然会改变基本序列,但这不需要额外的空间,并且每次 sample(n)
调用所需的时间与 n
成正比:
class Sampler(object):
def __init__(self, base):
self.base = base
self.navail = len(base)
def sample(self, n):
from random import randrange
if n < 0:
raise ValueError("n must be >= 0")
if n > self.navail:
raise ValueError("fewer than %s unused remain" % n)
base = self.base
for _ in range(n):
i = randrange(self.navail)
self.navail -= 1
base[i], base[self.navail] = base[self.navail], base[i]
return base[self.navail : self.navail + n]
小驱动程序:
s = Sampler(list(range(100)))
for i in range(9):
print s.sample(10)
print s.sample(1)
print s.sample(1)
random.shuffle()
,在选择了n
个元素后暂停。 base
未被销毁,但会被重新排列。O(range)
的内存,但如果我不能自己写出更好的代码,我会接受这个努力。再次感谢! - necromancerrange(n)
中的n
非常大,它也只会使用很少的内存(在Python 2中,如果你使用xrange(n)
,情况也差不多)。 - Tim Petersdef sample(n, base, forbidden):
# base is iterable, forbidden is a set.
# Every element of forbidden must be in base.
# forbidden is updated.
from random import random
nusable = len(base) - len(forbidden)
assert nusable >= n
result = []
if n == 0:
return result
for elt in base:
if elt in forbidden:
continue
if nusable * random() < n:
result.append(elt)
forbidden.add(elt)
n -= 1
if n == 0:
return result
nusable -= 1
assert False, "oops!"
这是一个小驱动程序:
base = list(range(100))
forbidden = set()
for i in range(10):
print sample(10, base, forbidden)
O(size(base))
?与 @Chronialis 的答案相比,您的空间复杂度更低,但平均时间复杂度相同?(附注:只采样1个将足以简化逻辑)。 - necromancerO(len(base))
的时间。但是如果你用完了基础集合(就像我的示例驱动程序一样),在所有调用结束时,len(forbidden) == len(base)
,你最终仍然会得到与 base
相同大小的集合。抱歉,我不明白 "sampling just 1" 是什么意思。如果你只想要一个大小为 1 的样本,请将 n
设为 1
传递给我的 sample()
函数。如果你将 1
硬编码进去,代码会变得简单一些,但并不会有太大改观。"难点" 仍然是跳过前面调用返回的所有元素。 - Tim PetersO(len(base))
时间,每次请求大小为n
的样本时需要O(n)
时间。如果你愿意让“base”被销毁,它不需要额外的空间。 - Tim PetersO(range)
的解决方案,而我正试图避免这种情况。 - necromancerdef shuffle_gen(src):
""" yields random items from base without repetition. Clobbers `src`. """
for remaining in xrange(len(src), 0, -1):
i = random.randrange(remaining)
yield src[i]
src[i] = src[remaining - 1]
然后可以使用 itertools.islice
进行切片:
>>> import itertools
>>> sampler = shuffle_gen(range(100))
>>> sample1 = list(itertools.islice(sampler, 10))
>>> sample1
[37, 1, 51, 82, 83, 12, 31, 56, 15, 92]
>>> sample2 = list(itertools.islice(sampler, 80))
>>> sample2
[79, 66, 65, 23, 63, 14, 30, 38, 41, 3, 47, 42, 22, 11, 91, 16, 58, 20, 96, 32, 76, 55, 59, 53, 94, 88, 21, 9, 90, 75, 74, 29, 48, 28, 0, 89, 46, 70, 60, 73, 71, 72, 93, 24, 34, 26, 99, 97, 39, 17, 86, 52, 44, 40, 49, 77, 8, 61, 18, 87, 13, 78, 62, 25, 36, 7, 84, 2, 6, 81, 10, 80, 45, 57, 5, 64, 33, 95, 43, 68]
>>> sample3 = list(itertools.islice(sampler, 20))
>>> sample3
[85, 19, 54, 27, 35, 4, 98, 50, 67, 69]
xrange()
的第二个参数应该是0,而不是1(例如,list(xrange(4, 1, -1)
是[4, 3, 2]
- range/xrange
总是在stop
参数之前停止。 - Tim Petersimport random
def shuffle_gen(n):
# this is used like a range(n) list, but we don’t store
# those entries where state[i] = i.
state = dict()
for remaining in xrange(n, 0, -1):
i = random.randrange(remaining)
yield state.get(i,i)
state[i] = state.get(remaining - 1,remaining - 1)
# Cleanup – we don’t need this information anymore
state.pop(remaining - 1, None)
用法:
out = []
gen = shuffle_gen(100)
for n in range(100):
out.append(gen.next())
print out, len(set(out))
state.pop(...)
),在倒数第二行将“get”替换为“pop”。然后它就是一个好答案 - LOL;-) - Tim Petersget
和一个 pop
是因为可能会出现 i=remaining-1
的情况。如果只用 pop
,我们会删除该项,然后再重新添加它。我想说的是,remaining
要么很大,这种情况很少见,要么很小,我们很快就会完成,所以“泄漏”问题并不太严重。但我想做到彻底 :)。 - Chronial编辑:请参阅@TimPeters和@Chronial提供的更清晰版本。一个小修改把它推到了前面。
以下是我认为最有效的增量采样解决方案。与其使用以前采样数字的列表,调用者所需维护的状态包括一个由增量采样器使用的字典和剩余范围内数字的计数。
以下是一个演示实现。与其他解决方案相比:
O(log(number_previously_sampled))
O(number_previously_sampled)
代码:
import random
def remove (i, n, state):
if i == n - 1:
if i in state:
t = state[i]
del state[i]
return t
else:
return i
else:
if i in state:
t = state[i]
if n - 1 in state:
state[i] = state[n - 1]
del state[n - 1]
else:
state[i] = n - 1
return t
else:
if n - 1 in state:
state[i] = state[n - 1]
del state[n - 1]
else:
state[i] = n - 1
return i
s = dict()
for n in range(100, 0, -1):
print remove(random.randrange(n), n, s)
O(1)
,因此获取大小为k
的样本是O(k)
。 - Tim PetersO(1)
。这是期望的时间复杂度。最坏情况下的时间复杂度是O(len(dict))
,但这种情况几乎不会出现。但是,要相信O(1)
的说法,你需要对概率有信心;-) - Tim PetersO(number_previously_sampled)
而不是 O(log(number_previously_sampled))
。 - Chronialfrom random import randrange
class Sampler:
def __init__(self, n):
self.n = n # number remaining from original range(n)
# i is a key iff i < n and i already returned;
# in that case, state[i] is a value to return
# instead of i.
self.state = dict()
def get(self):
n = self.n
if n <= 0:
raise ValueError("range exhausted")
result = i = randrange(n)
state = self.state
# Most of the fiddling here is just to get
# rid of state[n-1] (if it exists). It's a
# space optimization.
if i == n - 1:
if i in state:
result = state.pop(i)
elif i in state:
result = state[i]
if n - 1 in state:
state[i] = state.pop(n - 1)
else:
state[i] = n - 1
elif n - 1 in state:
state[i] = state.pop(n - 1)
else:
state[i] = n - 1
self.n = n-1
return result
s = Sampler(100)
allx = [s.get() for _ in range(100)]
assert sorted(allx) == list(range(100))
from collections import Counter
c = Counter()
for i in range(6000):
s = Sampler(3)
one = tuple(s.get() for _ in range(3))
c[one] += 1
for k, v in sorted(c.items()):
print(k, v)
同时附上样例输出:
(0, 1, 2) 1001
(0, 2, 1) 991
(1, 0, 2) 995
(1, 2, 0) 1044
(2, 0, 1) 950
(2, 1, 0) 1019
通过肉眼观察,这个分布是好的(如果您有疑问,请运行卡方检验)。这里的一些解决方案不会给出每个排列的等概率性(尽管它们以等概率返回每个k子集),因此在这方面与random.sample()
不同。
itertools.slice()
来达到这个目的会让用户的生活变得更加复杂,并且为代码读者掩盖了意图。尽管如此,像这样的代码仍然是它的核心。感谢大家的参与! :-) - Tim PetersO()
行为。因此,在任何真正的部署中,我都会有一个循环收集 - 并返回 - 每次调用的k
个样本。更快,更合适。 - Tim Peters
[0, x)
范围内采样整数?你期望的x
是多少? - Chronialrandom.sample
的源代码。 - Eric