如何使用AdaBoost提升基于Keras的神经网络?

17
假设我为一个二元分类问题拟合了以下神经网络:

Assuming I fit the following neural network for a binary classification problem:

model = Sequential()
model.add(Dense(21, input_dim=19, init='uniform', activation='relu'))
model.add(Dense(80, init='uniform', activation='relu'))
model.add(Dense(80, init='uniform', activation='relu'))
model.add(Dense(1, init='uniform', activation='sigmoid'))
# Compile model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# Fit the model
model.fit(x2, training_target, nb_epoch=10, batch_size=32, verbose=0,validation_split=0.1, shuffle=True,callbacks=[hist])

我该如何使用AdaBoost来提升神经网络?Keras是否有相应的命令?

3个回答

18

可以按照以下方式进行: 首先创建一个模型(为了可重复性,将其制作为函数):

def simple_model():                                           
    # create model
    model = Sequential()
    model.add(Dense(25, input_dim=x_train.shape[1], kernel_initializer='normal', activation='relu'))
    model.add(Dropout(0.2, input_shape=(x_train.shape[1],)))
    model.add(Dense(10, kernel_initializer='normal', activation='relu'))
    model.add(Dense(1, kernel_initializer='normal'))
    # Compile model
    model.compile(loss='mean_squared_error', optimizer='adam')
    return model

然后将它放在sklearn包装器中:

ann_estimator = KerasRegressor(build_fn= simple_model, epochs=100, batch_size=10, verbose=0)

然后最终提升它:

boosted_ann = AdaBoostRegressor(base_estimator= ann_estimator)
boosted_ann.fit(rescaledX, y_train.values.ravel())# scale your training data 
boosted_ann.predict(rescaledX_Test)

2
你会如何重新调整数据? - Oz0234
我们可以将Keras数据生成器传递给SkLearn包装器吗? - Adelov

3
Keras本身不实现adaboost。但是,Keras模型与scikit-learn兼容,因此您可以从那里使用AdaBoostClassifier:link。在编译模型后,将您的model用作base_estimator,并fit AdaBoostClassifier实例而不是model
然而,这样做,您将无法使用传递给fit的参数,例如epochs或batch_size的数量,因此将使用默认值。如果默认值不够好,您可能需要构建自己的类,在模型上实现scikit-learn接口,并将适当的参数传递给fit

你好,感谢您的回答。当我插入以下代码时:bdt.fit(x2, training_target)``` 其中model是我编译后的keras网络,它给我返回了一个错误:TypeError: Cannot clone object '' (type ): it does not seem to be a scikit-learn estimator as it does not implement a 'get_params' methods. - ishido
显然,单独的Keras分类器不兼容scikit-learn。请参阅此文章以了解如何使它们一起工作:https://keras.io/scikit-learn-api/ - Ishamael

1

1
欢迎来到Stack Overflow!这是一个边缘仅链接答案。您应该扩展您的答案,包括尽可能多的信息,并仅将链接用于参考。 - Filnor

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接