Pytorch自定义激活函数?

14

我在Pytorch中实现自定义激活函数(例如Swish)时遇到了问题。如何才能正确地实现和使用自定义激活函数呢?

3个回答

24

根据您所需要的内容,有四种可能性,您需要自己回答两个问题:

Q1) 您的激活函数是否具有可学习的参数?

如果是,则必须将您的激活函数创建为nn.Module类,因为您需要存储这些权重。

如果是,则可以自由地创建一个普通函数或类,具体取决于哪个对您更方便。

Q2) 您的激活函数是否可以表示为现有PyTorch函数的组合?

如果是,则可以将其简单地编写为现有PyTorch函数的组合,并且不需要创建定义梯度的backward函数。

如果是,则需要手动编写梯度。

示例1: SiLU函数

SiLU函数f(x)= x * sigmoid(x)没有任何可学习的权重,完全可以使用现有的PyTorch函数来编写,因此可以将其定义为一个函数:

def silu(x):
    return x * torch.sigmoid(x)

然后,您可以像使用torch.relu或任何其他激活函数一样简单地使用它。

示例2:具有学习斜率的SiLU

在这种情况下,您有一个学习的参数——斜率,因此您需要将其制作成一个类。

class LearnedSiLU(nn.Module):
    def __init__(self, slope = 1):
        super().__init__()
        self.slope = slope * torch.nn.Parameter(torch.ones(1))

    def forward(self, x):
        return self.slope * x * torch.sigmoid(x)

示例3:使用backward函数

如果您有需要创建自己的梯度函数的东西,您可以参考这个例子:Pytorch:定义自定义函数


你的意思是:“……如果是这样,你别无选择,只能将你的激活函数创建为一个nn.Module类,因为你需要存储那些权重。” - Shiv Krishna Jaiswal
是的,那就是我想表达的意思。 - patapouf_ai

2
你可以编写一个自定义的激活函数,如下所示(例如,加权Tanh)。"最初的回答":

您可以编写一个自定义的激活函数,例如加权Tanh:

class weightedTanh(nn.Module):
    def __init__(self, weights = 1):
        super().__init__()
        self.weights = weights

    def forward(self, input):
        ex = torch.exp(2*self.weights*input)
        return (ex-1)/(ex+1)

如果你使用autograd兼容的操作,就不需要担心反向传播。

注:backpropagation即反向传播算法。


谢谢,但它给了我一个 -main_.Swish 不是模块子类的错误。 - ZeroMaxinumXZ
@ZeroMaxinumXZ请通过更新您的问题来分享您的代码。 - Wasi Ahmad
问题已经在 https://discuss.pytorch.org/t/custom-activation-functions/43055 得到解决。 - ZeroMaxinumXZ

1
我编写了以下的SinActivation子类,继承自nn.Module,用于实现sin激活函数。
class SinActivation(torch.nn.Module):
    def __init__(self):
        super(SinActivation, self).__init__()
        return
    def forward(self, x):
        return torch.sin(x)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接