在PyTorch中,张量大小前面的"*"运算符是用来展开张量的。

4

我现在正在学习如何在PyTorch中构建神经网络。下面是从.py文件中剪切出来的代码:

x = torch.unsqueeze(torch.linspace(-1, 1, 1000), dim=1)
y = x.pow(2) + 0.1*torch.normal(torch.zeros(*x.size()))

我对在x.size()之前使用*运算符的效用感到困惑。我尝试删除它并绘制散点图,结果证明与未删除*的图形相同。

我还查看了https://pytorch.org/docs/stable/tensors.htmlsize的官方文档,但是我无法理解。

torch.size项目文档中的图片

如果您能帮忙,我将不胜感激。

2个回答

3

*在Python中的使用方式表示(参数)解包。当你将它前置于一个可迭代对象(即x.size()返回的内容)时,它会解包并将其项作为位置参数传递给函数。例如:

def f(a1, a2):
    print(a1, a2)

f(*["Hello", "World"])

你可以查看另一个示例和更详细的描述的文档链接。

非常感谢!这个答案很容易理解! - YOLOv4

2
*在这里不影响结果的原因是因为torch.zero接受可变数量的参数类似列表或元组的集合,如此处所述。这并不意味着*本身是无用的。
然后,由于torch.Size()类是python元组的子类,因此可以使用*对其进行解包。(x.size()将返回一个torch.Size()对象)
因此,总之,x.size()将给您(1000, 1),而*x.size()在参数中将给您1000, 1,两者都被torch.zeros()接受。

你的回答非常精确清晰!非常感谢! - YOLOv4

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接