如何检查两个Torch张量或矩阵是否相等?

93
我需要一个Torch命令,用于检查两个张量是否具有相同的内容,如果它们具有相同的内容则返回TRUE。
例如:
local tens_a = torch.Tensor({9,8,7,6});
local tens_b = torch.Tensor({9,8,7,6});

if (tens_a EQUIVALENCE_COMMAND tens_b) then ... end

在这个脚本中,我应该使用什么代替EQUIVALENCE_COMMAND

我尝试过用==,但不起作用。


1
为了允许浮点数差异,请参见检查PyTorch张量是否在epsilon内相等 - Tom Hale
6个回答

105
torch.eq(a, b)

eq()函数实现了==操作符,比较a中的每个元素和b中的对应元素是否相等(如果b是一个值),或者比较a中的每个元素和b中对应的元素是否相等(如果b是一个张量)。


@deltheil的另一种方法:

torch.all(tens_a.eq(tens_b))

7
如其他回答所述,使用当前版本的 PyTorch,.eq 方法返回一个张量,而.equal 实际上返回一个布尔值。 - fuzzyTew
有没有类似的版本?我想要一个计算范数 ||A - B|| 并检查它是否很小的东西(这与 A.allclose(B) 不同)。 - a06e

100

以下解决方案适用于我:

torch.equal(tensorA, tensorB)

来自文档

True表示两个张量具有相同的大小和元素,False则表示不同。


11
这个答案应该是解决这个问题的唯一答案,因为它能够匹配求助者所需的确切行为,而且如果张量不具有相同的形状,则不会进行任何计算,因此也是最有效的。 - Louis Lac
1
这是唯一正确的答案。 - orkenstein

32

要比较张量,您可以进行逐元素操作:

torch.eq 是逐元素的:

torch.eq(torch.tensor([[1., 2.], [3., 4.]]), torch.tensor([[1., 1.], [4., 4.]]))
tensor([[True, False], [False, True]])

或者使用torch.equal来比较整个张量是否相等:

torch.equal(torch.tensor([[1., 2.], [3, 4.]]), torch.tensor([[1., 1.], [4., 4.]]))
# False
torch.equal(torch.tensor([[1., 2.], [3., 4.]]), torch.tensor([[1., 2.], [3., 4.]]))
    # True

但是你可能会迷失,因为在某些时候有一些小差别你希望忽略。比如说,1.01.0000000001 这两个浮点数非常接近,你可能认为它们是相等的。对于这种比较,你可以使用torch.allclose

torch.allclose(torch.tensor([[1., 2.], [3., 4.]]), torch.tensor([[1., 2.000000001], [3., 4.]]))
# True
在某些情况下,检查逐个元素与元素的数量相等的重要性可能很大。如果您有两个张量dt1dt2,您可以使用dt1.nelement()获取dt1的元素数量。

应用以下公式,即可得出百分比:

print(torch.sum(torch.eq(dt1, dt2)).item()/dt1.nelement())

6
torch.allclose() 是我正在寻找的函数。 - Sachin Kumar
如何表示不等于? - Umair Javaid

11

如果您想忽略浮点数通常出现的小精度差异,请尝试以下方法。

torch.all(torch.lt(torch.abs(torch.add(tens_a, -tens_b)), 1e-12))

14
您可以使用torch.allclose()作为替代选项。 - irudyak

1

这种解决方案对我也很有效,而且似乎更自然。

torch.all(tensorA == tensorB)

输出结果为:

如果相等,则输出tensor(1, device='cuda:0', dtype=torch.uint8) 否则:输出tensor(0, device='cuda:0', dtype=torch.uint8)


-1
您可以将这两个张量转换为numpy数组:
local tens_a = torch.Tensor((9,8,7,6));
local tens_b = torch.Tensor((9,8,7,6));

a=tens_a.numpy()
b=tens_b.numpy()

然后就是类似这样的东西

np.sum(a==b)
4

会给你一个相当好的想法,关于它们是如何相等的。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接