当数组维度为3时,混淆矩阵错误

8
这段代码:
from pandas_ml import ConfusionMatrix
y_actu = [1,2]
y_pred = [1,2]
cm = ConfusionMatrix(y_actu, y_pred)
cm.print_stats()

打印:
population: 2
P: 1
N: 1
PositiveTest: 1
NegativeTest: 1
TP: 1
TN: 1
FP: 0
FN: 0
TPR: 1.0
TNR: 1.0
PPV: 1.0
NPV: 1.0
FPR: 0.0
FDR: 0.0
FNR: 0.0
ACC: 1.0
F1_score: 1.0
MCC: 1.0
informedness: 1.0
markedness: 1.0
prevalence: 0.5
LRP: inf
LRN: 0.0
DOR: inf
FOR: 0.0
/opt/conda/lib/python3.5/site-packages/pandas_ml/confusion_matrix/bcm.py:332: RuntimeWarning: divide by zero encountered in double_scalars
  return(np.float64(self.TPR) / self.FPR)

这很正常。

然而,当我修改代码为以下内容时:

from pandas_ml import ConfusionMatrix
y_actu = [1,2,3]
y_pred = [1,2,3]
cm = ConfusionMatrix(y_actu, y_pred)
cm.print_stats()

使用:

y_actu = [1,2,3]
y_pred = [1,2,3]

这会导致出现以下错误:

OrderedDict([('Accuracy', 1.0), ('95% CI', (0.29240177382128668, nan)), ('No Information Rate', 'ToDo'), ('P-Value [Acc > NIR]', 0.29629629629629622), ('Kappa', 1.0), ("Mcnemar's Test P-Value", 'ToDo')])

ValueErrorTraceback (most recent call last)
<ipython-input-30-d8c5dc2bea73> in <module>()
      3 y_pred = [1,2,3]
      4 cm = ConfusionMatrix(y_actu, y_pred)
----> 5 cm.print_stats()

/opt/conda/lib/python3.5/site-packages/pandas_ml/confusion_matrix/abstract.py in print_stats(self, lst_stats)
    446         Prints statistics
    447         """
--> 448         print(self._str_stats(lst_stats))
    449 
    450     def get(self, actual=None, predicted=None):

/opt/conda/lib/python3.5/site-packages/pandas_ml/confusion_matrix/abstract.py in _str_stats(self, lst_stats)
    427         }
    428 
--> 429         stats = self.stats(lst_stats)
    430 
    431         d_stats_str = collections.OrderedDict([

/opt/conda/lib/python3.5/site-packages/pandas_ml/confusion_matrix/abstract.py in stats(self, lst_stats)
    390         d_stats = collections.OrderedDict()
    391         d_stats['cm'] = self
--> 392         d_stats['overall'] = self.stats_overall
    393         d_stats['class'] = self.stats_class
    394         return(d_stats)

/opt/conda/lib/python3.5/site-packages/pandas_ml/confusion_matrix/cm.py in __getattr__(self, attr)
     33         Returns (weighted) average statistics
     34         """
---> 35         return(self._avg_stat(attr))

/opt/conda/lib/python3.5/site-packages/pandas_ml/confusion_matrix/abstract.py in _avg_stat(self, stat)
    509             v = getattr(binary_cm, stat)
    510             print(v)
--> 511             s_values[cls] = v
    512         value = (s_values * self.true).sum() / self.population
    513         return(value)

/opt/conda/lib/python3.5/site-packages/pandas/core/series.py in __setitem__(self, key, value)
    771         # do the setitem
    772         cacher_needs_updating = self._check_is_chained_assignment_possible()
--> 773         setitem(key, value)
    774         if cacher_needs_updating:
    775             self._maybe_update_cacher()

/opt/conda/lib/python3.5/site-packages/pandas/core/series.py in setitem(key, value)
    767                     pass
    768 
--> 769             self._set_with(key, value)
    770 
    771         # do the setitem

/opt/conda/lib/python3.5/site-packages/pandas/core/series.py in _set_with(self, key, value)
    809             if key_type == 'integer':
    810                 if self.index.inferred_type == 'integer':
--> 811                     self._set_labels(key, value)
    812                 else:
    813                     return self._set_values(key, value)

/opt/conda/lib/python3.5/site-packages/pandas/core/series.py in _set_labels(self, key, value)
    826         if mask.any():
    827             raise ValueError('%s not contained in the index' % str(key[mask]))
--> 828         self._set_values(indexer, value)
    829 
    830     def _set_values(self, key, value):

/opt/conda/lib/python3.5/site-packages/pandas/core/series.py in _set_values(self, key, value)
    831         if isinstance(key, Series):
    832             key = key._values
--> 833         self._data = self._data.setitem(indexer=key, value=value)
    834         self._maybe_update_cacher()
    835 

/opt/conda/lib/python3.5/site-packages/pandas/core/internals.py in setitem(self, **kwargs)
   3166 
   3167     def setitem(self, **kwargs):
-> 3168         return self.apply('setitem', **kwargs)
   3169 
   3170     def putmask(self, **kwargs):

/opt/conda/lib/python3.5/site-packages/pandas/core/internals.py in apply(self, f, axes, filter, do_integrity_check, consolidate, **kwargs)
   3054 
   3055             kwargs['mgr'] = self
-> 3056             applied = getattr(b, f)(**kwargs)
   3057             result_blocks = _extend_blocks(applied, result_blocks)
   3058 

/opt/conda/lib/python3.5/site-packages/pandas/core/internals.py in setitem(self, indexer, value, mgr)
    685                         indexer.dtype == np.bool_ and
    686                         len(indexer[indexer]) == len(value)):
--> 687                     raise ValueError("cannot set using a list-like indexer "
    688                                      "with a different length than the value")
    689 

ValueError: cannot set using a list-like indexer with a different length than the value

我发现了一个类似的问题,链接如下:question。其中提到:

在赋值时使用本地列表是不被允许的,也不建议这样做。

什么是本地列表?我是否创建了一个?


你尝试过Scikit学习吗?http://stackoverflow.com/questions/43697980/is-there-something-already-implemented-in-python-to-calculate-tp-tn-fp-and-fn/43698347#43698347 - Sidon
@Sidon 谢谢,是的,我“稍微”研究了一下,pandas ml通过print_stats方法“开箱即用”地提供了许多有用的数据集统计信息,巧合的是,你链接的问题发布者也在询问这个。感谢提供链接,可视化效果很吸引人。 - blue-sky
4个回答

6
我建议使用scikit-learn中的confusion_matrix。您提到的其他指标,如精确度、召回率和F1分数也可以从sklearn.metrics中获得。
>>> from sklearn.metrics import confusion_matrix
>>> y_actu = [1,2,3]
>>> y_pred = [1,2,3]
>>> confusion_matrix(y_actu, y_pred)
array([[1, 0, 0],
       [0, 1, 0],
       [0, 0, 1]])

4
我还使用并推荐sklearn的confusion_matrix函数。我个人也会保留一个方便的"漂亮打印混淆矩阵"函数,其中包含一些额外的便利功能:
  • 类标签打印在混淆矩阵轴上
  • 混淆矩阵统计量归一化,使所有单元格总和为1
  • 混淆矩阵单元格颜色根据单元格值进行缩放
  • 混淆矩阵下面打印其他指标,如F-score等。
像这样: enter image description here 这是绘图函数,主要基于Scikit-Learn文档中的此示例
import matplotlib.pyplot as plt
import itertools
from sklearn.metrics import classification_report

def pretty_print_conf_matrix(y_true, y_pred, 
                             classes,
                             normalize=False,
                             title='Confusion matrix',
                             cmap=plt.cm.Blues):
    """
    Mostly stolen from: http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html#sphx-glr-auto-examples-model-selection-plot-confusion-matrix-py

    Normalization changed, classification_report stats added below plot
    """

    cm = confusion_matrix(y_true, y_pred)

    # Configure Confusion Matrix Plot Aesthetics (no text yet) 
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title, fontsize=14)
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)
    plt.ylabel('True label', fontsize=12)
    plt.xlabel('Predicted label', fontsize=12)

    # Calculate normalized values (so all cells sum to 1) if desired
    if normalize:
        cm = np.round(cm.astype('float') / cm.sum(),2) #(axis=1)[:, np.newaxis]

    # Place Numbers as Text on Confusion Matrix Plot
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j],
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black",
                 fontsize=12)


    # Add Precision, Recall, F-1 Score as Captions Below Plot
    rpt = classification_report(y_true, y_pred)
    rpt = rpt.replace('avg / total', '      avg')
    rpt = rpt.replace('support', 'N Obs')

    plt.annotate(rpt, 
                 xy = (0,0), 
                 xytext = (-50, -140), 
                 xycoords='axes fraction', textcoords='offset points',
                 fontsize=12, ha='left')    

    # Plot
    plt.tight_layout()

以下是使用鸢尾花数据生成绘图图像的示例:

from sklearn import datasets
from sklearn.svm import SVC

#get data, make predictions
(X,y) = datasets.load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X,y, train_size=0.5)

clf = SVC()
clf.fit(X_train,y_train)
y_test_pred = clf.predict(X_test)


# Plot Confusion Matrix
plt.style.use('classic')
plt.figure(figsize=(3,3))
pretty_print_conf_matrix(y_test, y_test_pred, 
                         classes= ['0', '1', '2'],
                         normalize=True, 
                         title='Confusion Matrix')

我喜欢你的函数,但是我需要更改哪个部分才能使数据正确显示而不会覆盖矩阵?请参见此处 - OscarVanL
如果您仍在寻找问题的解决方案,请尝试更改xytext =(-50,-160),fontsize = 12,rotation = 45的值,直到获得清晰的图像。 - user1410665

2
有趣的是,当我运行您的代码时,并没有得到您收到的错误,代码完美运行。建议您通过以下方式升级pandas_ml库:
pip install --upgrade pandas_ml

此外,您需要通过运行以下命令来升级pandas:
pip install --upgrade pandas

如果那样不起作用,您可以使用pandas本身来创建混淆矩阵:

import pandas as pd
y_actu = pd.Series([1, 2, 3], name='Actual')
y_pred = pd.Series([1, 2, 3], name='Predicted')
df_confusion = pd.crosstab(y_actu, y_pred)
print df_confusion

这将为您提供所需的表格。

谢谢,但您没有使用“cm = ConfusionMatrix(y_actu,y_pred)”,因此无法使用“print_stats()”打印统计信息? - blue-sky
使用Python 2,我能够直接运行你的代码并获得所需的统计结果。你使用的是哪个版本的Python? - Ajax1234
@blue-sky,请看一下我上面的答案。我通过升级pandas和pandas_ml两个库来解决了这个问题。现在ConfusionMatrix可以很好地处理多类别示例[1,2,3]了。 - Alexey Trofimov

1
似乎错误不是由于数组维度引起的:
from pandas_ml import ConfusionMatrix
y_actu = [1,2,2]
y_pred = [1,1,2]
cm = ConfusionMatrix(y_actu, y_pred)
cm.print_stats()

这个(二分类问题)运行良好。

可能多类分类问题的混淆矩阵出了问题。

更新: 我刚刚执行了以下步骤:

conda update pandas

获取pandas 0.20.1

然后
pip install -U pandas_ml

现在,多分类混淆矩阵的一切都很好:

from pandas_ml import ConfusionMatrix
y_actu = [1,2,3]
y_pred = [1,2,3]
cm = ConfusionMatrix(y_actu, y_pred)
cm.print_stats()

我得到了输出:
Class Statistics:

Classes                                       1         2         3
Population                                    3         3         3
P: Condition positive                         1         1         1
N: Condition negative                         2         2         2
Test outcome positive                         1         1         1
Test outcome negative                         2         2         2
TP: True Positive                             1         1         1
TN: True Negative                             2         2         2
FP: False Positive                            0         0         0
FN: False Negative                            0         0         0
TPR: (Sensitivity, hit rate, recall)          1         1         1
TNR=SPC: (Specificity)                        1         1         1
PPV: Pos Pred Value (Precision)               1         1         1
NPV: Neg Pred Value                           1         1         1
FPR: False-out                                0         0         0
FDR: False Discovery Rate                     0         0         0
FNR: Miss Rate                                0         0         0
ACC: Accuracy                                 1         1         1
F1 score                                      1         1         1
MCC: Matthews correlation coefficient         1         1         1
Informedness                                  1         1         1
Markedness                                    1         1         1
Prevalence                             0.333333  0.333333  0.333333
LR+: Positive likelihood ratio              inf       inf       inf
LR-: Negative likelihood ratio                0         0         0
DOR: Diagnostic odds ratio                  inf       inf       inf
FOR: False omission rate                      0         0         0

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接