如何进行多标签分层抽样?

8
我正在处理多标签数据,并希望使用分层抽样。假设我有10个类别,我们称它们为'ABCDEFGHIJ'。我有一个数据框,其中有10列对应于每个标签,包含有关条目的其他信息。我可以在n_entry * 10矩阵中提取那些10列,我将其称为label_values。
例如,label_values的一行如[0,0,1,1,0,0,0,0,0,0],这一行表示该条目具有标签C和标签D。
我想将我的数据拆分成训练集和验证集,并且希望在训练集和验证集中有相同比例的每个标签。为了进行拆分,我以前使用Sklearn train_test_split函数(在需要分层之前),它恰好有一个stratify参数。当前行为是将多标签行为转换为多类行为(我们认为[A,B]是与A类和B类完全不同的新品牌)。因此,有一些只有1个元素的类别,这会触发错误:
ValueError("The least populated class in y has only 1"
                         " member, which is too few. The minimum"
                         " number of groups for any class cannot"
                         " be less than 2.")

来自sklearn/model_selection/_split.py文件中的StratifiedShuffleSplit类的_iter_indices方法:

if np.min(class_counts) < 2:
        raise ValueError("The least populated class in y has only 1"
                         " member, which is too few. The minimum"
                         " number of groups for any class cannot"
                         " be less than 2.")

我的解决办法是覆盖此方法以删除此检查。这有效,我在训练和验证之间获得了更好的标签重新分配。然而,我的其中一个具有2个元素的标签完全在训练集中。这正常吗?
另一个问题:这是处理此问题的正确方式吗?还是您认为有更好的方法来获取多标签的分层训练测试拆分?
2个回答

12
正如您所注意到的,scikit-learn的train_test_split()的分层并不考虑标签本身,而是将其视为“标签集”。这对于多标签数据根本行不通,因为唯一组合的数量随标签数呈指数增长。在您的示例中,有1024种不同的可能标签组合。您需要至少两倍才能执行双向拆分,即使这样,每个拆分也只会得到每个组合的一个示例。
禁用检查的拆分可能相对有效,因为重复的标签集能够分层,但对于唯一的标签集,您只是允许scikit-learn随机拆分它们,这是无用和无效的。
2011年,Sechidis、Tsoumakas和Vlahavas提出了一种算法迭代分层, 它通过单独考虑每个标签来拆分多标签数据集,从具有最少正例的标签开始,逐步到最好表示的标签。
目前有两种实现可供使用:
1.iterative-stratification 2.scikit-multilearn的iterative_train_test_split() 假设您想要这些3个标签(L1,L2,L3)样本的双向拆分:
L1 L2 L3
--------
0  0  0
0  0  1
0  1  0
0  1  1
1  0  0 
1  0  1
1  1  0
1  1  1

有8个独特的标签集,每个标签都有4个正面例子。迭代分层会尝试给你两个包含来自每个标签平衡数量的例子的拆分,而不是随机拆分。一个示例拆分可能如下:

Split 1
-------
L1 L2 L3
0  0  1
0  1  0
1  0  1
1  1  0

Split 2
-------
L1 L2 L3
0  0  0
0  1  1
1  0  0
1  1  1

请注意,尽管每个标签集仍然是唯一的,但现在每个标签都在拆分中保持了良好的平衡。


3
最简单的解决方案是使用带有skmultilearn的多标签分层。快速示例如下:
from skmultilearn.model_selection import iterative_train_test_split
t_train, y_train, t_test, y_test = iterative_train_test_split(X, y, test_size = 0.2)

请考虑迭代分层是较慢的,对于大型数据集可能非常耗时。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接