LightGBM多类别分类

18

我正在尝试使用Python中的LightGBM模型来对多分类问题(3类)进行分类。 我使用了以下参数。

params = {'task': 'train',
    'boosting_type': 'gbdt',
    'objective': 'multiclass',
    'num_class':3,
    'metric': 'multi_logloss',
    'learning_rate': 0.002296,
    'max_depth': 7,
    'num_leaves': 17,
    'feature_fraction': 0.4,
    'bagging_fraction': 0.6,
    'bagging_freq': 17}

数据集的所有类别特征都使用 LabelEncoder 进行了标签编码。在运行下面显示的带有 early_stoppingcv 后,我训练了模型。

lgb_cv = lgbm.cv(params, d_train, num_boost_round=10000, nfold=3, shuffle=True, stratified=True, verbose_eval=20, early_stopping_rounds=100)

nround = lgb_cv['multi_logloss-mean'].index(np.min(lgb_cv['multi_logloss-mean']))
print(nround)

model = lgbm.train(params, d_train, num_boost_round=nround)

训练完成后,我使用这个模型进行了预测:

preds = model.predict(test)
print(preds)             

我得到了一个嵌套数组的输出,格式如下。

[[  7.93856847e-06   9.99989550e-01   2.51164967e-06]
 [  7.26332978e-01   1.65316511e-05   2.73650491e-01]
 [  7.28564308e-01   8.36756769e-06   2.71427325e-01]
 ..., 
 [  7.26892634e-01   1.26915179e-05   2.73094674e-01]
 [  5.93217601e-01   2.07172044e-04   4.06575227e-01]
 [  5.91722491e-05   9.99883828e-01   5.69994435e-05]]

由于preds中的每个列表表示类别概率,因此我使用np.argmax()来查找类别,如下所示...

predictions = []

for x in preds:
    predictions.append(np.argmax(x))

在分析预测结果时,我发现我的预测只包含两个类别 - 0 和 1。第二大的类别是2,在训练集中出现频率很高,但是在预测结果中却没有出现。评估结果表明准确率约为 78%

那么,为什么我的模型没有对任何情况进行类别2的预测呢?我使用的参数有问题吗?

这难道不是解释模型所做的预测的正确方法吗?我需要更改参数吗?


我不确定这段代码出了什么问题,但我猜测你的问题似乎是二元分类,但你正在使用多类分类指标来衡量准确性。我建议你使用binary_logloss来解决这个问题。你可以在这里找到更多相关信息。 - Aditya
我在我的目标中有3个类。我已经进行了交叉检查。 - Sreeram TP
4个回答

4
尝试通过交换类0和类2进行故障排除,然后重新运行训练和预测过程。
如果新的预测结果只包含类1和类2(根据您提供的数据最有可能):
分类器可能没有学习第三类;也许它的特征与某个较大类别的特征重叠,在为了最小化目标函数而缺省使用了较大类别。尝试提供一个平衡的训练集(每个类别相同数量的样本)并重试。
如果新的预测结果包含所有3个类别:
则您的代码中出现了问题。需要更多信息才能确定出错原因。
希望这可以帮到您。

0

从您提供的输出来看,预测结果似乎没有问题。

模型产生了三个概率值,就像您展示的那样,仅从您提供的第一个输出 [ 7.93856847e-06 9.99989550e-01 2.51164967e-06] 来看,第二类的概率更高,所以我看不出有什么问题。

类别0是第一类,类别1实际上是第二类,类别2是第三类。所以我想这里没有问题。


模型甚至在训练过的样本上也无法预测任何输入样本的第三类别。 - Sreeram TP

-1

解决方案是:

best_preds_svm = [np.argmax(line) for line in preds]

然后您可以打印具有最合理结果的类。


-2
import pandas as pd

pd.DataFrame(preds).apply(lambda x: np.argmax(x), axis=1)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接