scikit-learn SGD中的class_weight参数是什么意思?

3

我经常使用scikit-learn,我想了解有关SGD中“class_weight”参数的一些见解。

我已经弄清楚了函数调用部分。

plain_sgd(coef, intercept, est.loss_function,
                 penalty_type, alpha, C, est.l1_ratio,
                 dataset, n_iter, int(est.fit_intercept),
                 int(est.verbose), int(est.shuffle), est.random_state,
                 pos_weight, neg_weight,
                 learning_rate_type, est.eta0,
                 est.power_t, est.t_, intercept_decay)

https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/linear_model/stochastic_gradient.py

在这之后,它会进入sgd_fast函数,但我不太擅长cpython。您能否就以下问题提供一些帮助:

  1. 在开发集中,我遇到了类偏差问题,其中正类别约为15k,负类别约为36k。使用class_weight是否可以解决这个问题?或者进行欠采样会更好。虽然我得到了更好的结果,但很难解释。
  2. 如果是,则它实际上是如何解决的?我的意思是它是应用于特征惩罚还是作为优化函数的权重。我该如何向普通人解释这个问题?
1个回答

6

class_weight可以帮助提高分类模型在不平衡数据上的ROC AUC或f1分数。

您可以尝试使用class_weight="auto"来选择与类别频率成反比的权重。您也可以尝试将自己的权重作为Python字典传递,其中类标签为键,权重为值。

可以通过交叉验证的网格搜索来调整权重。

内部是通过从class_weight派生sample_weight(取决于每个样本的类标签)来实现的。然后使用样本权重来缩放单个样本对用随机梯度下降训练的线性分类模型使用的损失函数的贡献。

特征惩罚是通过penaltyalpha超参数分别控制的。 sample_weight/class_weight对其没有影响。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接