如何将张量保存到TFRecord?

5

我想将一个张量保存为TFRecord格式。我尝试了以下代码:

import tensorflow as tf

x = tf.constant([[2.0, 3.0, 3.0],
                 [1.0, 5.0, 9.0]], dtype='float32')

x2 = tf.io.serialize_tensor(x)

# I understand that I can parse it using this:
# x3 = tf.io.parse_tensor(x2, 'float32')

record_file = 'temp.tfrecord'
with tf.io.TFRecordWriter(record_file) as writer:
    writer.write(x2)

这给我报错:

TypeError: write(): incompatible function arguments. The following argument types are supported:
    1. (self: tensorflow.python._pywrap_record_io.RecordWriter, record: str) -> None

我知道这可能是一个基础问题,但我在 TensorFlow 网站上阅读了一份指南,并在 StackOverflow 上搜索,但没有找到答案。

1个回答

4
问题在于你需要使用张量`x2`的实际值,而不是张量对象本身:
import tensorflow as tf

x = tf.constant([[2.0, 3.0, 3.0],
                 [1.0, 5.0, 9.0]], dtype='float32')
x2 = tf.io.serialize_tensor(x)
record_file = 'temp.tfrecord'
with tf.io.TFRecordWriter(record_file) as writer:
    # Get value with .numpy()
    writer.write(x2.numpy())
# Read from file
parse_tensor_f32 = lambda x: tf.io.parse_tensor(x, tf.float32)
ds = (tf.data.TFRecordDataset('temp.tfrecord')
      .map(parse_tensor_f32))
for x3 in ds:
    tf.print(x3)
# [[2 3 3]
#  [1 5 9]]

在最近的TensorFlow版本中,目前有一种实验性的tf.data.experimental.TFRecordWriter也可以完成这项任务。
import tensorflow as tf

x = tf.constant([[2.0, 3.0, 3.0],
                 [1.0, 5.0, 9.0]], dtype='float32')
# Write to file
ds = (tf.data.Dataset.from_tensors(x)
      .map(tf.io.serialize_tensor))
writer = tf.data.experimental.TFRecordWriter('temp.tfrecord')
writer.write(ds)
# Read from file
parse_tensor_f32 = lambda x: tf.io.parse_tensor(x, tf.float32)
ds2 = (tf.data.TFRecordDataset('temp.tfrecord')
       .map(parse_tensor_f32))
for x2 in ds2:
    tf.print(x2)
# [[2 3 3]
#  [1 5 9]]

1
我收到了“Tensor”对象没有“numpy”属性的错误,有什么建议吗? - JayJona

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接