我成功地将pyspark的输出适配到了pyLDAvis中。
以下代码需要稍作清理,但它可以正常运行。
from pyspark.ml.feature import StopWordsRemover,Tokenizer, RegexTokenizer, CountVectorizer, IDF
from pyspark.sql.functions import udf, col, size, explode, regexp_replace, trim, lower, lit
from pyspark.sql.types import ArrayType, StringType, DoubleType, IntegerType, LongType
from pyspark.ml.clustering import LDA
import pyLDAvis
def format_data_to_pyldavis(df_filtered, count_vectorizer, transformed, lda_model):
xxx = df_filtered.select((explode(df_filtered.words_filtered)).alias("words")).groupby("words").count()
word_counts = {r['words']:r['count'] for r in xxx.collect()}
word_counts = [word_counts[w] for w in count_vectorizer.vocabulary]
data = {'topic_term_dists': np.array(lda_model.topicsMatrix().toArray()).T,
'doc_topic_dists': np.array([x.toArray() for x in transformed.select(["topicDistribution"]).toPandas()['topicDistribution']]),
'doc_lengths': [r[0] for r in df_filtered.select(size(df_filtered.words_filtered)).collect()],
'vocab': count_vectorizer.vocabulary,
'term_frequency': word_counts}
return data
def filter_bad_docs(data):
bad = 0
doc_topic_dists_filtrado = []
doc_lengths_filtrado = []
for x,y in zip(data['doc_topic_dists'], data['doc_lengths']):
if np.sum(x)==0:
bad+=1
elif np.sum(x) != 1:
bad+=1
elif np.isnan(x).any():
bad+=1
else:
doc_topic_dists_filtrado.append(x)
doc_lengths_filtrado.append(y)
data['doc_topic_dists'] = doc_topic_dists_filtrado
data['doc_lengths'] = doc_lengths_filtrado
create a Spark Dataframe named df_filtered and it has the list of raw words.
It can be the output of StopWordsRemover
count_vectorizer = CountVectorizer(inputCol="words_filtered", outputCol="features", minDF=0.05, maxDF=0.5)
count_vectorizer = count_vectorizer.fit(df_filtered)
df_counted = count_vectorizer.transform(df_filtered)
idf = IDF(inputCol="features", outputCol="features_tfidf")
idf_model = idf.fit(df_counted)
df_tfidf = idf_model.transform(df_counted)
lda = LDA(k=2, maxIter=20, featuresCol='features_tfidf')
lda_model = lda.fit(df_tfidf)
transformed = lda_model.transform(df_tfidf)
data = format_data_to_pyldavis(df_filtered, count_vectorizer, transformed, lda_model)
filter_bad_docs(data)
py_lda_prepared_data = pyLDAvis.prepare(**data)
pyLDAvis.display(py_lda_prepared_data)