为什么当a和b引用相同数据时,a.storage()和b.storage()返回false?

5
>>> a = torch.arange(12).reshape(2, 6)
>>> a
tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11]])
>>> b = a[1:, :]
>>> b.storage() is a.storage()
False

但是

>>> b[0, 0] = 999
>>> b, a # both tensors are changed
(tensor([[999,   7,   8,   9,  10,  11]]),
 tensor([[  0,   1,   2,   3,   4,   5],
         [999,   7,   8,   9,  10,  11]]))

存储张量数据的确切对象是什么?如何检查两个张量是否共享内存?

1个回答

5

torch.Tensor.storage() 每次调用都会返回 torch.Storage 的一个新实例。您可以在以下内容中看到这一点。

a.storage() is a.storage()
# False

为了比较指向底层数据的指针,您可以使用以下操作:
a.storage().data_ptr() == b.storage().data_ptr()
# True

在这个PyTorch论坛帖子中,讨论了如何确定PyTorch张量是否共享内存。


请注意a.data_ptr()a.storage().data_ptr()之间的区别。第一个返回张量的第一个元素的指针,而第二个似乎指向底层数据(而不是切片视图)的内存地址,尽管它没有被记录在文档中
了解上述内容后,我们可以理解为什么a.data_ptr()b.data_ptr()不同。考虑以下代码:
import torch

a = torch.arange(4, dtype=torch.int64)
b = a[1:]
b.data_ptr() - a.data_ptr()
# 8
< p > b 的第一个元素的地址比 a 的第一个元素的地址多 8,因为我们切片删除了第一个元素,并且每个元素是 8 字节(dtype 是 64 位整数)。

如果我们使用与上面相同的代码,但使用 8 位整数数据类型,则内存地址将增加一。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接