如何将一个包含整数的PyTorch张量转换为布尔类型的张量?

18

我想将一个整数张量转换为布尔值张量。

具体来说,我希望能够有一个函数,将tensor([0,10,0,16])转换为tensor([0,1,0,1])

在Tensorflow中,这很容易实现,只需使用tf.cast(x, tf.bool)即可。

我希望将所有大于0的整数转换为1,将所有等于0的整数转换为0。这相当于大多数语言中的!!

由于pytorch似乎没有专门的布尔类型可供转换,那么这里最好的方法是什么?

编辑:我正在寻找一种向量化的解决方案,而不是通过循环遍历每个元素。


1
对每个元素调用 bool(int)。或者在 numpy 中:使用 array.astype(...) - Thomas Lang
1
这是一个简单的解决方案,需要使用for循环,是的。但是是否有矢量化的解决方案呢? - Ross
astype 版本几乎肯定是矢量化的。 - Thomas Lang
1
@ThomasLang 在 Pytorch 中没有 .astype,因此一个人必须 转换为 numpy->转换类型 -> 加载到 pytorch,在我看来这是低效的。 - Umang Gupta
5个回答

26
你需要生成一个给定整数张量的布尔掩码。为此,可以使用简单比较运算符(>)或 torch.gt() 检查张量中的值是否大于0,并得到所需结果。
# input tensor
In [76]: t   
Out[76]: tensor([ 0, 10,  0, 16])

# generate the needed boolean mask
In [78]: t > 0      
Out[78]: tensor([0, 1, 0, 1], dtype=torch.uint8)

# sanity check
In [93]: mask = t > 0      

In [94]: mask.type()      
Out[94]: 'torch.ByteTensor'

注意: 在 PyTorch 1.4+ 版本中,上述操作将返回 'torch.BoolTensor'

In [9]: t > 0  
Out[9]: tensor([False,  True, False,  True])

# alternatively, use `torch.gt()` API
In [11]: torch.gt(t, 0)
Out[11]: tensor([False,  True, False,  True])

如果您确实需要单个位(无论是0还是1),请使用以下方式进行转换:
In [14]: (t > 0).type(torch.uint8)   
Out[14]: tensor([0, 1, 0, 1], dtype=torch.uint8)

# alternatively, use `torch.gt()` API
In [15]: torch.gt(t, 0).int()
Out[15]: tensor([0, 1, 0, 1], dtype=torch.int32)

这个变化的原因已经在这个功能请求问题中讨论过:issues/4764 - 引入torch.BoolTensor ...

TL;DR: 简单的一句话总结

t.bool().int()

1
在PyTorch 1.4.0中,这将返回'torch.BoolTensor'。 - draupnie

6

PyTorch的to(dtype)方法具有方便的数据类型别名。您可以直接调用bool

>>> t.bool()
tensor([False,  True, False,  True])

>>> t.bool().int()
tensor([0, 1, 0, 1], dtype=torch.int32)

2

另一种选择是直接执行:

temp = torch.tensor([0,10,0,16])
temp.bool()
#Returns
tensor([False,  True, False,  True])

2

Convert boolean to number value:

a = torch.tensor([0,4,0,0,5,0.12,0.34,0,0])
print(a.gt(0)) # output in boolean dtype
# output: tensor([False,  True, False, False,  True,  True,  True, False, False])

print(a.gt(0).to(torch.float32)) # output in float32 dtype
# output: tensor([0., 1., 0., 0., 1., 1., 1., 0., 0.])

2
您可以按照以下方式使用比较:

您可以按照以下方式使用比较:

 >>> a = tensor([0,10,0,16])
 >>> result = (a == 0)
 >>> result
 tensor([ True, False,  True, False])

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接