如何在TensorFlow 2.0中使用由Dataset.window()方法创建的窗口?

33

我正在尝试创建一个数据集,它将从时间序列中返回随机窗口,并将下一个值作为目标,使用TensorFlow 2.0。

我正在使用Dataset.window(),这看起来很有前途:

import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices(tf.range(10))
dataset = dataset.window(5, shift=1, drop_remainder=True)
for window in dataset:
    print([elem.numpy() for elem in window])

输出:

[0, 1, 2, 3, 4]
[1, 2, 3, 4, 5]
[2, 3, 4, 5, 6]
[3, 4, 5, 6, 7]
[4, 5, 6, 7, 8]
[5, 6, 7, 8, 9]

不过,我希望使用最后一个值作为目标。如果每个窗口都是一个张量,我会使用:

dataset = dataset.map(lambda window: (window[:-1], window[-1:]))

然而,如果我尝试这样做,则会抛出异常:

TypeError: '_VariantDataset' object is not subscriptable
1个回答

42
解决方案是这样调用flat_map():
dataset = dataset.flat_map(lambda window: window.batch(5))

现在数据集中的每个条目都是一个窗口,所以您可以像这样分割它:

dataset = dataset.map(lambda window: (window[:-1], window[-1:]))

因此,完整的代码如下:

import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices(tf.range(10))
dataset = dataset.window(5, shift=1, drop_remainder=True)
dataset = dataset.flat_map(lambda window: window.batch(5))
dataset = dataset.map(lambda window: (window[:-1], window[-1:]))

for X, y in dataset:
    print("Input:", X.numpy(), "Target:", y.numpy())

输出结果为:

Input: [0 1 2 3] Target: [4]
Input: [1 2 3 4] Target: [5]
Input: [2 3 4 5] Target: [6]
Input: [3 4 5 6] Target: [7]
Input: [4 5 6 7] Target: [8]
Input: [5 6 7 8] Target: [9]

8
虽然这并不是回答问题所必需的,但您能否详细说明为什么我们需要进行这个 flat_map 步骤?我仍然很难理解。 - Elisio Quintino
24
window()方法返回一个数据集,其中每个窗口本身又表示为一个数据集。例如{{1,2,3,4,5},{6,7,8,9,10},...},其中{...}表示一个数据集。但我们只想要一个包含张量的常规数据集:{[1,2,3,4,5],[6,7,8,9,10],...},其中[...]表示一个张量。flat_map()方法在转换每个嵌套数据集后,返回嵌套数据集中的所有张量。如果我们不进行批处理,则会得到:{1,2,3,4,5,6,7,8,9,10,...}。将每个窗口分批到其完整大小,我们将获得所需的{[1,2,3,4,5],[6,7,8,9,10],...}。清楚吗? - MiniQuark
有没有办法将这些样本制作成小批量?我们已经从window.batch(5)得到了一个None维度,因此当添加例如dataset.batch(3)时,我们会得到另一个None维度。 - Sip
好的,它实际上是有效的,因为window.batch调用的None维度当然是必要的。 - Sip
使用flat_map会失去len()的能力,但是可以使用较慢的len(list(dataset))版本。 - Harbor
如果我的数据集来自字典,我该如何使用它?即 dataset = tf.data.Dataset.from_tensor_slices(some_dict)。 - ablanch5

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接