Keras模型用于拟合多项式

9

我从一个四次多项式中生成了一些数据,并希望创建一个Keras回归模型来拟合这个多项式。问题在于,拟合后的预测结果似乎基本上是线性的。由于这是我第一次使用神经网络,我认为我可能犯了一个非常微小和愚蠢的错误。

以下是我的代码:

model = Sequential()
model.add(Dense(units=200, input_dim=1))
model.add(Activation('relu'))
model.add(Dense(units=45))
model.add(Activation('relu'))
model.add(Dense(units=1))

model.compile(loss='mean_squared_error',
              optimizer='sgd')

model.fit(x_train, y_train, epochs=20, batch_size=50)

loss_and_metrics = model.evaluate(x_test, y_test, batch_size=100)

classes = model.predict(x_test, batch_size=1)

x_trainy_train是包含第前9900个条目的numpy数组此文件

我尝试了不同的batch_sizes、epochs数、层大小和训练数据量。什么都没有帮助。

请指出您发现的任何不合理之处!


你能提供 x_train、y_train 和 x_test 吗?这样读者就可以运行你的代码了。 - Miriam Farber
我在原问题中添加了数据文件的链接。 - FloodLuszt
1个回答

11

神经网络通常不擅长外推多项式函数。然而,如果你的训练和测试数据来自同样的范围,你可能会取得不错的结果。我生成了一些数据并使用了你的代码:

import numpy as np
x_train=np.random.rand(9000)
y_train=x_train**4+x_train**3-x_train
x_train=x_train.reshape(len(x_train),1)

x_test=np.linspace(0,1,100)
y_test=x_test**4+x_test**3-x_test
x_test=x_test.reshape(len(x_test),1)


model = Sequential()
model.add(Dense(units=200, input_dim=1))
model.add(Activation('relu'))
model.add(Dense(units=45))
model.add(Activation('relu'))
model.add(Dense(units=1))

model.compile(loss='mean_squared_error',
              optimizer='sgd')

model.fit(x_train, y_train, epochs=40, batch_size=50, verbose=1)

loss_and_metrics = model.evaluate(x_test, y_test, batch_size=100)

classes = model.predict(x_test, batch_size=1)

test=x_test.reshape(-1)
plt.plot(test,classes,c='r')
plt.plot(test,y_test,c='b')
plt.show()
请注意,我将epochs增加到40以获得更多的迭代和更准确的结果。我还设置verbose=1以便能够看到损失的行为。损失确实下降到7.4564e-04,下面是我得到的结果。红线是网络的预测值,蓝线是正确的值。您可以看到它们非常接近。

enter image description here


谢谢!对于我选择的多项式,上面的代码并不真正起作用。但是将激活函数从ReLU更改为sigmoid(10 * x)可以产生可接受的结果。也许首先尝试拟合多项式来进行神经网络并不是一个好的选择。 - FloodLuszt
1
@FloodLuszt 在一维 Relu 的单层中执行分段线性回归。也许你最好只使用一个层,并且可能显著增加节点数。基本上,用线段拟合曲线的好坏取决于函数的凸度[因此,梯度变化越大,需要添加的节点就越多]。请参见我的答案 https://stats.stackexchange.com/a/375658/27556 - seanv507

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接