在TensorFlow 1.X中,您可以使用占位符动态更改批量大小。例如:
我尝试了以下方法,但都不起作用。
dataset.batch(batch_size=tf.placeholder())
查看完整示例
在TensorFlow 2.0中如何实现呢?我尝试了以下方法,但都不起作用。
import numpy as np
import tensorflow as tf
def new_gen_function():
for i in range(100):
yield np.ones(2).astype(np.float32)
batch_size = tf.Variable(5, trainable=False, dtype=tf.int64)
train_ds = tf.data.Dataset.from_generator(new_gen_function, output_types=(tf.float32)).batch(
batch_size=batch_size)
for data in train_ds:
print(data.shape[0])
batch_size.assign(10)
print(batch_size)
输出
5
<tf.Variable 'Variable:0' shape=() dtype=int64, numpy=10>
5
<tf.Variable 'Variable:0' shape=() dtype=int64, numpy=10>
5
<tf.Variable 'Variable:0' shape=() dtype=int64, numpy=10>
5
...
...
我正在使用Gradient tape自定义训练循环来训练模型。如何实现这一点?
tf.keras.Input
替换tf.placeholder
。你可以参考这个链接https://dev59.com/xVMH5IYBdhLWcg3wmwPw获取更多信息。谢谢! - user11530462