为什么在序列分类中(DistilBertForSequenceClassification),要选择第一个隐藏状态?

9
HuggingFace进行序列分类的最后几层中,他们采用了变压器输出序列长度的第一个隐藏状态用于分类。
hidden_state = distilbert_output[0]  # (bs, seq_len, dim) <-- transformer output
pooled_output = hidden_state[:, 0]  # (bs, dim)           <-- first hidden state
pooled_output = self.pre_classifier(pooled_output)  # (bs, dim)
pooled_output = nn.ReLU()(pooled_output)  # (bs, dim)
pooled_output = self.dropout(pooled_output)  # (bs, dim)
logits = self.classifier(pooled_output)  # (bs, dim)

选择第一个隐藏状态与选择最后一个、平均值或使用Flatten层相比,是否有任何好处?

1个回答

6

是的,这与BERT的训练方式直接相关。具体来说,我鼓励您查看原始BERT论文,在其中作者介绍了[CLS]标记的含义:

[CLS]是添加在每个输入示例前面的特殊符号[...]。

具体而言,它用于分类目的,因此对于任何分类任务的微调,第一个且最简单的选择。您相关的代码片段所做的基本上只是提取此[CLS]标记。

不幸的是,Huggingface库的DistilBERT文档没有明确提到这一点,而是需要查看他们的BERT文档,在那里他们还强调了一些与[CLS]标记类似的问题,类似于您的担忧:

除了使用MLM,BERT还使用了下一句预测(NSP)目标来训练,使用[CLS]标记作为序列近似。用户可以使用此标记(在使用特殊标记构建的序列中的第一个标记)获得序列预测而不是标记预测。然而,对序列进行平均可能比使用[CLS]标记产生更好的结果。

如果对序列的嵌入求平均可以得到更好的结果,为什么这些作者没有采用这种方法呢? - avocado
1
我认为选择另一种方法会更加计算密集,因此可能只会带来微小的收益,不值得。 - dennlinger

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接