我不确定你是否仍然遇到这个问题(因为已经过去一个月了),但是我通过使用tf.tensordot
和tf.map_fn
解决了同样的问题,它们接受嵌套的输入元素并在第一个(通常是批次)维度上并行执行函数。以下函数在任意秩的张量的最后两个维度上执行批处理矩阵乘法(只要最后两个轴匹配用于矩阵乘法的目的):
def matmul_final_two_dims(tensor1, tensor2):
_your_dtype_here = tf.float64
return tf.map_fn(lambda xy: tf.tensordot(xy[0], xy[1], axes=[[-1], [-2]]),
elems=(tensor1, tensor2), dtype=_your_dtype_here)
使用示例:
>> batchsize = 3
>> tensor1 = np.random.rand(batchsize,3,4,5,2)
>> tensor2 = np.random.rand(batchsize,2,3,2,4)
>> sess.run(tf.shape(matmul_final_two_dims(tensor1, tensor2)))
array([3, 3, 4, 5, 2, 3, 4], dtype=int32)
>> matmul_final_two_dims(tensor1,tensor2)
<tf.Tensor 'map_1/TensorArrayStack/TensorArrayGatherV3:0' shape=(3, 3, 4, 5, 2, 3, 4) dtype=float64>
请注意,输出的第一个维度是正确的批处理大小,形状中的最后一个
2
被张量收缩掉了。但是,您需要进行某种
tf.transpose
操作,以将维度-
5
索引放在正确的位置,因为输出矩阵的索引按照它们在输入张量中出现的顺序排序。
我正在使用TFv1.1。
tf.map_fn
可以并行化,但我不确定上述是否是最有效的实现。供参考:
tf.tensordot API
tf.map_fn API
编辑:上述是对我有用的内容,但我认为您还可以使用
einsum
(
docs here)来完成您想要的内容。
>> tensor1 = tf.constant(np.random.rand(3,4,5))
>> tensor2 = tf.constant(np.random.rand(3,5,7))
>> tf.einsum('bij,bjk->bik', tensor1, tensor2)
<tf.Tensor 'transpose_2:0' shape=(3, 4, 7) dtype=float64>
100,3,100,4
完全合理。你想在轴100
上做什么?逐元素相乘而不进行缩并? - romeric