我在Pytorch中实现自定义激活函数(例如Swish)时遇到了问题。如何才能正确地实现和使用自定义激活函数呢?
我在Pytorch中实现自定义激活函数(例如Swish)时遇到了问题。如何才能正确地实现和使用自定义激活函数呢?
根据您所需要的内容,有四种可能性,您需要自己回答两个问题:
Q1) 您的激活函数是否具有可学习的参数?
如果是是,则必须将您的激活函数创建为nn.Module
类,因为您需要存储这些权重。
如果是否,则可以自由地创建一个普通函数或类,具体取决于哪个对您更方便。
Q2) 您的激活函数是否可以表示为现有PyTorch函数的组合?
如果是是,则可以将其简单地编写为现有PyTorch函数的组合,并且不需要创建定义梯度的backward
函数。
如果是否,则需要手动编写梯度。
示例1: SiLU函数
SiLU函数f(x)= x * sigmoid(x)
没有任何可学习的权重,完全可以使用现有的PyTorch函数来编写,因此可以将其定义为一个函数:
def silu(x):
return x * torch.sigmoid(x)
然后,您可以像使用torch.relu
或任何其他激活函数一样简单地使用它。
示例2:具有学习斜率的SiLU
在这种情况下,您有一个学习的参数——斜率,因此您需要将其制作成一个类。
class LearnedSiLU(nn.Module):
def __init__(self, slope = 1):
super().__init__()
self.slope = slope * torch.nn.Parameter(torch.ones(1))
def forward(self, x):
return self.slope * x * torch.sigmoid(x)
示例3:使用backward函数
如果您有需要创建自己的梯度函数的东西,您可以参考这个例子:Pytorch:定义自定义函数
您可以编写一个自定义的激活函数,例如加权Tanh:
class weightedTanh(nn.Module):
def __init__(self, weights = 1):
super().__init__()
self.weights = weights
def forward(self, input):
ex = torch.exp(2*self.weights*input)
return (ex-1)/(ex+1)
如果你使用autograd
兼容的操作,就不需要担心反向传播。
注:backpropagation即反向传播算法。
SinActivation
子类,继承自nn.Module
,用于实现sin
激活函数。class SinActivation(torch.nn.Module):
def __init__(self):
super(SinActivation, self).__init__()
return
def forward(self, x):
return torch.sin(x)