在xgb中使用f-score

9

我正在尝试在xgb分类器中使用scikit-learn中的f-score作为评估指标。以下是我的代码:

clf = xgb.XGBClassifier(max_depth=8, learning_rate=0.004,
                            n_estimators=100,
                            silent=False,   objective='binary:logistic',
                            nthread=-1, gamma=0,
                            min_child_weight=1, max_delta_step=0, subsample=0.8,
                            colsample_bytree=0.6,
                            base_score=0.5,
                            seed=0, missing=None)
scores = []
predictions = []
for train, test, ans_train, y_test in zip(trains, tests, ans_trains, ans_tests):
        clf.fit(train, ans_train, eval_metric=xgb_f1,
                    eval_set=[(train, ans_train), (test, y_test)],
                    early_stopping_rounds=900)
        y_pred = clf.predict(test)
        predictions.append(y_pred)
        scores.append(f1_score(y_test, y_pred))

def xgb_f1(y, t):
    t = t.get_label()
    return "f1", f1_score(t, y)

但是出现了一个错误:无法处理二进制和连续值的混合
1个回答

6
问题在于f1_score试图比较非二进制与二进制目标,并且默认情况下,此方法执行二进制平均。从文档中可以看到 "average : string, [None, ‘binary’ (default), ‘micro’, ‘macro’, ‘samples’, ‘weighted’]"。
无论如何,错误提示说你的预测是连续的,类似于这样的数组[0.001, 0.7889,0.33...],但你的目标是二进制的[0,1,0...]。因此,如果您知道阈值,请在将结果发送到f1_score函数之前对其进行预处理。阈值的通常值为0.5
经过测试的评估函数示例,不再输出错误:
def xgb_f1(y, t, threshold=0.5):
    t = t.get_label()
    y_bin = [1. if y_cont > threshold else 0. for y_cont in y] # binarizing your output
    return 'f1',f1_score(t,y_bin)

如@smci所建议的,一种更简洁/更高效的解决方案可能是:
def xgb_f1(y, t, threshold=0.5):
    t = t.get_label()
    y_bin = (y > threshold).astype(int) # works for both type(y) == <class 'numpy.ndarray'> and type(y) == <class 'pandas.core.series.Series'>
    return 'f1',f1_score(t,y_bin)

2
理想情况下,您应该在函数中将阈值参数进行参数化:xgb_f1(..., threshold=0.5),并将其默认设置为0.5。不要在函数内部突然出现没有解释的魔法数字。 - smci
顺便提一下,在您的列表推导式中,您不需要浮点数0.,1.,我认为您甚至不需要1 if cond else 0 ...表达式,您可以直接使用int(y_cont > threshold)。或者如果是pandas系列,则可以使用y_cont.gt(threshold).astype(int) - smci
1
感谢您的输入!老实说,自从回答以来已经过了很长时间,我甚至不记得“冗长”是有意为之还是只是匆忙想出答案的产物。坦白地说,在这种情况下我不知道该怎么做。我完全理解您的观点,但我有一种感觉,这里的额外冗长有助于“可读性”,使值明确而不仅仅是布尔转换。我知道这最终是主观感受,所以我会等待看看其他人是否通过投票支持您的建议,如果是,我会进行更改;-)这听起来对您好吗? - Guiem Bosch
1
当然,没问题。在这种情况下,我会展示给初学者的冗长代码习惯用法和更高级/更有效的代码,每个上面都有注释说明。 - smci

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接