假设在PyTorch中我有model1
和model2
,它们具有相同的架构。 它们使用相同的数据进行训练或者其中一个模型是另一个模型的早期版本,但这对问题来说并不重要。现在我想将model
的权重设置为model1
和model2
权重的平均值。 在PyTorch中我该怎么做?
假设在PyTorch中我有model1
和model2
,它们具有相同的架构。 它们使用相同的数据进行训练或者其中一个模型是另一个模型的早期版本,但这对问题来说并不重要。现在我想将model
的权重设置为model1
和model2
权重的平均值。 在PyTorch中我该怎么做?
beta = 0.5 #The interpolation parameter
params1 = model1.named_parameters()
params2 = model2.named_parameters()
dict_params2 = dict(params2)
for name1, param1 in params1:
if name1 in dict_params2:
dict_params2[name1].data.copy_(beta*param1.data + (1-beta)*dict_params2[name1].data)
model.load_state_dict(dict_params2)
摘自pytorch论坛。您可以获取参数,进行转换并重新加载,但请确保维度匹配。
此外,我对您的研究结果非常感兴趣。