如何将张量的dtype从tf.float32_ref转换为tf.float32?

4

我希望通过函数tf.cast()float32_ref类型的word_embeddings修改为float32类型:

   word_embeddings_modify=tf.cast(word_embeddings,dtype=tf.float32)

但是它并没有按预期工作,word_embeddings_modify的数据类型仍为tf.float32_ref。

   word_embeddings = tf.scatter_nd_update(var_output, error_word_f,sum_all)
   word_embeddings_modify=tf.cast(word_embeddings,dtype=tf.float32)
   word_embeddings_dropout = tf.nn.dropout(word_embeddings_2, dropout_pl)
1个回答

2
你可以使用tf.identity来取消引用_ref类型。
word_embeddings = tf.identity(word_embeddings)

欢迎 @user9799714 - 既然这个答案解决了你的问题,请记得将其标记为已接受! - nessuno
嗨,我有另一个问题,如何将张量dtype=tf.float32转换为dtype=tf.float32_ref? - user9799714
你需要提出一个新的答案,因为这是不同的话题 - 但首先请记得将我的答案标记为被接受的答案。 - nessuno

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接