如何将LMDB文件加载到TensorFlow中?

5

我有一个庞大的数据集(1 TB),分散在大约3000个CSV文件中。我的计划是将其转换为一个大的LMDB文件,以便快速读取并训练神经网络。然而,我没有找到任何关于如何将LMDB文件加载到TensorFlow中的文档。是否有人知道如何做到这一点?我知道TensorFlow可以读取CSV文件,但我认为那会太慢。

1个回答

7
根据此文档,有几种读取TensorFlow数据的方法。 占位符 占位符 如果您想将随机排序和批次处理委托给框架,则需要创建输入管道。问题是如何将lmdb数据注入到符号输入管道中。一种可能的解决方案是使用tf.py_func操作。以下是一个例子:
def create_input_pipeline(lmdb_env, keys, num_epochs=10, batch_size=64):
   key_producer = tf.train.string_input_producer(keys, 
                                                 num_epochs=num_epochs,
                                                 shuffle=True)
   single_key = key_producer.dequeue()

   def get_bytes_from_lmdb(key):
      with lmdb_env.begin() as txn:
         lmdb_val = txn.get(key)
      example = get_example_from_val(lmdb_val) # A single example (numpy array)
      label = get_label_from_val(lmdb_val)     # The label, could be a scalar
      return example, label

   single_example, single_label = tf.py_func(get_bytes_from_lmdb,
                                             [single_key], [tf.float32, tf.float32])
   # if you know the shapes of the tensors you can set them here:
   # single_example.set_shape([224,224,3])

   batch_examples, batch_labels = tf.train.batch([single_example, single_label],
                                                 batch_size)
   return batch_examples, batch_labels

The tf.py_func op inserts a call to regular python code inside of the TensorFlow graph, we need to specify the inputs and the number and types of the outputs. The tf.train.string_input_producer creates a shuffled queue with the given keys. The tf.train.batch op create another queue that contains batches of data. When training, each evaluation of batch_examples or batch_labels will dequeue another batch from that queue.

Because we created queues we need to take care and run the QueueRunner objects before we start training. This is done like this (from the TensorFlow doc):

# Create the graph, etc.
init_op = tf.initialize_all_variables()

# Create a session for running operations in the Graph.
sess = tf.Session()

# Initialize the variables (like the epoch counter).
sess.run(init_op)

# Start input enqueue threads.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

try:
    while not coord.should_stop():
        # Run training steps or whatever
        sess.run(train_op)

except tf.errors.OutOfRangeError:
    print('Done training -- epoch limit reached')
finally:
    # When done, ask the threads to stop.
    coord.request_stop()

# Wait for threads to finish.
coord.join(threads)
sess.close()


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接