通过一些描述来说明如何使用Bert架构进行句子嵌入。
同时,Christian Arteaga的评论也阐述了选择正确模型以完成正确任务的重要性。
我使用Hugging Face提供的Bert模型和分词器,而不是使用sentence_transformer包装器,因为这将更好地向那些刚开始学习NLP的用户展示它们的工作原理。
Bert模型 - https://huggingface.co/transformers/v3.0.2/model_doc/bert.html
注意 - 这只是伪代码;还请参阅https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
'''
Adapted and extended from
https://github.com/huggingface/transformers/issues/1950#issuecomment-558679189
'''
import pandas as pd
from transformers import BertTokenizer, BertModel
from sklearn.metrics.pairwise import cosine_similarity
import torch
def get_sentence_similarity(tokenizer,model,s1,s2):
s1 = tokenizer.encode(s1)
s2 = tokenizer.encode(s2)
print("1 len(s1) s1",len(s1),s1)
print("1 len(s2) s2",len(s2),s2)
s1 = torch.tensor(s1)
s1 = s1.unsqueeze(0)
s2 = torch.tensor(s2).unsqueeze(0)
with torch.no_grad():
output_1 = model(s1)
output_2 = model(s2)
logits_s1 = output_1[0]
logits_s2 = output_2[0].detach()
logits_s1 = logits_s1.detach()
print("logits_s1.shape",logits_s1.shape )
print("logits_s2.shape",logits_s2.shape )
logits_s1 = torch.squeeze(logits_s1)
logits_s2 = torch.squeeze(logits_s2)
print("logits_s1.shape",logits_s1.shape )
print("logits_s2.shape",logits_s2.shape )
a = logits_s1.reshape(1,logits_s1.numel())
b = logits_s2.reshape(1,logits_s2.numel())
print("a.shape",a.shape )
print("b.shape",b.shape )
if a.shape[1] < b.shape[1]:
pad_size = (0, b.shape[1] - a.shape[1])
a = torch.nn.functional.pad(a, pad_size, mode='constant', value=0)
else:
pad_size = (0, a.shape[1] - b.shape[1])
b = torch.nn.functional.pad(b, pad_size, mode='constant', value=0)
print("After padding")
print("a.shape",a.shape )
print("b.shape",b.shape )
cos_sim = cosine_similarity(a,b)
return cos_sim
if __name__ == "__main__":
s1 = "John loves dogs"
s2 = "dogs love John"
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased")
model.eval()
cos_sim = get_sentence_similarity(tokenizer,model,s1,s2)
print("got cosine similarity",cos_sim)
tokenizer = BertTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = BertModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model.eval()
cos_sim = get_sentence_similarity(tokenizer,model,s1,s2)
print("got cosine similarity",cos_sim)