一个更通用的版本,可以控制双端队列的大小,支持批量插入,并且可以控制推送维度,参考@Bruno_Lubascher的答案。
def push_to_deque(deque, x, deque_size=None, dim=0):
"""Handling `deque` tensor as a (set of) deque/FIFO, push the content of `x` into it."""
if deque_size is None:
deque_size = deque.shape[dim]
deque_dims = deque.dim()
input_size = x.shape[dim]
dims_right = deque_dims - dim - 1
deque_slicing = (
(slice(None),) * dim
+ (
slice(
input_size - deque_size
if input_size < deque_size
else deque.shape[dim],
None,
),
)
+ (slice(None),) * dims_right
)
input_slicing = (
(slice(None),) * dim + (slice(-deque_size, None),) + (slice(None),) * dims_right
)
deque = torch.cat((deque[deque_slicing], x[input_slicing]), dim=dim)
return deque
示例:
>>>
>>> batch_size, vector_size = 1, 2
>>> deque_size = 4
>>>
>>> deques = torch.empty((batch_size, 0, vector_size))
>>>
>>> vals = torch.arange(10).view((batch_size, 5, vector_size))
>>> deque = push_to_deque(deque, vals, deque_size=deque_size, dim=1)
>>> deque
tensor([[[2., 3.],
[4., 5.],
[6., 7.],
[8., 9.]]])
>>>
>>> vals = torch.arange(10, 20).view((batch_size, 5, vector_size))
>>> deque = push_to_deque(deque, vals, deque_size=deque_size, dim=1)
>>> deque
tensor([[[12., 13.],
[14., 15.],
[16., 17.],
[18., 19.]]])
>>> vals = torch.arange(20, 24).view((batch_size, 2, vector_size))
>>> deque = push_to_deque(deque, vals, deque_size=deque_size, dim=1)
>>> deque
tensor([[[16., 17.],
[18., 19.],
[20., 21.],
[22., 23.]]])
>>>
>>> deque = torch.zeros(batch_size, 10, vector_size)
>>> vals = torch.arange(4).view((batch_size, 2, vector_size))
>>> deque = push_to_deque(deque, vals, deque_size=deque_size, dim=1)
>>> deque
tensor([[[0., 0.],
[0., 0.],
[0., 1.],
[2., 3.]]])