将Spark Dataframe转换为Tensorflow Dataset(tf.data API)

4

我正在尝试将一个Spark DataFrame转换为TensorFlow记录,并在TensorFlow中调用它作为数据集,以便为我的模型获得输入。但这并没有起作用。

我的尝试如下:

1)使用spark-tensorflow-connector库的jar获取一个SparkSession:

spark = SparkSession.builder.config(conf=SparkConf().set("spark.jars", "path/to/spark-tensorflow-connector_2.11-1.6.0.jar").getOrCreate()

2) 将数据帧保存为TFRecord(在此以数据集为例):

df = spark.createDataFrame([(1, 120), (2, 130), (2, 140)], ['A', 'B'])

path='path/example.tfrecord'
df.write.format("tfrecords").mode("overwrite").option("recordType", "Example").save(path)

3) 将tfrecord文件加载到tf.data API中(仅以'A'作为特征进行简化):

path2 = "path/example.tfrecord/*"
dataset=tf.data.TFRecordDataset(tf.compat.v1.gfile.Glob(path2))

def parse_func(buff):
       features = {'A': tf.compat.v1.FixedLenFeature(shape=[5], dtype=tf.int64)}
       tensor_dict = tf.compat.v1.parse_single_example(buff, features)
       return tensor_dict['A']

train_dataset = dataset.map(parse_func).batch(1)

但当我尝试打印数据集迭代器时:

for x in train_dataset:
       print(x)

我遇到了以下错误:
2020-05-21 06:43:53.579843: W tensorflow/core/framework/op_kernel.cc:1655] OP_REQUIRES failed at iterator_ops.cc:941 : Data loss: corrupted record at 0
Traceback (most recent call last):
  File "/home/patrizio/PycharmProjects/pyspark-config/venv/lib/python3.6/site-packages/tensorflow_core/python/eager/context.py", line 1897, in execution_mode
2020-05-21 06:43:53.580090: W tensorflow/core/framework/op_kernel.cc:1655] OP_REQUIRES failed at example_parsing_ops.cc:93 : Invalid argument: Key: A.  Can't parse serialized Example.
2020-05-21 06:43:53.580567: W tensorflow/core/framework/op_kernel.cc:1655] OP_REQUIRES failed at example_parsing_ops.cc:93 : Invalid argument: Key: A.  Can't parse serialized Example.
    yield
  File "/home/patrizio/PycharmProjects/pyspark-config/venv/lib/python3.6/site-packages/tensorflow_core/python/data/ops/iterator_ops.py", line 659, in _next_internal
    output_shapes=self._flat_output_shapes)
  File "/home/patrizio/PycharmProjects/pyspark-config/venv/lib/python3.6/site-packages/tensorflow_core/python/ops/gen_dataset_ops.py", line 2479, in iterator_get_next_sync
    _ops.raise_from_not_ok_status(e, name)
  File "/home/patrizio/PycharmProjects/pyspark-config/venv/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py", line 6606, in raise_from_not_ok_status
    six.raise_from(core._status_to_exception(e.code, message), None)
  File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.DataLossError: corrupted record at 0 [Op:IteratorGetNextSync]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/snap/pycharm-community/194/plugins/python-ce/helpers/pycharm/_jb_unittest_runner.py", line 35, in <module>
    sys.exit(main(argv=args, module=None, testRunner=unittestpy.TeamcityTestRunner, buffer=not JB_DISABLE_BUFFERING))
  File "/usr/lib/python3.6/unittest/main.py", line 94, in __init__
    self.parseArgs(argv)
  File "/usr/lib/python3.6/unittest/main.py", line 141, in parseArgs
    self.createTests()
  File "/usr/lib/python3.6/unittest/main.py", line 148, in createTests
    self.module)
  File "/usr/lib/python3.6/unittest/loader.py", line 219, in loadTestsFromNames
    suites = [self.loadTestsFromName(name, module) for name in names]
  File "/usr/lib/python3.6/unittest/loader.py", line 219, in <listcomp>
    suites = [self.loadTestsFromName(name, module) for name in names]
  File "/usr/lib/python3.6/unittest/loader.py", line 204, in loadTestsFromName
    test = obj()
  File "/home/patrizio/PycharmProjects/pyspark-config/tests/python/output/test_output.py", line 75, in test_TFRecord_new
    for x in train_dataset:
  File "/home/patrizio/PycharmProjects/pyspark-config/venv/lib/python3.6/site-packages/tensorflow_core/python/data/ops/iterator_ops.py", line 630, in __next__
    return self.next()
  File "/home/patrizio/PycharmProjects/pyspark-config/venv/lib/python3.6/site-packages/tensorflow_core/python/data/ops/iterator_ops.py", line 674, in next
    return self._next_internal()
  File "/home/patrizio/PycharmProjects/pyspark-config/venv/lib/python3.6/site-packages/tensorflow_core/python/data/ops/iterator_ops.py", line 665, in _next_internal
    return structure.from_compatible_tensor_list(self._element_spec, ret)
  File "/usr/lib/python3.6/contextlib.py", line 99, in __exit__
    self.gen.throw(type, value, traceback)
  File "/home/patrizio/PycharmProjects/pyspark-config/venv/lib/python3.6/site-packages/tensorflow_core/python/eager/context.py", line 1900, in execution_mode
    executor_new.wait()
  File "/home/patrizio/PycharmProjects/pyspark-config/venv/lib/python3.6/site-packages/tensorflow_core/python/eager/executor.py", line 67, in wait
    pywrap_tensorflow.TFE_ExecutorWaitForAllPendingNodes(self._handle)
tensorflow.python.framework.errors_impl.DataLossError: corrupted record at 0

有人知道如何处理这个问题吗?

非常感谢您的帮助。

1个回答

2

我希望这仍然相关。

您的glob表达式不正确。在将示例保存到TFRecord时,Spark必须创建_SUCCESS文件。在模式中包含扩展名。

path2 = "path/example.tfrecord/*.tfrecord"

您可以通过简单地评估以下内容来检查Python将要读取的文件列表:
tf.io.gfile.glob(path)

我会使用这个API而不是旧的compat.v1tf.io.FixedLenFeature的形状也是错误的。每个值都是标量,而不是长度为5的向量。正确的形状只是[]
def parse_func(buff):
    features = {'A': tf.io.FixedLenFeature(shape=[], dtype=tf.int64)}
    tensor_dict = tf.io.parse_single_example(buff, features)
    return tensor_dict

train_dataset = dataset.map(parse_func).batch(3)

如果你想更加高级一些,使用tf.io.parse_example会更好,因为它可以执行向量化解析。但是,在解析之前需要进行批处理。

def parse_func(buff):
    features = {'A': tf.io.FixedLenFeature(shape=[], dtype=tf.int64)}
    tensor_dict = tf.io.parse_example(buff, features)
    return tensor_dict

train_dataset = dataset.batch(3).map(parse_func)

使用parse_example批处理Example协议缓存可能比直接使用此函数更具性能优势。(来源)

你好,如果我像这样写入tfrecords,就不会得到文件扩展名。所以只有part-r-00001等部分。有什么想法为什么会发生这种情况吗?这有点烦人,因为还存在一个.part-r-00001.src文件.... - thijsvdp

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接