函数tf.gather()的反函数

6
假设我有一个形状为[B,D]的张量a,以及一个包含形状为[B]的指数列表I。现在我想要使用列表中的索引将张量扩展到形状[M,D],其中M > B。请注意,这些索引属于范围[0,M]。具体而言,I是从张量a到另一个具有较大维度0值的张量的行映射。这个功能与函数tf.gather()相反。请问有解决方案吗?谢谢

2
tf.scatter_nd等和tf.zeros/tf.fill - Patwie
非常符合我的需求,谢谢! - lenhhoxung
1个回答

1

tf.scatter_ndtf.gather_nd的反向操作。让我们通过一个往返示例来看看:

  • 首先,我们创建一个形状为((5,4,1,2,3))的张量,其中除了索引[1,2,0,0][3,0,0,1]处的元素外,所有元素都为零。这两个位置的元素分别为[16, 12, 11][3,0,0,1],使用tf.scatter_nd
  • 其次,我们通过在生成的张量上应用tf.gather_nd来执行反向操作,以获取两个原始向量,即[16, 12, 11][3,0,0,1]
updates = [[16, 12, 11],[18, 40, 37]]
indices = [[1,2,0,0], [3,0,0,1]]
shape = (5,4,1,2,3)

# first step - scatter
scat_tensor = tf.scatter_nd(indices=indices, updates=updates, shape=shape)
print(f"verification: expected {updates[0]}, got {scat_tensor[1,2,0,0]}")

# now the reverse step
reconstructed_updates = tf.gather_nd(scat_tensor, indices)
print(f"verification: expected {updates }, got {reconstructed_updates }")

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接