我在使用PyTorch实现RCNN文本分类,需要使用
请问为什么这个置换是必要或有用的呢?
相关代码: 完整代码。
permute
函数对张量的维度进行两次置换。第一次出现在LSTM层之后和tanh之前,第二次出现在线性层之后和最大池化层之前。请问为什么这个置换是必要或有用的呢?
相关代码: 完整代码。
def forward(self, x):
# x.shape = (seq_len, batch_size)
embedded_sent = self.embeddings(x)
# embedded_sent.shape = (seq_len, batch_size, embed_size)
lstm_out, (h_n,c_n) = self.lstm(embedded_sent)
# lstm_out.shape = (seq_len, batch_size, 2 * hidden_size)
input_features = torch.cat([lstm_out,embedded_sent], 2).permute(1,0,2)
# final_features.shape = (batch_size, seq_len, embed_size + 2*hidden_size)
linear_output = self.tanh(
self.W(input_features)
)
# linear_output.shape = (batch_size, seq_len, hidden_size_linear)
linear_output = linear_output.permute(0,2,1) # Reshaping fot max_pool
max_out_features = F.max_pool1d(linear_output, linear_output.shape[2]).squeeze(2)
# max_out_features.shape = (batch_size, hidden_size_linear)
max_out_features = self.dropout(max_out_features)
final_out = self.fc(max_out_features)
return self.softmax(final_out)
其他代码库中的相似代码
其他RCNN的实现使用permute
或transpose
,以下是示例: