Tensorflow tf.map_fn 参数

4

我正在尝试构造参数,以便它们可以与tf.map_fn()正常工作,但大多数示例文档只讨论与函数参数形状相同的数组或张量。

链接包括:

TensorFlow 的 map_fn 支持接受多个张量吗?

我的具体示例是这样的: 我有一些 TensorFlow 函数,期望 [None, 2] 和 [x,y] 作为参数张量形状。

张量 A 的形状为 [batch_size, x*y, 2]

张量 B 的形状为 [batch_size, x, y]

lambdaData = (tensorA, tensorB)
lambdaFunc = lambda x: tensorflowFunc(x[0], x[1])
returnValues = tf.map_fn(lambdaFunc, lambdaData)

来自tensorflow文档:

If elems is a (possibly nested) list or tuple of tensors, then each of these 
tensors must have a matching first (unpack) dimension

由于张量A和B只在第0维匹配,我无法堆叠或连接它们;我还尝试创建lambdaData:

  1. 两个张量的列表
  2. 两个张量的元组
  3. 张量对的列表

上述所有方法都导致不同的维度不匹配错误。我想按照文档建议将所有数据放入单个张量中,但由于张量A和张量B之间的维度不匹配,我无法这样做。是否有人尝试过使用元组或参数列表来传递elems?

2个回答

4

一种更简洁的解决方案是指定map_fndtype参数(请参见文档),例如:

tf.map_fn(lambda x: fn(*x), elements, dtype=tf.float32)

如果fn只返回一个float32值。


dtype 在最新的 Tensorflow 版本中已经被弃用,应该使用 fn_output_signature 标志。 - Gaurav Srivastava

3
原来 tf.map_fn 的错误信息非常具有误导性;虽然文档没有详细说明,但是如果你传递一个元组/张量列表时,需要将返回值的数量与函数的参数数量完全匹配。最简单的方法是返回无关紧要的东西,然后只获取第一个返回值。
print(a.shape) #[batch, 784, 2]
print(b.shape) #[batch, 28, 28]
lambdaData = (a, b)
testFunc = lambda x: return <somethingUseful>, 0
returnValues, _ = tf.map_fn(testFunc, lambdaData)

正常运作。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接