不复制内存的情况下重复一个PyTorch张量

8

pytorch是否支持无需分配大量内存就能重复张量的操作?

假设我们有一个张量

t = torch.ones((1,1000,1000))
t10 = t.repeat(10,1,1)

重复 t 10 次将需要占用 10 倍的内存。有没有一种方法可以创建一个张量 t10,而不会分配更多的内存? 这里 是一个相关的问题,但没有答案。
1个回答

13

您可以使用torch.expand函数。

t = torch.ones((1, 1000, 1000))
t10 = t.expand(10, 1000, 1000)

请记住,t10只是对t的引用。例如,对t10[0,0,0]所做的更改将导致t[0,0,0]t10[:,0,0]的每个成员发生相同的更改。
除了直接访问之外,对t10执行的大多数操作都会导致内存被复制,这将打破引用并使用更多的内存。例如:更改设备(.cpu().to(device=...).cuda()),更改数据类型(.float().long().to(dtype=...))或使用.contiguous()

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接