pytorch
是否支持无需分配大量内存就能重复张量的操作?
假设我们有一个张量
t = torch.ones((1,1000,1000))
t10 = t.repeat(10,1,1)
重复
t
10 次将需要占用 10 倍的内存。有没有一种方法可以创建一个张量 t10
,而不会分配更多的内存?
这里 是一个相关的问题,但没有答案。您可以使用torch.expand
函数。
t = torch.ones((1, 1000, 1000))
t10 = t.expand(10, 1000, 1000)
t10
只是对t
的引用。例如,对t10[0,0,0]
所做的更改将导致t[0,0,0]
和t10[:,0,0]
的每个成员发生相同的更改。t10
执行的大多数操作都会导致内存被复制,这将打破引用并使用更多的内存。例如:更改设备(.cpu()
、.to(device=...)
、.cuda()
),更改数据类型(.float()
、.long()
、.to(dtype=...)
)或使用.contiguous()
。