PyTorch张量中的“大于”运算符“>”是什么意思?

3

我有一个张量it的定义如下:

import torch
it = torch.tensor([0,  0,  0,  0,  0,  0,  0,  0,  0,  0], device='cuda:0')

根据这个定义,那么 it > 0 是什么意思呢?

2
它创建一个布尔值张量,其中每个 i 都设置为 True,如果 it[i] > 0,则为True,否则为 False。 - Gabriel M
2个回答

8

使用 > 运算符与使用 torch.gt() 函数是相同的。

换句话说,

it > 0

是等同于

torch.gt(it, 0)

它返回一个与it相同shapeByteTensor(布尔张量),其中out[i]为True,如果it[i] > 0,否则为False。


1
如问题所示,it是一个由10个元素组成的一维张量。当我们写it > 0时,张量it的每个元素都与0进行比较,输出根据数字是否大于0设置为TrueFalse。结果也是一个由TrueFalse值组成的一维布尔张量。
在您的情况下,您将获得一个像这样的一维张量:[False,False,False,False,False,False,False,False,False,False,],因为it中没有任何元素等于0。
更简单地说,如果result是一个变量(实际上是一个一维张量)用于存储与it具有相同形状的输出,则方程result= it > 0可以写成:
if it[i]>0:
    result[i]= True
else:
    result[i]= False 

但是当它被执行为result= it > 0时,执行速度比编写自己的for/while循环要快得多。

希望这可以帮助你。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接