TensorFlow高效进行张量乘法的方法

3

我在tensorflow中有两个张量,第一个张量是3D的,第二个是2D的。我想要像这样将它们相乘:

x = tf.placeholder(tf.float32, shape=[sequence_length, batch_size, hidden_num]) 
w = tf.get_variable("w", [hidden_num, 50]) 
b = tf.get_variable("b", [50])

output_list = []
for step_index in range(sequence_length):
    output = tf.matmul(x[step_index,  :,  :], w) + b
    output_list.append(output)
output = tf.pack(outputs_list)

我使用循环进行乘法运算,但我觉得这样太慢了。有什么最简单/清晰的方法可以使这个过程更快?

2个回答

2
您可以使用 batch_matmul。不幸的是,batch_matmul 不支持沿批次维度的广播,因此您需要复制 w 矩阵。这将使用更多内存,但所有操作都将保留在 TensorFlow 中。
a = tf.ones((5, 2, 3))
b = tf.ones((3, 1))
b = tf.reshape(b, (1, 3, 1))
b = tf.tile(b, [5, 1, 1])
c = tf.batch_matmul(a, b) # use tf.matmul in TF 1.0 
sess = tf.InteractiveSession()
sess.run(tf.shape(c))

这可以提供
array([5, 2, 1], dtype=int32)

谢谢您的回答,但我有一个问题,使用tf.tile后,变量b的大小发生了变化,因为其中有一些无用数据。因此,计算tf.nn.l2_loss(b)很困难。 - Nils Cao
顺便提一下,tile的输入可以是张量,所以您可以根据tf.shape(w)动态构建shape参数。 - Yaroslav Bulatov
1
在tensorflow 1.0中,请使用tf.matmul而不是tf.batch_matmul - holdenlee

1
你可以使用map_fn,它可以在第一维度上扫描函数。
x = tf.placeholder(tf.float32, shape=[sequence_length, batch_size, hidden_num]) 
w = tf.get_variable("w", [hidden_num, 50]) 
b = tf.get_variable("b", [50])

def mul_fn(current_input):
    return tf.matmul(current_input, w) + b

output = tf.map_fn(mul_fn, x)

我曾经使用这个来实现沿着序列的softmax扫描。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接