在PyTorch中,张量的形状“torch.Size([])”和“torch.Size([1])”之间的区别是什么?

5

我是pytorch的新手。在玩弄张量时,我观察到两种类型的张量:

tensor(58)
tensor([57.3895])

我打印了它们的形状,输出分别为 -
torch.Size([])
torch.Size([1])

这两者有什么区别?

3个回答

6
您可以这样操作只有单个标量值的张量:
import torch

t = torch.tensor(1)
print(t, t.shape) # tensor(1) torch.Size([])

t = torch.tensor([1])
print(t, t.shape) # tensor([1]) torch.Size([1])

t = torch.tensor([[1]])
print(t, t.shape) # tensor([[1]]) torch.Size([1, 1])

t = torch.tensor([[[1]]])
print(t, t.shape) # tensor([[[1]]]) torch.Size([1, 1, 1])

t = torch.unsqueeze(t, 0)
print(t, t.shape) # tensor([[[[1]]]]) torch.Size([1, 1, 1, 1])

t = torch.unsqueeze(t, 0)
print(t, t.shape) # tensor([[[[[1]]]]]) torch.Size([1, 1, 1, 1, 1])

t = torch.unsqueeze(t, 0)
print(t, t.shape) # tensor([[[[[[1]]]]]]) torch.Size([1, 1, 1, 1, 1, 1])

#squize dimension with id 0
t = torch.squeeze(t,dim=0)
print(t, t.shape) # tensor([[[[[1]]]]]) torch.Size([1, 1, 1, 1, 1])

#back to beginning.
t = torch.squeeze(t)
print(t, t.shape) # tensor(1) torch.Size([])

print(type(t)) # <class 'torch.Tensor'>
print(type(t.data)) # <class 'torch.Tensor'>

张量具有大小或形状,它们其实是同一个概念,表示为torch.Size。您可以编写help(torch.Size)以获取更多信息。每当您编写t.shapet.size()时,都将获得该大小信息。
张量的概念是它们可以具有不同的兼容大小维度,包括torch.Size([])
每当您展开张量时,它会添加另一个尺寸为1的维度。每当您挤压张量时,它会删除1个维度或在一般情况下删除所有维度为1的维度。

5

第一个张量的大小为0,第二个张量的大小为1,PyTorch会尝试使两者兼容(0大小可以类比于float或类似物,尽管我还没有真正遇到需要显式使用它的情况,除了@javadr在下面的答案中展示的情况)。

通常你会使用list来初始化它,在这里查看更多信息。


我认为这不是错误,它意味着一个维度为0的张量。 - javadr
1
请查看 fill_。它只接受一个0维张量作为参数。 因此,您不能像 t.fill_(torch.tensor([1])) 这样使用它。这是完全错误的,会产生以下错误:RuntimeError: fill_ only supports 0-dimension value tensor but got tensor with 1 dimensions. 但是,您可以像下面这样使用它:t.fill_(torch.tensor(1)) - javadr

3

请查看pytorchtensor的文档:

Docstring:
tensor(data, dtype=None, device=None, requires_grad=False, pin_memory=False) -> Tensor

Constructs a tensor with :attr:`data`.

然后它描述了数据是什么:

Args:
    data (array_like): Initial data for the tensor. Can be a list, tuple,
        NumPy ``ndarray``, scalar, and other types.

如您所见,data 可以是一个标量(即零维数据)。

因此,针对您的问题,tensor(58) 是一个零维张量,而 tensor([58]) 则是一个一维张量。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接