从torch.Tensor中删除一个元素

11

我正在尝试从张量中删除一个项目。

在下面的示例中,我如何从张量中删除第三个项目?

tensor([[-5.1949, -6.2621, -6.2051, -5.8983, -6.3586, -6.2434, -5.8923, -6.1901,
         -6.5713, -6.2396, -6.1227, -6.4196, -3.4311, -6.8903, -6.1248, -6.3813,
         -6.0152, -6.7449, -6.0523, -6.4341, -6.8579, -6.1961, -6.5564, -6.6520,
         -5.9976, -6.3637, -5.7560, -6.7946, -5.4101, -6.1310, -3.3249, -6.4584,
         -6.2202, -6.3663, -6.9293, -6.9262]], grad_fn=<SqueezeBackward1>)
4个回答

7

我认为使用索引来实现这个更易读。

t[t!=t[0,3]]

与下面的cat解决方案相同。 要小心:这通常适用于浮点数,但要注意,如果数组中[0,3]处的值出现多次,则会删除所有该项的出现。

4

您可以使用NumPy的r_索引技巧

y = x[:, np.r_[:3, 4:36]]

2
你可以先通过索引筛选数组,然后将它们连接起来。
t.shape
torch.Size([1, 36])

t = torch.cat((t[:,:3], t[:,4:]), axis = 1)

t.shape
torch.Size([1, 35])

1
如果你想按索引删除多个元素,你可以按照以下步骤进行操作。
# Removing elements
x = torch.arange(10)
y = torch.arange(4)*2
print(x)
print(y)
x[y] = -1
print(x[x != -1])

这会产生输出
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
tensor([0, 2, 4, 6])
tensor([1, 3, 5, 7, 8, 9])

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接