分层 K 折交叉验证:IndexError:数组索引过多

9
使用sklearn的StratifiedKFold函数,有人能帮我理解这里的错误吗?
我猜它与我的标签输入数组有关,我注意到当我打印它们(在这个例子中的前16个)时,索引从0到15,但是多了一个我没想到的额外的0。也许我只是Python新手,但那看起来很奇怪。
有人看到这里的错误了吗?
文档:http://scikit-learn.org...StratifiedKFold.html 代码:
import nltk
import sklearn

print('The nltk version is {}.'.format(nltk.__version__))
print('The scikit-learn version is {}.'.format(sklearn.__version__))

print type(skew_gendata_targets.values), skew_gendata_targets.values.shape
print skew_gendata_targets.head(16)

skew_sfold10 = cross_validation.StratifiedKFold(skew_gendata_targets.values, n_folds=10, shuffle=True, random_state=20160121)

结果

The nltk version is 3.1.
The scikit-learn version is 0.17.
<type 'numpy.ndarray'> (500L, 1L)
    0
0   0
1   0
2   0
3   0
4   0
5   0
6   0
7   0
8   0
9   0
10  0
11  0
12  0
13  0
14  1
15  0
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-373-653b6010b806> in <module>()
      8 print skew_gendata_targets.head(16)
      9 
---> 10 skew_sfold10 = cross_validation.StratifiedKFold(skew_gendata_targets.values, n_folds=10, shuffle=True, random_state=20160121)
     11 
     12 #print '\nSkewed Generated Dataset (', len(skew_gendata_data), ')'

d:\Program Files\Anaconda2\lib\site-packages\sklearn\cross_validation.pyc in __init__(self, y, n_folds, shuffle, random_state)
    531         for test_fold_idx, per_label_splits in enumerate(zip(*per_label_cvs)):
    532             for label, (_, test_split) in zip(unique_labels, per_label_splits):
--> 533                 label_test_folds = test_folds[y == label]
    534                 # the test split can be too big because we used
    535                 # KFold(max(c, self.n_folds), self.n_folds) instead of

IndexError: too many indices for array
1个回答

13

检查 skew_gendata_targets.values 的形状。您会发现它不是 StratifiedKFold 预期的 1 维数组 (形状为 (500,)),而是一个(500,1) 数组。SKlearn 将其视为不同的数据类型,而不是强制将它们转换为相同的类型。如果这有帮助,请告诉我。


在问题的输出中有打印内容:print type(skew_gendata_targets.values), skew_gendata_targets.values.shape,它是一个(500,1)的numpy数组。我是一个Matlab迷被扔进了Python的坑里,不知道500x1和500xnada矩阵/数组/东西之间的区别。至少在Matlab世界里没有区别。 - David Parks
2
是的,这很不幸,也有些令人困惑。在进行'*'等操作时,区别非常重要。在一种情况下,Pandas/numpy将执行逐元素乘法,而在另一种情况下,它将执行矩阵乘法。希望在将其强制转换为(500,)数组后,StratifiedKFold操作能够正常工作。 - Brian
1
我明白了,重新塑造矩阵是Matlab程序员可以理解的内容,这似乎已经解决了问题:np.reshape(skew_gendata_targets.values,[500,]),谢谢! - David Parks

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接