Torch - 像numpy repeat一样重复张量

16

我正在尝试用两种方法在torch中重复张量。例如,将张量{1,2,3,4}重复3次以产生以下结果;

{1,2,3,4,1,2,3,4,1,2,3,4}
{1,1,1,2,2,2,3,3,3,4,4,4}

有一个内置的 torch:repeatTensor 函数,它将生成两个中的第一个(类似于 numpy.tile()),但我找不到第二个的函数(类似于 numpy.repeat())。我确定我可以对第一个调用 sort 来得到第二个,但我认为这对于较大的数组来说可能是计算上昂贵的?

谢谢。


2
repeatTensor和expandAs是你的好朋友。 - smhx
6个回答

22

1
这应该是正确的答案! - jef

8

引用https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853的内容 -

z = torch.FloatTensor([[1,2,3],[4,5,6],[7,8,9]])
1 2 3
4 5 6
7 8 9
z.transpose(0,1).repeat(1,3).view(-1, 3).transpose(0,1)
1 1 1 2 2 2 3 3 3
4 4 4 5 5 5 6 6 6
7 7 7 8 8 8 9 9 9
这将让你直观地了解它的工作原理。

6
a = torch.Tensor([1,2,3,4])

为了得到[1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4.],我们需要在第一维度上将张量重复三次:
a.repeat(3)

为了得到[1,1,1,2,2,2,3,3,3,4,4,4],我们需要在张量中添加一个维度,并在第二个维度上将其重复三次以获得一个4 x 3的张量,然后我们可以对其进行扁平化处理。
b = a.reshape(4,1).repeat(1,3).flatten()

或者

b = a.reshape(4,1).repeat(1,3).view(-1)

能解释一下你的答案吗? - André Schild
@AndréSchild 对不起,现在好了吗? - abhshkdz
你的回答中是否包含严重的Python语法错误?例如 a = torch.Tensor{1,2,3,4}a:repeatTensor(3)?我运行了这些代码,但是出现了语法错误。 - Charlie Parker
这是在 Lua torch (http://torch.ch) 时代,而不是 Pytorch :) - abhshkdz

1
这是一个通用的函数,可以重复张量中的元素。
def repeat(tensor, dims):
    if len(dims) != len(tensor.shape):
        raise ValueError("The length of the second argument must equal the number of dimensions of the first.")
    for index, dim in enumerate(dims):
        repetition_vector = [1]*(len(dims)+1)
        repetition_vector[index+1] = dim
        new_tensor_shape = list(tensor.shape)
        new_tensor_shape[index] *= dim
        tensor = tensor.unsqueeze(index+1).repeat(repetition_vector).reshape(new_tensor_shape)
    return tensor

如果你有


foo = tensor([[1, 2],
              [3, 4]])

通过调用 repeat(foo, [2,1]),您将获得:
tensor([[1, 2],
        [1, 2],
        [3, 4],
        [3, 4]])

所以您将沿着第0维度复制了每个元素,并在第1维度上保留元素不变。

1
使用 einops:
from einops import repeat

repeat(x, 'i -> (repeat i)', repeat=3)
# like {1,2,3,4,1,2,3,4,1,2,3,4}

repeat(x, 'i -> (i repeat)', repeat=3)
# like {1,1,1,2,2,2,3,3,3,4,4,4}

这段代码对于任何框架(如numpy、torch、tf等)都能起到相同的作用。

0

你能不能试试这样:

import torch as pt

#1 work as numpy tile

b = pt.arange(10)
print(b.repeat(3))

#2 work as numpy tile

b = pt.tensor(1).repeat(10).reshape(2,-1)
print(b)

#3 work as numpy repeat

t = pt.tensor([1,2,3])
t.repeat(2).reshape(2,-1).transpose(1,0).reshape(-1)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接