Keras分类器的predict_proba()方法与predict()方法不匹配。

3

我正在使用Keras Theano后端处理一个涉及14个类别的分类问题。我想要预测的类别以及相关概率。问题是从predict_proba()得到的概率似乎与从predict()得到的预测类别不匹配,以下是代码和1个样本的输出结果。

PPRANK = ['pp1', 'pp2', 'pp3', 'pp4', 'pp5', 'pp6', 'pp7', 'pp8', 'pp9', 'pp10', 'pp11', 'pp12', 'pp13', 'pp14', 'pp15']

FEATURES = (PPRANK)

# fix random seed for reproducibility
seed = 7
np.random.seed(seed)

data_df = pd.DataFrame.from_csv("data.csv")
X = np.array(data_df[FEATURES].values)
Y = (data_df["bres"].replace(14,13).values)


# define baseline model
def baseline_model():
    # create model
    model = Sequential()
    model.add(Dense(8, input_dim=(len(FEATURES)), init='normal', activation='relu'))
    model.add(Dense(14, init='normal', activation='softmax'))
    # Compile model
    model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    return model
#build model
estimator = KerasClassifier(build_fn=baseline_model, nb_epoch=200, batch_size=5, verbose=0)

#split train and test
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.1, random_state=seed)
estimator.fit(X_train, Y_train)

#get probabilities
predictions = estimator.predict_proba(X_test)

#convert expon to floats
probs = [[] for x in range(21)]
tick2 = 0
for i in range( len( predictions ) ):
    tick = 0
    for x in xrange(14):
        (predictions[i][(tick)]) = '%.4f' % (predictions[i][(tick)])
        probs[(tick2)].append((predictions[i][(tick)]))
        tick += 1
    tick2 += 1

# pprint probabilities
pp = pprint.PrettyPrinter(indent=0)
pp.pprint(probs)

#print class predictions
print estimator.predict(X_test)
print Y_test

结果:

#probabilities
[0.00000, 0.00030, 0.02360, 0.04329, 0.00019, 0.00069, 0.00120, 0.00030, 0.00559, 0.00410, 0.00510, 0.91549, 0.0, 0.0]
#predicted class
11
#actual class
13

predict_proba() 中显示12具有最高的概率,而不是从 predict() 中显示的11。

1个回答

4

Python数组(以及这里的类)的索引从0开始计数,而不是从1开始。再看一下,按照人们的习惯,0.91是第12个值,但它的索引为11,因此predictpredict_proba是一致的。

至于为什么不是13,预测可能是错误的(但请检查您是否有相同类型的错误)。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接