方法“train_test_split”(scikit Learn)中的参数“stratify”

163

我正在尝试使用scikit Learn包中的train_test_split函数,但是我在stratify参数上遇到了麻烦。以下是代码:

from sklearn import cross_validation, datasets 

X = iris.data[:,:2]
y = iris.target

cross_validation.train_test_split(X,y,stratify=y)

然而,我一直遇到以下问题:

raise TypeError("Invalid parameters passed: %s" % str(options))
TypeError: Invalid parameters passed: {'stratify': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])}

有人知道发生了什么吗?下面是函数文档。

[...]

stratify:类似数组或None(默认为None)

如果不是None,则使用此作为标签数组以分层方式拆分数据。

新增于版本0.17:stratify拆分

[...]


没有了,全部解决了。 - Daneel Olivaw
6个回答

503

stratify 参数使得一个分割,产生的样本比例与提供给参数 stratify 的值的比例相同。

例如,如果变量 y 是一个二元分类变量,其取值为 01,且有25% 的零和 75% 的一,则 stratify=y 确保您的随机分割将有 25% 的 0 和 75% 的 1


186
这并不真正回答问题,但对于理解它的工作原理非常有用。非常感谢。 - Reed Jessen
10
我仍然难以理解为什么这种分层是必要的:如果数据中存在阶级不平衡,那么在对数据进行随机分割时,这种不平衡不会平均分布吗? - Holger Brandl
22
平均而言,它会被保留;使用分层抽样,它将被确保保留。 - Yonatan
13
对于非常小或非常不平衡的数据集,随机分割可能会完全从其中一个分组中排除某个类别。 - cddt
2
@HolgerBrandl 不错的问题!也许我们可以先使用 stratify 将数据集分成训练集和测试集。然后,为了纠正不平衡,您最终需要在训练集上运行过采样或欠采样。许多 Sklearn 分类器都有一个称为 class-weight 的参数,您可以将其设置为 balanced。最后,您还可以选择比准确度更适合不平衡数据集的指标。尝试使用 F1 或 ROC 下面的面积。 - Claude COULOMBE
显示剩余3条评论

111

对于通过谷歌搜索进入这里的未来的自己:

train_test_split现在在model_selection中,因此:

from sklearn.model_selection import train_test_split

# given:
# features: xs
# ground truth: ys

x_train, x_test, y_train, y_test = train_test_split(xs, ys,
                                                    test_size=0.33,
                                                    random_state=0,
                                                    stratify=ys)

使用它的方法是设置random_state,以实现可重复性。


84

Scikit-Learn只是告诉你它不认识参数“stratify”,并不意味着你使用了错误的参数。这是因为该参数在版本0.17中被添加,正如你引用文档所示。

所以你只需要更新Scikit-Learn就可以了。


1
我遇到了相同的错误,尽管我已经安装了0.21.2版本的scikit-learn。scikit-learn 0.21.2 py37h2a6a0b8_0 conda-forge - KareemJ

21
在这种情况下,分层意味着train_test_split方法返回的训练和测试子集具有与输入数据集相同比例的类标签。

12
我能给出的答案是,分层保留了目标列中数据分布的比例,并在train_test_split中呈现相同的分布比例。例如,如果问题是二元分类问题,目标列的比例为80%=“是”,20%=“否”。由于目标列中有4倍多的“是”比“否”,如果不进行分层拆分,我们可能会遇到只有“是”落入训练集,而所有“否”都落入测试集的麻烦。(即,训练集可能没有“否”在其目标列中)
因此,通过分层,训练集的目标列具有“80%的'是'和20%的'否'”,而测试集的目标列也分别具有“80%的'是'和20%的'否'”。
因此,Stratify 在训练集和测试集中实现了目标(标签)的均匀分布 - 就像在原始数据集中一样。
from sklearn.model_selection import train_test_split
X_train, y_train, X_test, y_test = train_test_split(features, target, test-size = 0.25, stratify = target, random_state = 43)

6

尝试运行这段代码,它“只是工作”:

from sklearn import cross_validation, datasets 

iris = datasets.load_iris()

X = iris.data[:,:2]
y = iris.target

x_train, x_test, y_train, y_test = cross_validation.train_test_split(X,y,train_size=.8, stratify=y)

y_test

array([0, 0, 0, 0, 2, 2, 1, 0, 1, 2, 2, 0, 0, 1, 0, 1, 1, 2, 1, 2, 0, 2, 2,
       1, 2, 1, 1, 0, 2, 1])

@user5767535,正如您所看到的,它在我的Ubuntu机器上运行良好,使用的是Python 3.5的Anaconda发行版,sklearn版本为'0.17'。我只能建议您再次检查是否正确输入了代码并更新您的软件。 - Sergey Bushmanov
2
您IP地址为143.198.54.68,由于运营成本限制,当前对于免费用户的使用频率限制为每个IP每72小时10次对话,如需解除限制,请点击左下角设置图标按钮(手机用户先点击左上角菜单按钮)。 - Sergey Bushmanov

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接