我试图在Python(sklearn版本)中使用xgboost执行多类文本分类,但有时它会出错,告诉我存在特征名称不匹配的问题。奇怪的是,有时它确实有效(也许4次中的1次),但这种不确定性使我现在难以依赖这个解决方案,即使它在没有进行任何真正的预处理的情况下显示了鼓舞人心的结果。
我已经在代码中提供了一些说明性样本数据,与我将要使用的数据相似。我目前拥有的代码如下:
更新的代码反映了maxymoo的建议
我常遇到的错误如下:
任何建议都将不胜感激!
我已经在代码中提供了一些说明性样本数据,与我将要使用的数据相似。我目前拥有的代码如下:
更新的代码反映了maxymoo的建议
import xgboost as xgb
import numpy as np
from sklearn.cross_validation import KFold, train_test_split
from sklearn.metrics import accuracy_score
from sklearn.feature_extraction.text import CountVectorizer
rng = np.random.RandomState(31337)
y = np.array([0, 1, 2, 1, 0, 3, 1, 2, 3, 0])
X = np.array(['milk honey bear bear honey tigger',
'tom jerry cartoon mouse cat cat WB',
'peppa pig mommy daddy george peppa pig pig',
'cartoon jerry tom silly',
'bear honey hundred year woods',
'ben holly elves fairies gaston fairy fairies castle king',
'tom and jerry mouse WB',
'peppa pig daddy pig rebecca rabit',
'elves ben holly little kingdom king big people',
'pot pot pot pot jar winnie pooh disney tigger bear'])
xgb_model = make_pipeline(CountVectorizer(), xgb.XGBClassifier())
kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng)
for train_index, test_index in kf:
xgb_model.fit(X[train_index],y[train_index])
predictions = xgb_model.predict(X[test_index])
actuals = y[test_index]
accuracy = accuracy_score(actuals, predictions)
print accuracy
我常遇到的错误如下:
Traceback (most recent call last):
File "main.py", line 95, in <module>
predictions = xgb_model.predict(X[test_index])
File "//anaconda/lib/python2.7/site-packages/xgboost-0.6-py2.7.egg/xgboost/sklearn.py", line 465, in predict
ntree_limit=ntree_limit)
File "//anaconda/lib/python2.7/site-packages/xgboost-0.6-py2.7.egg/xgboost/core.py", line 939, in predict
self._validate_features(data)
File "//anaconda/lib/python2.7/site-packages/xgboost-0.6-py2.7.egg/xgboost/core.py", line 1179, in _validate_features
data.feature_names))
ValueError: feature_names mismatch: ['f0', 'f1', 'f2', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9', 'f10', 'f11', 'f12', 'f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f20', 'f21', 'f22', 'f23', 'f24', 'f25', 'f26'] ['f0', 'f1', 'f2', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9', 'f10', 'f11', 'f12', 'f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f20', 'f21', 'f22', 'f23', 'f24']
expected f26, f25 in input data
任何建议都将不胜感激!