如何通过TensorFlow的tf.data API加载pickle文件

6

我有多个pickle文件存储在磁盘上,其中包含我的数据。我想使用tensorflow的tf.data.Dataset将我的数据加载到训练过程中。我的代码如下:

def _parse_file(path):
    image, label = *load pickle file*
    return image, label
paths = glob.glob('*.pkl')
print(len(paths))
dataset = tf.data.Dataset.from_tensor_slices(paths)
dataset = dataset.map(_parse_file)
iterator = dataset.make_one_shot_iterator()

问题是我不知道如何实现_parse_file函数。这个函数的参数path是张量类型。我尝试过了。

def _parse_file(path):
    with tf.Session() as s:
        p = s.run(path)
        image, label = pickle.load(open(p, 'rb'))
    return image, label

并且收到了错误信息:

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'arg0' with dtype string
     [[Node: arg0 = Placeholder[dtype=DT_STRING, shape=<unknown>, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

在互联网上搜索后,我仍然不知道如何做。我将非常感谢任何能够为我提供提示的人。

为什么要将路径传递到会话中,只需加载图像并通过会话循环即可。 - Eliethesaiyan
1
@Eliethesaiyan 因为从数据集中检索到的“path”是张量类型,不能直接用作字符串。 - Zhao Chen
3个回答

3
我已经解决了这个问题。我应该使用tf.py_func,就像在这个文档中所说明的那样。

1
请描述一下您的解决方案的外观,因为对于新手来说,这个问题还不够清晰易懂。 - learner
太棒了,你能分享一下吗? - Kartik
1
@user3829269,提到的Tensorflow文档链接似乎已经过时了,您可以在此处找到此API的示例:https://www.tensorflow.org/api_docs/python/tf/compat/v1/py_func 或者在tf2.0中找到:https://www.tensorflow.org/api_docs/python/tf/py_function - Zhao Chen
我最终实现了该函数,但没有使用数据 API。感谢您的回复。 - Kartik
你好,能否详细说明一下你的解决方案?你是如何使用py_func的?我遇到了类似的问题 - https://stackoverflow.com/questions/74614110/loading-pickled-data-files-to-tensorflow-data-using-py-func - Ohm

1
这是我解决这个问题的方法。我没有使用tf.py_func,而是看一下下面的“load_encoding()”函数,它是读取pickle文件的函数。FACELIB_DIR包含了vggface2编码的目录,每个目录都以人名命名。
import tensorflow as tf
import pickle
import os

FACELIB_DIR='/var/noggin/FaceEncodings'

# Get list of all classes & build a quick int-lookup dictionary
labelNames = sorted([x for x in os.listdir(FACELIB_DIR) if os.path.isdir(os.path.join(FACELIB_DIR,x)) and not x.startswith('.')])
labelStrToInt = dict([(x,i) for i,x in enumerate(labelNames)])

# Function load_encoding - Loads Encoding data from enc2048 file in filepath
#    This reads an encoding from disk, and through the file path gets the label oneHot value, returns both
def load_encoding(file_path):
    with open(os.path.join(FACELIB_DIR,file_path),'rb') as fin:
        A,_ = pickle.loads(fin.read())    # encodings, source_image_name
    label_str = tf.strings.split(file_path, os.path.sep)[-2]
    return (A, labelStrToInt[label_str])

# Build the dataset of every enc2048 file in our data library
encpaths = []
for D in sorted([x for x in os.listdir(FACELIB_DIR) if os.path.isdir(os.path.join(FACELIB_DIR,x)) and not x.startswith('.')]):
    # All the encoding files
    encfiles = sorted(filter((lambda x: x.endswith('.enc2048')), os.listdir(os.path.join(FACELIB_DIR, D))))
    encpaths += [os.path.join(D,x) for x in encfiles]
dataset = tf.data.Dataset.from_tensor_slices(encpaths)

# Shuffle and speed improvements on the dataset
BATCH_SIZE = 64
from tensorflow.data import AUTOTUNE
dataset = (dataset
    .shuffle(1024)
    .cache()
    .repeat()
    .batch(BATCH_SIZE)
    .prefetch(AUTOTUNE)
)
    
# Benchmark our tf.data pipeline
import time
datasetGen = iter(dataset)
NUM_STEPS = 10000
start_time = time.time()
for i in range(0, NUM_STEPS):
    X = next(datasetGen)
totalTime = time.time() - start_time
print('==> tf.data generated {} tensors in {:.2f} seconds'.format(BATCH_SIZE * NUM_STEPS, totalTime))

你好@RubinMac,我在你的代码中没有看到你使用你创建的load_encoding函数。我这里有一个类似的方法,但是尝试了一下并没有成功 - https://stackoverflow.com/questions/74614110/loading-pickled-data-files-to-tensorflow-data-using-py-func - Ohm

-1

tf.py_func函数用于解决这个问题,正如文档中所提到的那样。


1
请添加细节 - Dr. Prof. Patrick

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接