这不是对原始OP的答案,而是针对使用sklearn API并遇到此问题的人的答案。
对于那些使用sklearn API的人,特别是使用sklearn中的cross_val
方法之一的人,有两个解决方案可供考虑。
Sklearn API解决方案
对我有效的解决方案是将分类字段转换为pandas中的category
数据类型。
如果您正在使用pandas df,则LightGBM应自动将其视为分类。从文档中可以看出:
在Python包中,将从pandas分类中提取整数代码
这应该相当于在Dataset对象中设置分类变量的sklearn API中的等效项。
但请记住,LightGBM并不正式支持sklearn API的任何非核心参数,他们明确地表示了这一点:
在sklearn中不支持**kwargs,可能会导致意外问题。
自适应解决方案
另一种更可靠的解决方法是创建自己的包装类,在底层实现核心数据集/训练,但暴露出适合cv方法的fit/predict接口。这样,您只需要编写少量代码就可以获得lightGBM的全部功能。
以下是此解决方案的示例。
class LGBMSKLWrapper:
def __init__(self, categorical_variables, params):
self.categorical_variables = categorical_variables
self.params = params
self.model = None
def fit(self, X, y):
my_dataset = ltb.Dataset(X, y, categorical_feature=self.categorical_variables)
self.model = ltb.train(params=self.params, train_set=my_dataset)
def predict(self, X):
return self.model.predict(X)
上述代码允许您在创建对象时加载参数,然后在客户端调用“fit”时将其传递给训练。