如何在Python 3中从deque中获取random.sample()?

12

我有一个包含元组的 collections.deque(),想从中进行随机采样。在 Python 2.7 中,我可以使用 batch = random.sample(my_deque, batch_size)

但是在 Python 3.4 中,会出现错误 TypeError: Population must be a sequence or set. For dicts, use list(d).

有什么最佳解决方案或建议的方法可以在 Python 3 中有效地从 deque 进行采样呢?


这个有帮助吗?https://dev59.com/4kvSa4cB1Zd3GeqPcCeA - Nihal Rp
如果deque足够短,我会使用sample(list(the_deque), k) - kennytm
1
奇怪。在Python 3.5上,random.sample(deq, size)对我有效。已确认在3.4上无效。 - juanpa.arrivillaga
2个回答

14

显而易见的方法-将其转换为列表。

batch = random.sample(list(my_deque), batch_size))

但是您可以避免创建整个列表。

idx_batch = set(sample(range(len(my_deque)), batch_size))
batch = [val for i, val in enumerate(my_deque) if i in idx_batch] 

顺便提一句(已编辑)

实际上,在Python >= 3.5中,random.sample应该可以很好地使用deque,因为该类已更新以匹配Sequence接口。

In [3]: deq = collections.deque(range(100))

In [4]: random.sample(deq, 10)
Out[4]: [12, 64, 84, 77, 99, 69, 1, 93, 82, 35]

请注意!正如Geoffrey Irving在下面的评论中正确地指出的那样,您最好将队列转换为列表,因为队列是实现为链表,使得每个索引访问在队列大小上都是O(n),因此抽取m个随机值将需要O(m * n)时间。


2
请注意,根据文档,random.sample将是二次时间复杂度,因为每个deque下标在中间都是O(n)。 - Geoffrey Irving

6
在Python≥3.5中,deque上的sample()函数效果很好,速度也相当快。
在Python 3.4中,您可以使用以下代码,其运行速度大约与sample()相同:
sample_indices = sample(range(len(deq)), 50)
[deq[index] for index in sample_indices]

在我的 MacBook 上使用 Python 3.6.8,这种解决方案比 Eli Korvigo 的解决方案快了超过44倍。 :)
我使用了一个包含100万个项的双端队列(deque),并随机取样了50个项:
from random import sample
from collections import deque

deq = deque(maxlen=1000000)
for i in range(1000000):
    deq.append(i)

sample_indices = set(sample(range(len(deq)), 50))

%timeit [deq[i] for i in sample_indices]
1.68 ms ± 23.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%timeit sample(deq, 50)
1.94 ms ± 60.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

%timeit sample(range(len(deq)), 50)
44.9 µs ± 549 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%timeit [val for index, val in enumerate(deq) if index in sample_indices]
75.1 ms ± 410 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

尽管如此,正如其他人指出的那样,deque并不适合随机访问。如果您想要实现回放记忆,可以使用类似这样的旋转列表:

class ReplayMemory:
    def __init__(self, max_size):
        self.buffer = [None] * max_size
        self.max_size = max_size
        self.index = 0
        self.size = 0

    def append(self, obj):
        self.buffer[self.index] = obj
        self.size = min(self.size + 1, self.max_size)
        self.index = (self.index + 1) % self.max_size

    def sample(self, batch_size):
        indices = sample(range(self.size), batch_size)
        return [self.buffer[index] for index in indices]

如果有一百万个项目,取样50个项目速度非常快:

%timeit mem.sample(50)
#58 µs ± 691 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接