LSTM中input_shape和batch_input_shape有什么区别?

10

这只是设置同一件事的不同方式,还是它们实际上具有不同的含义?它是否与网络配置有关?

在一个简单的例子中,我无法观察到以下两者之间的任何区别:

model = Sequential()
model.add(LSTM(1, batch_input_shape=(None,5,1), return_sequences=True))
model.add(LSTM(1, return_sequences=False))

model = Sequential()
model.add(LSTM(1, input_shape=(5,1), return_sequences=True))
model.add(LSTM(1, return_sequences=False))

然而,当我将批次大小设置为12 batch_input_shape=(12,5,1) 并在拟合模型时使用 batch_size=10时,我遇到了一个错误。

ValueError: Cannot feed value of shape (10, 5, 1) for Tensor 'lstm_96_input:0', which has shape '(12, 5, 1)'

这显然是有道理的。但是我看不出限制模型批次大小的意义所在。

我有遗漏什么吗?

1个回答

11

这两种方式实际上是等价的,你的实验结果证实了这一点,也可以参考这个讨论

然而,在模型层面上限制批次大小没有意义。有时候,限制批次大小是必要的,比如在一个状态化LSTM中,批次中的最后一个单元状态将被记住并用于初始化后续批次。这确保了客户端不会向网络提供不同的批次大小。以下是示例代码:

# Expected input batch shape: (batch_size, timesteps, data_dim)
# Note that we have to provide the full batch_input_shape since the network is stateful.
# the sample of index i in batch k is the follow-up for the sample i in batch k-1.
model = Sequential()
model.add(LSTM(32, return_sequences=True, stateful=True,
               batch_input_shape=(batch_size, timesteps, data_dim)))

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接