TFRecordWriter
似乎是最方便的选项,但不幸的是它只能写入每个元素一个张量的数据集。以下是您可以使用的几种解决方法。首先,由于所有张量具有相同的类型和类似的形状,因此您可以将它们全部连接成一个张量,然后在加载时再拆分回来:
import tensorflow as tf
a = tf.zeros((100, 512), tf.int32)
ds = tf.data.Dataset.from_tensor_slices((a, a, a, a[:, 0]))
print(ds)
def write_map_fn(x1, x2, x3, x4):
return tf.io.serialize_tensor(tf.concat([x1, x2, x3, tf.expand_dims(x4, -1)], -1))
ds = ds.map(write_map_fn)
writer = tf.data.experimental.TFRecordWriter('mydata.tfrecord')
writer.write(ds)
def read_map_fn(x):
xp = tf.io.parse_tensor(x, tf.int32)
xp.set_shape([1537])
return xp[:512], xp[512:1024], xp[1024:1536], xp[-1]
ds = tf.data.TFRecordDataset('mydata.tfrecord').map(read_map_fn)
print(ds)
但更普遍的情况是,您可以为每个张量(即Tensor)单独创建一个文件,然后读取它们:
import tensorflow as tf
a = tf.zeros((100, 512), tf.int32)
ds = tf.data.Dataset.from_tensor_slices((a, a, a, a[:, 0]))
for i, _ in enumerate(ds.element_spec):
ds_i = ds.map(lambda *args: args[i]).map(tf.io.serialize_tensor)
writer = tf.data.experimental.TFRecordWriter(f'mydata.{i}.tfrecord')
writer.write(ds_i)
NUM_PARTS = 4
parts = []
def read_map_fn(x):
return tf.io.parse_tensor(x, tf.int32)
for i in range(NUM_PARTS):
parts.append(tf.data.TFRecordDataset(f'mydata.{i}.tfrecord').map(read_map_fn))
ds = tf.data.Dataset.zip(tuple(parts))
print(ds)
可以将整个数据集保存在单个文件中,每个元素有多个独立的张量,即包含tf.train.Example
的TFRecords文件,但我不知道是否有一种方法可以在TensorFlow内部创建这些文件,而无需将数据从数据集中提取到Python中,再将其写入记录文件。