ValueError: 处理多标签指示器和二进制数据混合时出错

4

我正在使用带有scikit-learn封装器的Keras。特别地,我想使用GridSearchCV进行超参数优化。

这是一个多类问题,即目标变量只能在n个类集合中选择一个标签。例如,目标变量可以是'Class1','Class2' ... 'Classn'。

# self._arch creates my model
nn = KerasClassifier(build_fn=self._arch, verbose=0)
clf = GridSearchCV(
  nn,
  param_grid={ ... },
  # I use f1 score macro averaged
  scoring='f1_macro',
  n_jobs=-1)

# self.fX is the data matrix
# self.fy_enc is the target variable encoded with one-hot format
clf.fit(self.fX.values, self.fy_enc.values)

问题在于,在交叉验证期间计算分数时,验证样本的真实标签被编码为one-hot,而由于某种原因,预测结果会折叠为二进制标签(当目标变量仅有两个类别时)。例如,以下是堆栈跟踪的最后一部分:
...........................................................................
/Users/fbrundu/.pyenv/versions/3.6.0/lib/python3.6/site-packages/sklearn/metrics/classification.py in _check_targets(y_true=array([[ 0.,  1.],
       [ 0.,  1.],
       [ 0... 0.,  1.],
       [ 0.,  1.],
       [ 0.,  1.]]), y_pred=array([1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1,...0, 1, 0, 0, 1, 0, 0, 0,
       0, 0, 0, 0, 1, 1]))
     77     if y_type == set(["binary", "multiclass"]):
     78         y_type = set(["multiclass"])
     79
     80     if len(y_type) > 1:
     81         raise ValueError("Can't handle mix of {0} and {1}"
---> 82                          "".format(type_true, type_pred))
        type_true = 'multilabel-indicator'
        type_pred = 'binary'
     83
     84     # We can't have more than one value on y_type => The set is no more needed
     85     y_type = y_type.pop()
     86

ValueError: Can't handle mix of multilabel-indicator and binary

我应该如何指示Keras/sklearn以one-hot编码返回预测结果?

1
当您直接使用fy而不对值进行编码时会发生什么。在多类问题中,这不应该是一个问题。在我看来,目标的一位有效编码仅在多标签问题中是必要的。 - Vivek Kumar
1个回答

5

在 Vivek 的评论后,我使用了原始(而非 one-hot-encoded)目标数组,并在我的 Keras 模型中进行了配置(请参见代码),使用了 sparse_categorical_crossentropy 损失,按照这个问题的评论建议。

arch.compile(
  optimizer='sgd',
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy'])

如果您已经解决了问题,请接受您的答案并关闭问题。 - Vivek Kumar
@VivekKumar 如果你了解 SO 的规则,你不能在问题发布后的前 2 天内接受答案。 - gc5
抱歉,我的错。 - Vivek Kumar
@fbrundu,你用它解决了你的问题吗?我仍然有问题。当我使用loss='sparse_categorical_crossentropy'; metrics=['f1_score']时,我的F1分数超过了1,这显然是错误的。 - Jundong
@Jundong 是的,我解决了我的问题。不幸的是,我无法在这里提供帮助,因为我已经没有原始代码了。 - gc5
1
@fbrundu 谢谢你的回复。我正在处理这个问题。我认为即使 y_true 没有编码,"categorical_crossentropy" 也可以起作用。这与文档中所说的不同。 - Jundong

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接