通过常数因子扩展张量中一组行的规模

3

简短概述:如何将张量的一部分按2缩放(tf列表中存在行索引)


详细说明:

indices_of_scaling_ids: 存储行ID列表 Tensor("Squeeze:0", dtype=int64, device=/device:GPU:0) [1, 4, 5, 6, 12]

emb_inputs = tf.nn.embedding_lookup(embedding, self.all_rows) #形状为(batch_size=4, all_row_len, emb_size=128)的张量

因此,对于每个self.all_rows,都会计算emb_inputs

问题/挑战:我需要将在indices_of_scaling_ids中提到的每个row_id的emb_inputs缩放2.0。我尝试过各种切片方法,但似乎找不到好的解决方案。有人可以建议一下吗?谢谢

N.B. TensorFlow初学者

1个回答

1
尝试使用类似以下的内容:
SCALE = 2
emb_inputs = ...
indices_of_scaling_ids = ...
emb_shape = tf.shape(emb_inputs)
# Select indices in boolean array
r = tf.range(emb_shape[1])
mask = tf.reduce_any(tf.equal(r[:, tf.newaxis], indices_of_scaling_ids), axis=1)
# Tile the mask
mask = tf.tile(mask[tf.newaxis, :, tf.newaxis], (emb_shape[0], 1, emb_shape[2]))
# Choose scaled or not depending on indices
result = tf.where(mask, SCALE * emb_inputs, emb_inputs)

这真是太棒了!谢谢你!给我几分钟时间来试一下。 - Vishal Anand
我试了几个小时来让它工作。你能否评论一下如何修复这个错误? https://imgur.com/VjP7F4Z非常感谢! 附言:对于一个二级索引数字列表中的元素进行三维矩阵缩放比通常更困难,因为轴参数也不适用,掩码处理也更加复杂。 - Vishal Anand
1
@VishalAnand 我明白了,原始代码对我有效,但也许这取决于 TensorFlow 的版本,它是否支持仅为 tf.where 中的第一维提供一个数组。我已经在代码中添加了平铺操作,请查看是否适用于您。 - jdehesa
1
终于搞定了!对答案进行了微小的编辑(迭代第二索引,第一索引是批次号)。 再次感谢! - Vishal Anand
@VishalAnand 谢谢您的修复,我误读了一些问题细节,对造成的困惑感到抱歉。 - jdehesa
没有问题!我从这里学到的比大多数教程都要多。 - Vishal Anand

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接