当你使用gensim训练w2v模型时,它会存储每个单词的vocab和index信息。gensim使用这些信息将单词映射到其向量。
如果您要微调已有的w2v模型,则需要确保您的vocab是一致的。
参见附带的代码片段。
import os
import pickle
import numpy as np
import gensim
from gensim.models import Word2Vec, KeyedVectors
from gensim.models.callbacks import CallbackAny2Vec
import operator
os.mkdir("model_dir")
def load_vectors(token2id, path, limit=None):
embed_shape = (len(token2id), 300)
freqs = np.zeros((len(token2id)), dtype='f')
vectors = np.zeros(embed_shape, dtype='f')
i = 0
with open(path, encoding="utf8", errors='ignore') as f:
for o in f:
token, *vector = o.split(' ')
token = str.lower(token)
if len(o) <= 100:
continue
if limit is not None and i > limit:
break
vectors[token2id[token]] = np.array(vector, 'f')
i += 1
return vectors
embedding_name = "glove.840B.300d.txt"
data = "<training data(new line separated tect file)>"
token2id = {}
vocab_freq_dict = {}
id_ = 0
training_examples = []
file = open("{}".format(data),'r', encoding="utf-8")
for line in file.readlines():
words = line.strip().split(" ")
training_examples.append(words)
for word in words:
if word not in vocab_freq_dict:
vocab_freq_dict.update({word:0})
vocab_freq_dict[word] += 1
if word not in token2id:
token2id.update({word:id_})
id_ += 1
max_id = max(token2id.items(), key=operator.itemgetter(1))[0]
max_token_id = token2id[max_id]
with open(embedding_name, encoding="utf8", errors='ignore') as f:
for o in f:
token, *vector = o.split(' ')
token = str.lower(token)
if len(o) <= 100:
continue
if token not in token2id:
max_token_id += 1
token2id.update({token:max_token_id})
vocab_freq_dict.update({token:1})
with open("vocab_freq_dict","wb") as vocab_file:
pickle.dump(vocab_freq_dict, vocab_file)
with open("token2id", "wb") as token2id_file:
pickle.dump(token2id, token2id_file)
vectors = load_vectors(token2id, embedding_name)
vec = KeyedVectors(300)
vec.add(list(token2id.keys()), vectors, replace=True)
vectors = None
params = dict(min_count=1,workers=14,iter=6,size=300)
model = Word2Vec(**params)
model.build_vocab_from_freq(vocab_freq_dict)
idxmap = np.array([token2id[w] for w in model.wv.index2entity])
model.wv.vectors[:] = vec.vectors[idxmap]
model.trainables.syn1neg[:] = vec.vectors[idxmap]
model.train(training_examples, total_examples=len(training_examples), epochs=model.epochs)
output_path = 'model_dir/final_model.model'
model.save(output_path)
如果您有任何疑问,请留言评论。