如何在Scikit-Learn的随机森林分类器中设置子样本大小?特别是对于不平衡数据。

7

目前,我正在使用Sklearn实现RandomForestClassifier来处理我的不平衡数据。我对Sklearn中RF的工作原理并不是很清楚。以下是我的疑虑:

  1. 根据文档,似乎没有办法为每个树学习器设置子样本大小(即小于原始数据大小)。但是在随机森林算法中,我们需要为每棵树获取样本和特征的子集。我不确定是否可以通过Sklearn实现?如果可以,如何实现?

以下是Sklearn中RandomForestClassifier的描述。

“随机森林是一种元估计器,它在数据集的各个子样本上拟合多个决策树分类器,并使用平均值来提高预测准确性并控制过度拟合。如果bootstrap=True(默认情况下),则子样本大小始终与原始输入样本大小相同,但是如果bootstrap=True,则使用替换抽样。”

我之前发现了一个类似的问题。但是这个问题没有得到太多答案。

如何使用SciKit-Learn随机森林子样本大小等于原始训练数据大小?

  1. 对于不平衡的数据,如果我们可以通过Sklearn进行子样本拾取(即解决问题#1),那么我们可以做平衡随机森林吗?即对于每个树学习器,它将从较少的类中选择一个子集,并选择相同数量的样本从较多的类中组成一个完整的训练集,以使两个类具有平等的分布。然后重复此过程多次(即# of trees)。

谢谢! Cheng


1
对于第一个问题,看起来你无法为每个树选择子样本的大小。至于不平衡数据问题,这就是“class_weight”参数的作用所在。 - ktdrv
谢谢您的回答。但根据我的理解,“class_weight”参数旨在调整预测误差,使错误预测较少的类别受到更多惩罚。但它不能使每个树学习器实现两个类别之间的平衡采样。 - Cheng Fang
您还可以调整fit方法的sample_weight参数。除此之外,您可能需要手动复制较少频繁类别的样本。 - ktdrv
1个回答

9

虽然没有明显的方法,但你可以通过在sklearn.ensemble.forest中的抽样方法上进行修改来实现。

使用set_rf_samples(n)可以强制树对n行进行子采样,并调用reset_rf_samples()来对整个数据集进行采样。

适用于版本 < 0.22.0

from sklearn.ensemble import forest

def set_rf_samples(n):
    """ Changes Scikit learn's random forests to give each tree a random sample of
    n random rows.
    """
    forest._generate_sample_indices = (lambda rs, n_samples:
        forest.check_random_state(rs).randint(0, n_samples, n))

def reset_rf_samples():
    """ Undoes the changes produced by set_rf_samples.
    """
    forest._generate_sample_indices = (lambda rs, n_samples:
        forest.check_random_state(rs).randint(0, n_samples, n_samples))
  

对于版本 >=0.22.0

现在有一个可用的参数,请访问https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestRegressor.html

max_samples: int or float, default=None

   If bootstrap is True, the number of samples to draw from X to train each base estimator.

   If None (default), then draw X.shape[0] samples.

   If int, then draw max_samples samples.

   If float, then draw max_samples * X.shape[0] samples. Thus, max_samples should be in the interval (0, 1).

参考资料:fast.ai机器学习课程


今天我尝试使用这个方法,但在scikit-learn>=0.22上无法正常工作,因为"forest"模块现在已经移动到"_forest"中,需要进行一些更改才能使其正常工作。 - mediumnok

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接