在Keras中如何返回分类预测中的概率值?

12

我试图制作一个简单的概念验证,以便查看给定预测的不同类别的概率。

然而,我尝试的所有方法似乎只输出预测的类别,即使我使用softmax激活。由于我是机器学习的新手,所以我不确定我是否犯了一个简单的错误,或者这是Keras中没有提供的功能。

我正在使用Keras + TensorFlow。我已经改编了Keras为MNIST数据集分类提供的基本示例之一

我的代码与示例完全相同,只是有几行(已注释)额外的代码将模型导出到本地文件。

'''Trains a simple deep NN on the MNIST dataset.
Gets to 98.40% test accuracy after 20 epochs
(there is *a lot* of margin for parameter tuning).
2 seconds per epoch on a K520 GPU.
'''

from __future__ import print_function

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.optimizers import RMSprop

import h5py # added import because it is required for model.save
model_filepath = 'test_model.h5' # added filepath config

batch_size = 128
num_classes = 10
epochs = 20

# the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

model = Sequential()
model.add(Dense(512, activation='relu', input_shape=(784,)))
model.add(Dropout(0.2))
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(num_classes, activation='softmax'))

model.summary()

model.compile(loss='categorical_crossentropy',
          optimizer=RMSprop(),
          metrics=['accuracy'])

history = model.fit(x_train, y_train,
                batch_size=batch_size,
                epochs=epochs,
                verbose=1,
                validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

model.save(model_filepath) # added saving model
print('Model saved') # added log

然后这个内容的第二部分是一个简单的脚本,应该导入模型,针对给定的一些数据进行分类预测,并打印出每个类别的概率。(我使用了包含在Keras代码库中的相同mnist类别,以尽可能简单的例子来说明)

import keras
from keras.datasets import mnist
from keras.models import Sequential
import keras.backend as K

import numpy

# loading model saved locally in test_model.h5
model_filepath = 'test_model.h5'
prev_model = keras.models.load_model(model_filepath)

# these lines are copied from the example for loading MNIST data
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(60000, 784)

# for this example, I am only taking the first 10 images
x_slice = x_train[slice(1, 11, 1)]

# making the prediction
prediction = prev_model.predict(x_slice)

# logging each on a separate line
for single_prediction in prediction:
    print(single_prediction)

如果我运行第一个脚本导出模型,然后运行第二个脚本对一些示例进行分类,我将得到以下输出:

[ 1.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
[ 0.  0.  0.  0.  1.  0.  0.  0.  0.  0.]
[ 0.  1.  0.  0.  0.  0.  0.  0.  0.  0.]
[ 0.  0.  0.  0.  0.  0.  0.  0.  0.  1.]
[ 0.  0.  1.  0.  0.  0.  0.  0.  0.  0.]
[ 0.  1.  0.  0.  0.  0.  0.  0.  0.  0.]
[ 0.  0.  0.  1.  0.  0.  0.  0.  0.  0.]
[ 0.  1.  0.  0.  0.  0.  0.  0.  0.  0.]
[ 0.  0.  0.  0.  1.  0.  0.  0.  0.  0.]
[ 0.  0.  0.  1.  0.  0.  0.  0.  0.  0.]

这对于查看每个类别的预测非常好,但是如果我想看到每个示例的每个类别的相对概率怎么办?我正在寻找更像这样的东西:

[ 0.94 0.01 0.02 0. 0. 0.01 0. 0.01 0.01 0.]
[ 0. 0. 0. 0. 0.51 0. 0. 0. 0.49 0.]
...
换句话说,我需要知道每个预测的确定性程度,而不仅仅是预测本身。我以为在模型中使用softmax激活函数可以看到相关概率,但我似乎找不到Keras文档中会给我概率而不是预测答案的任何内容。我是犯了某种愚蠢的错误,还是这个功能不可用?

你的Keras版本是什么? - desertnaut
我正在使用Keras 2.0.9版本。 - user9040452
1
无法重现您的问题;我的Keras 2.0.9中的“predict”返回概率,正如它应该做的那样。 - desertnaut
嗯,你有没有使用Theano或CNTK而不是TensorFlow?也许这是TensorFlow的一个bug? - user9040452
不,TensorFlow...也许这是一个四舍五入的细节问题,当你在Python 3中打印(我正在使用Python 2)时? - desertnaut
2个回答

11

结果表明问题出在我的预测脚本没有完全规范化数据。

我的预测脚本应该包含以下几行:

# these lines are copied from the example for loading MNIST data
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(60000, 784)
x_train = x_train.astype('float32') # this line was missing
x_train /= 255 # this line was missing too

因为数据没有转换为浮点数,并且除以了255(这样它就在0和1之间),所以它只显示为1和0。


我遇到了同样的问题,似乎得到的是类别而不是概率。我在训练时使用了训练生成器,它会自动对我的数据进行归一化处理,但是在进行自定义推理调用时,我没有对数据进行归一化处理。感谢您的帮助! - dzubke

5

Keras的predict函数返回概率值,而非类别。

在我的系统配置中无法重现你的问题:

Python version 2.7.12
Tensorflow version 1.3.0
Keras version 2.0.9
Numpy version 1.13.3

这是您已加载的模型(按照您的代码进行了20个周期的训练)对于x_slice的预测输出:

print(prev_model.predict(x_slice))
# Result: 
[[  1.00000000e+00   3.31656316e-37   1.07806675e-21   7.11765177e-30
    2.48000320e-31   5.34837679e-28   3.12470132e-24   4.65175406e-27
    8.66994134e-31   5.26426367e-24]
 [  0.00000000e+00   5.34361977e-30   3.91144999e-35   0.00000000e+00
    1.00000000e+00   0.00000000e+00   1.05583665e-36   1.01395577e-29
    0.00000000e+00   1.70868685e-29]
 [  3.99137559e-38   1.00000000e+00   1.76682222e-24   9.33333581e-31
    3.99846307e-15   1.17745576e-24   1.87529709e-26   2.18951752e-20
    3.57518280e-17   1.62027896e-28]
 [  6.48006586e-26   1.48974980e-17   5.60530329e-22   1.81973780e-14
    9.12573406e-10   1.95987500e-14   8.08566866e-27   1.17901132e-12
    7.33970447e-13   1.00000000e+00]
 [  2.01602060e-16   6.58242856e-14   1.00000000e+00   6.84244084e-09
    1.19809885e-16   7.94907624e-14   3.10690434e-19   8.02848586e-12
    4.68330721e-11   5.14736501e-15]
 [  2.31014903e-35   1.00000000e+00   6.02224725e-21   2.35928828e-23
    7.50006509e-15   4.06930881e-22   1.13288827e-24   4.20440718e-17
    4.95182972e-17   1.85492109e-18]
 [  0.00000000e+00   0.00000000e+00   0.00000000e+00   1.00000000e+00
    0.00000000e+00   6.30200370e-27   0.00000000e+00   5.19937755e-33
    1.63205659e-31   1.21508034e-20]
 [  1.44608573e-26   1.00000000e+00   1.78712268e-18   6.84598301e-19
    1.30042071e-11   2.53873986e-14   5.83169942e-17   1.20201071e-12
    2.21844570e-14   3.75015198e-15]
 [  0.00000000e+00   6.29184453e-34   9.22474943e-29   0.00000000e+00
    1.00000000e+00   3.05067233e-34   1.43097161e-28   1.34234082e-29
    4.28647272e-36   9.29760838e-34]
 [  4.68828449e-30   5.55172479e-20   3.26705529e-19   9.99999881e-01
    3.49577992e-22   1.27715460e-11   4.99185615e-36   1.19164204e-20
    4.21086124e-16   1.52631387e-07]]

我怀疑在打印时存在一些四舍五入问题(或者您已经训练了更多的epochs,导致训练集中的概率非常接近1)...
为了确信您确实获得了概率而不是类别预测结果,我建议您尝试使用仅训练了一个epoch的模型来获取预测结果;通常情况下,您应该看到较少的1.0 - 这里是一个经过epochs=1训练的model的情况:
print(model.predict(x_slice))
# Result: 

[[  9.99916673e-01   5.36548761e-08   6.10747229e-05   8.21199933e-07
    6.64725164e-08   6.78853041e-07   9.09637220e-06   4.56192402e-06
    1.62688798e-06   5.23997733e-06]
 [  7.59836894e-07   1.78043920e-05   1.79073555e-04   2.95592145e-05
    9.98031914e-01   1.75839632e-05   5.90557102e-06   1.27705920e-03
    3.94643757e-06   4.36416740e-04]
 [  4.48473330e-08   9.99895334e-01   2.82608235e-05   5.33154832e-07
    9.78453227e-06   1.58954310e-06   3.38150176e-06   5.26260410e-05
    8.09341054e-06   3.28643267e-07]
 [  7.38236849e-07   4.80247072e-05   2.81726116e-05   4.77648537e-05
    7.21933879e-03   2.52177160e-05   3.88786475e-07   3.56770557e-04
    2.83472677e-04   9.91990149e-01]
 [  5.03611082e-05   2.69402866e-04   9.92011130e-01   4.68175858e-03
    9.57477605e-05   4.26214538e-04   7.66683661e-05   7.05923303e-04
    1.45670515e-03   2.26032615e-04]
 [  1.36330849e-10   9.99994516e-01   7.69141934e-07   1.44130311e-07
    9.52201333e-07   1.45219332e-07   4.43408908e-07   6.93398249e-07
    2.18685204e-06   1.50741769e-07]
 [  2.39427478e-09   3.75754922e-07   3.89349816e-06   9.99889374e-01
    1.85837867e-09   1.16176770e-05   1.89989760e-11   3.12301523e-07
    1.13220040e-05   8.29571582e-05]
 [  1.45760115e-08   9.99900222e-01   3.67058942e-06   4.04857201e-06
    1.97999962e-05   7.85745397e-06   8.13850420e-06   1.87294081e-05
    2.81870762e-05   9.38157609e-06]
 [  7.52560858e-09   8.84437856e-09   9.71140025e-07   5.20911703e-10
    9.99986649e-01   3.12135370e-07   1.06521384e-05   1.25693066e-06
    7.21853368e-08   5.21001624e-08]
 [  8.67672298e-08   2.17907742e-04   2.45352840e-06   9.95455265e-01
    1.43749105e-06   1.51766278e-03   1.83744309e-08   3.83995541e-07
    9.90309782e-05   2.70584645e-03]]

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接