如何在 TensorFlow 中对数据集进行切片?

3
我希望在 tf.data 中切片数据集。我的数据如下所示:
dataset = tf.data.Dataset.from_tensor_slices([[0, 1, 2, 3, 4],
                                              [1, 2, 3, 4, 5],
                                              [2, 3, 4, 5, 6],
                                              [3, 4, 5, 6, 7],
                                              [4, 5, 6, 7, 8]])

然后主要的数据是:

[0 1 2 3 4]
[1 2 3 4 5]
[2 3 4 5 6]
[3 4 5 6 7]
[4 5 6 7 8]

我想创建另一个张量数据集,其中包含像这样的数据:

       [[1, 2],
       [2, 3],
       [3, 4],
       [4, 5],
       [5, 6]]

在numpy中,它是这样的:

dataset[:,1:3]

如何在TensorFlow中做到这一点?

更新:

我使用以下方法实现了这个:

dataset2 = dataset.map(lambda data: data[1:3])
for val in dataset2:
    print(val.numpy())

但我认为有好的解决方案。

1个回答

1
根据我的看法,您的解决方案是最佳的解决方案。为了造福社区,我正在使用tf.data.Datasetas_numpy_iterator()方法来切片数据集(对您的代码进行了小的语法更改)。
请参考下面的代码:
import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices([[0, 1, 2, 3, 4],
                                              [1, 2, 3, 4, 5],
                                              [2, 3, 4, 5, 6],
                                              [3, 4, 5, 6, 7],
                                              [4, 5, 6, 7, 8]])


dataset2 = dataset.map(lambda data: data[1:3])
for val in dataset2.as_numpy_iterator():
    print(val)

输出:

[1 2]
[2 3]
[3 4]
[4 5]
[5 6]

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接