Tensorflow - 使用数据集 API 对序列进行填充或截断

4
我正在尝试使用 Dataset API 来准备一个文本序列的 TFRecordDataset。处理后,我有一个字典,其中包含每个记录的张量。每个记录包含两个序列。
我使用 padded_batch 进行填充。
dataset = dataset.padded_batch(batch_size, padded_shapes= {
    'seq1': tf.TensorShape([None]),
    'seq2': tf.TensorShape([None])
})

这会将每个序列填充到批处理中的最大序列长度。然而,我想选择任意序列长度,并在真实序列长度小于该长度时进行填充,否则截断该序列。
当我尝试将None替换为100时,例如,我遇到了一个DataLossError

DataLossError:尝试填充到比输入元素更小的大小。

有没有一种方法可以在序列上实现类似tf.image.resize_image_with_crop_or_pad的功能呢?
2个回答

3

对于填充或截断没有简单的方法,但可以使用map函数获取包含所需长度元素的数据集。这是一个快速的示例:

k = 4
def pad_or_trunc(t):
    dim = tf.size(t)
    return tf.cond(tf.equal(dim, k), lambda: t, lambda: tf.cond(tf.greater(dim, k), lambda: tf.slice(t, [0], [k]), lambda: tf.concat([t, tf.zeros(k-dim, dtype=tf.int32)], 0)))

vals = tf.constant([[1, 1, 1], [2, 2, 2], [3, 3, 3]])
dset1 = tf.data.Dataset.from_tensor_slices(vals)
dset2 = dset1.map(pad_or_trunc)
iter = dset2.make_one_shot_iterator()

with tf.Session() as sess:
    while True:
        try:
            print(sess.run(iter.get_next()))
        except tf.errors.OutOfRangeError:
            break

这是一个不错的想法。然而,我会部分地将其更改为将序列截断到所需长度,就像您所做的那样,但是在单独的步骤中进行填充-也许使用DS.padded_batch-这样,我会用较少的填充来填充较小的样本,而不是将所有样本填充到与最长样本相同的大小。这只是针对TF2中新功能的更新。我知道这个解决方案提出已经有一年半了。 - MAltakrori
太神奇啦!我不知道你的代码在干什么,但稍作修改后就可以用在我的视频上。非常感谢!但是...如果我想要对称填充怎么办?在连接(concat)部分添加一个零张量(zeros tensor),并将尺寸(size)更改为(k-dim)//2和(k-dim)//2+1?那切片(slice)呢? - grofte

1
你可以使用 tf.slicetf.math.greater 截断所有较长的序列,然后使用 padded_batch 对序列进行填充。
一个例子可能是这样的:
import tensorflow as tf
import numpy as np

# data generator
def gen():
  for i in [np.array([1, 1, 1]), np.array([2, 2, 2, 2]), np.array([3, 3, 3, 3, 3])]:
    yield i

cut_or_pad = 4 # 100 in your example

def cut_if_longer(el):
  if tf.greater(tf.shape(el), cut_or_pad): # only slice if longer
    return tf.slice(el, begin=[0], size=[cut_or_pad])
  return el

# data pipeline
dataset = tf.data.Dataset.from_generator( gen, (tf.int32), (tf.TensorShape([None])))
dataset = dataset.map( lambda el: cut_if_longer(el))
dataset = dataset.padded_batch(batch_size=2, padded_shapes=[cut_or_pad], padding_values=-1)

list(dataset.take(2).as_numpy_iterator())

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接