如何在Pytorch中创建上三角矩阵?

3

这是一个简单的问题,但有没有一种原生的方法在Pytorch中从现有矩阵创建一个上三角矩阵?我考虑使用掩码,但即使使用掩码也需要创建上三角矩阵。


4
我自己从未使用过,但你看过torch.triu()了吗? - JoshVarty
谢谢!看起来他们实际上移植了所有的np功能 :) - information_interchange
2个回答

4
import torch
upper_tri = torch.ones(rol, col).triu()

Eg:

>> mat = torch.ones(3, 3).triu()
>> print(mat)
tensor([[1., 1., 1.],
        [0., 1., 1.],
        [0., 0., 1.]])

1
import torch

l = torch.tril(torch.ones(row, column))

这将返回大小为(行,列)的矩阵的下三角部分。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接