你好,我正在尝试构建一个图像输入管道。我的预处理训练数据存储在一个tfrecords文件中,我使用以下代码创建了它:
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
..
img_raw = img.tostring() # typeof(img) = np.Array with shape (50, 80) dtype float64
img_label_text_raw = str.encode(img_lable)
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(height), #heigth (integer)
'width': _int64_feature(width), #width (integer)
'depth': _int64_feature(depth), #num of rgb channels (integer)
'image_data': _bytes_feature(img_raw), #raw image data (byte string)
'label_text': _bytes_feature(img_label_text_raw), #raw image_lable_text (byte string)
'lable': _int64_feature(lable_txt_to_int[img_lable])})) #label index (integer)
writer.write(example.SerializeToString())
现在我尝试读取二进制数据,以便从中重构张量:
def read_and_decode(filename_queue):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
# Defaults are not specified since both keys are required.
features={
'label': tf.FixedLenFeature([], tf.int64),
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'depth': tf.FixedLenFeature([], tf.int64),
'image_data': tf.FixedLenFeature([], tf.string)
})
label = features['label']
height = tf.cast(features['height'], tf.int64)
width = tf.cast(features['width'], tf.int64)
depth = tf.cast(features['depth'], tf.int64)
image_shape = tf.pack([height, width, depth])
image = tf.decode_raw(features['image_data'], tf.float64)
image = tf.reshape(image, image_shape)
images, labels = tf.train.shuffle_batch([image, label], batch_size=2,
capacity=30,
num_threads=1,
min_after_dequeue=10)
return images, labels
很遗憾,这个不起作用。我收到了以下错误消息:
ValueError: Tensor conversion requested dtype string for Tensor with dtype int64: 'Tensor("ParseSingleExample/Squeeze_label:0", shape=(), dtype=int64)' ... TypeError: Input 'bytes' of 'DecodeRaw' Op has type int64 that does not match expected type of string.
有人能给我一些修复的提示吗?
提前感谢!
更新:read_and_decode的完整代码清单
@mmry非常感谢您。现在我的代码在批次混洗时中断。具体如下:
ValueError: All shapes must be fully defined: [TensorShape([Dimension(None), Dimension(None), Dimension(None)]), TensorShape([])]
有什么建议吗?
label_text
。那么你如何在解码器中解析这些数据呢?能否给一些建议?谢谢。 - mininglabel_txt = (example.features.feature['label_text'].bytes_list .value[0].decode("utf-8"))
- monchi