Tensorflow批量稀疏矩阵乘法

5
我希望能够在批处理中将稀疏张量乘以密集张量。例如,我有一个稀疏张量,相应的密集形状为(20,65536,65536),其中20是批次大小。我想将批次中的每个(65536,65536)与来自张量形状(20,65536)的相应的(65536x1)相乘,该张量具有密集表示。tf.sparse_tensor_dense_matmul只接受秩为2的稀疏张量。是否有一种方法可以在批处理中执行此操作?如果可能的话,我想避免将稀疏矩阵转换为密集矩阵,因为会受到内存限制的影响。
2个回答

2
假设 a 是一个形状为 (20, 65536, 65536) 的稀疏张量,b 是一个形状为 (20, 65536) 的密集张量,则可以按以下方式执行批稀疏-密集矩阵乘法:
y_sparse = tf.sparse.reduce_sum_sparse(a * b[:, None, :], axis=-1)

这个解决方案扩展了张量b的第二维,以实现隐式广播。然后,通过执行稀疏-密集乘法和最后一个轴上的稀疏求和来进行批量矩阵乘法。
如果b有第三个维度,因此它是一批矩阵,则可以逐个相乘它们的列,然后将它们连接起来:
multiplied_dims = []
for i in range (b.shape[-1]):
  multiplied_dims.append(tf.expand_dims(tf.sparse.reduce_sum(a * b[:, :, i][:, None, :], axis=-1), -1))
result = tf.concat(multiplied_dims, -1)

我认为你误解了问题,因为 b 的形状应该是 (20, 65536, 1) - McLP
1
哦,抱歉,我在答案中打错了 b 的形状(我已经更新了)- 代码应该按预期工作。 - rvinas
你能解释一下如果我有形状为(20,65536,3)的b,如何扩展它吗? - McLP
很抱歉,无法将形状为(20,65536,3)b进行广播,因为广播仅支持从密集到稀疏的操作。不过,您可以分别执行三个矩阵乘法,然后沿着最后一个轴连接结果。 - rvinas
可以在数学上实现,但不能仅使用TensorFlow。谢谢! - McLP
显示剩余3条评论

0
答案很简单 - 首先重塑稀疏张量,然后再乘以密集矩阵。类似这样的代码可以实现:
sparse_tensor_rank2 = tf.sparse_reshape(sparse_tensor, [-1, 65536])

请问您能否澄清为什么这会起作用,并扩展代码直到乘法之后? - McLP
它会工作,因为只是重塑后的乘法,但我猜它可能不能满足每个用例。 - Arka Mukherjee

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接