使用BERT进行下一句预测

11

Google的BERT是在下一句预测任务上预训练的,但我想知道是否可以在新数据上调用下一句预测函数。

这个想法是:给定句子A和句子B,我想要一个概率标签,表示句子B是否跟在句子A后面。BERT是在大量数据集上预训练的,因此我希望在新的句子数据上使用这种下一句预测方法。但我似乎无法弄清楚是否可以调用这个下一句预测函数,如果可以,如何调用。谢谢你的帮助!

2个回答

19

Aerin的答案已经过时了。 HuggingFace库(现在称为transformers)在过去几个月中发生了很多变化。这是一个使用下一句预测(NSP)模型以及如何从中提取概率的示例。注意,只有使用具有NSP任务的预训练头的模型才能很好地工作。

from torch.nn.functional import softmax
from transformers import BertForNextSentencePrediction, BertTokenizer


seq_A = 'I like cookies !'
seq_B = 'Do you like them ?'

# load pretrained model and a pretrained tokenizer
model = BertForNextSentencePrediction.from_pretrained('bert-base-cased')
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

# encode the two sequences. Particularly, make clear that they must be 
# encoded as "one" input to the model by using 'seq_B' as the 'text_pair'
encoded = tokenizer.encode_plus(seq_A, text_pair=seq_B, return_tensors='pt')
print(encoded)
# {'input_ids': tensor([[  101,   146,  1176, 18621,   106,   102,  2091,  1128,  1176,  1172, 136,   102]]),
#  'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]]),
#  'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
# NOTE how the token_type_ids are 0 for all tokens in seq_A and 1 for seq_B, 
# this way the model knows which token belongs to which sequence

# a model's output is a tuple, we only need the output tensor containing
# the relationships which is the first item in the tuple
seq_relationship_logits = model(**encoded)[0]

# we still need softmax to convert the logits into probabilities
# index 0: sequence B is a continuation of sequence A
# index 1: sequence B is a random sequence
probs = softmax(seq_relationship_logits, dim=1)

print(seq_relationship_logits)
print(probs)
# tensor([[9.9993e-01, 6.7607e-05]], grad_fn=<SoftmaxBackward>)
# very high value for index 0: high probability of seq_B being a continuation of seq_A

1
是的,我指的是“为什么”。谢谢。关于 BertForNextSentencePrediction,无论是训练时间还是预测时间,它是否总是处理两个句子? - stackoverflowuser2010
6
似乎对于seq_B中几乎任何句子都会给出高分。例如,我尝试使用seq_B ='blah blah blah'seq_A仍然是'I like cookies!'...模型仍然在probs张量的索引0处输出了非常高的值:tensor([[0.9704, 0.0296]] - AruniRC
1
嗯,它可能已经改变了。BertForNextSentencePrediction的示例现在有两个句子。 - NatalieL
@BramVanroy 关于您最后的评论,'bert-base-cased' 应该被训练为 NSP,并且应该有可用的权重,是这样吗? - amiola
1
如果我没记错的话,NSP分类头的权重是不可用的,也从未公开过。但我想这很容易自己测试! - Bram Vanroy
显示剩余3条评论

10

Hugging Face 为您做到了这一点:https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/modeling.py#L854

class BertForNextSentencePrediction(BertPreTrainedModel):
    """BERT model with next sentence prediction head.
    This module comprises the BERT model followed by the next sentence classification head.
    Params:
        config: a BertConfig class instance with the configuration to build a new model.
    Inputs:
        `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
            with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
            `extract_features.py`, `run_classifier.py` and `run_squad.py`)
        `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
            types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
            a `sentence B` token (see BERT paper for more details).
        `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
            selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
            input sequence length in the current batch. It's the mask that we typically use for attention when
            a batch has varying length sentences.
        `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size]
            with indices selected in [0, 1].
            0 => next sentence is the continuation, 1 => next sentence is a random sentence.
    Outputs:
        if `next_sentence_label` is not `None`:
            Outputs the total_loss which is the sum of the masked language modeling loss and the next
            sentence classification loss.
        if `next_sentence_label` is `None`:
            Outputs the next sentence classification logits of shape [batch_size, 2].
    Example usage:
    ```python
    # Already been converted into WordPiece token ids
    input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
    input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
    token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
    config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
        num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
    model = BertForNextSentencePrediction(config)
    seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
    ```
    """
    def __init__(self, config):
        super(BertForNextSentencePrediction, self).__init__(config)
        self.bert = BertModel(config)
        self.cls = BertOnlyNSPHead(config)
        self.apply(self.init_bert_weights)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None):
        _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
                                     output_all_encoded_layers=False)
        seq_relationship_score = self.cls( pooled_output)

        if next_sentence_label is not None:
            loss_fct = CrossEntropyLoss(ignore_index=-1)
            next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
            return next_sentence_loss
        else:
            return seq_relationship_score

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接