如何批量编写TFRecords?

13

我有一个大约有四千万行的CSV文件,每一行都是一个训练实例。根据TFRecords的文档,我试图对数据进行编码并保存在TFRecord文件中。

我找到的所有示例(甚至是TensorFlow仓库中的示例)都展示了创建TFRecord的过程取决于类TFRecordWriter。该类具有一个write方法,它以序列化的字符串表示形式的数据作为输入,并将其写入磁盘。但是,这似乎是逐个训练实例完成的。

我该如何批量写入序列化数据?

假设我有一个函数:

  def write_row(sentiment, text, encoded):
    feature = {"one_hot": _float_feature(encoded),
               "label": _int64_feature([sentiment]),
               "text": _bytes_feature([text.encode()])}

    example = tf.train.Example(features=tf.train.Features(feature=feature))
    writer.write(example.SerializeToString())

将四千万次写入磁盘(每个样本一次)会非常缓慢。更有效的方法是将数据分批处理,每次写入50k或100k个样本(根据机器资源)。然而,在TFRecordWriter内部似乎没有可用的批处理方法。

大致上可以这样实现:

class MyRecordWriter:

  def __init__(self, writer):
    self.records = []
    self.counter = 0
    self.writer = writer

  def write_row_batched(self, sentiment, text, encoded):
    feature = {"one_hot": _float_feature(encoded),
               "label": _int64_feature([sentiment]),
               "text": _bytes_feature([text.encode()])}

    example = tf.train.Example(features=tf.train.Features(feature=feature))
    self.records.append(example.SerializeToString())
    self.counter += 1
    if self.counter >= 10000:
      self.writer.write(os.linesep.join(self.records))
      self.counter = 0
      self.records = []

但是当我读取使用该方法创建的文件时,我会收到以下错误:

tensorflow/core/framework/op_kernel.cc:1192] Invalid argument: Could not parse example input, value: '
��

label

��
one_hot����
��
注意:我可以更改编码过程,以便每个example proto包含数千个示例而不仅仅是一个,但是我不想在写入TFrecord文件时以这种方式预先批处理数据,因为当我想要使用该文件进行具有不同批大小的训练时,它会在我的训练流水线中引入额外的开销。
1个回答

7
TFRecords是一种二进制格式。使用以下代码将其视为文本文件:self.writer.write(os.linesep.join(self.records))。这是因为您正在使用操作系统相关的linesep(即\n\r\n)。
解决方案:只需编写记录。您要求批量编写它们。您可以使用缓冲区编写器。对于4000万行,您可能还希望考虑将数据拆分成单独的文件,以便更好地并行化。
使用TFRecordWriter时:文件已经被缓冲。
证据可以在源代码中找到:
  • tf_record.py调用pywrap_tensorflow.PyRecordWriter_New
  • PyRecordWriter调用Env::Default()->NewWritableFile
  • Env->NewWritableFile在匹配的FileSystem上调用NewWritableFile
  • 例如PosixFileSystem调用fopen
  • fopen返回一个流,如果它不是交互式设备,则默认完全缓冲
  • 这将取决于文件系统,但WritableFile注意到"实现必须提供缓冲区,因为调用者可能一次附加小片段到文件中。"

谢谢,这样清楚了很多。当你说使用缓冲写入器时,我相信标准的Python with open("path", "wb") 方法提供了一个不需要额外费用的缓冲写入器。然而我找不到任何方法来检查 TFRecordWriter 类是否在写入磁盘之前也进行了流缓冲... - Insectatorious
1
问题是如何将大量数据写入TFRecords,那么该怎么做呢? - bluesummers
1
很抱歉,我没有得到解决方案 - 我遇到了同样的问题。如何将缓冲区写入TFRecord?写入是通过tf.python_io.TFRecordWriter完成的 - 它没有缓冲区的参数或选项。这不是标准的open('...', 'b')操作。 - bluesummers
@bluesummers 这个问题没有提到 TFRecordWriter。我添加了一个部分来说明 TFRecordWriter 是带缓冲的。 - de1
4
可以,请问您需要怎样的TFRecordWriter批量写入的简单示例呢? - Maosi Chen
显示剩余2条评论

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接