有没有办法将PyTorch张量转换为TensorFlow张量?

6

https://github.com/taoshen58/BiBloSA/blob/ec67cbdc411278dd29e8888e9fd6451695efc26c/context_fusion/self_attn.py#L29

我需要使用上面链接中实现的mulit_dimensional_attention,它是在TensorFlow中实现的,但我正在使用PyTorch,所以我能否将PyTorch Tensor转换为TensorFlow Tensor,还是必须在PyTorch中实现它。
在这里我尝试使用的代码,我必须将'rep_tensor'作为TensorFlow张量类型传递,但我有PyTorch张量。
def multi_dimensional_attention(rep_tensor, rep_mask=None, scope=None,
                                keep_prob=1., is_train=None, wd=0., activation='elu',
                                tensor_dict=None, name=None):

    # bs, sl, vec = tf.shape(rep_tensor)[0], tf.shape(rep_tensor)[1], tf.shape(rep_tensor)[2]

    ivec = rep_tensor.shape[2]
    with tf.variable_scope(scope or 'multi_dimensional_attention'):
        map1 = bn_dense_layer(rep_tensor, ivec, True, 0., 'bn_dense_map1', activation,
                              False, wd, keep_prob, is_train)
        map2 = bn_dense_layer(map1, ivec, True, 0., 'bn_dense_map2', 'linear',
                              False, wd, keep_prob, is_train)
        # map2_masked = exp_mask_for_high_rank(map2, rep_mask)

        soft = tf.nn.softmax(map2, 1)  # bs,sl,vec
        attn_output = tf.reduce_sum(soft * rep_tensor, 1)  # bs, vec

        # save attn
        if tensor_dict is not None and name is not None:
            tensor_dict[name] = soft

        return attn_output

2个回答

13
您可以将pytorch张量转换为numpy数组,然后将其转换为tensorflow张量,反之亦然:
import torch
import tensorflow as tf

pytorch_tensor = torch.zeros(10)
np_tensor = pytorch_tensor.numpy()
tf_tensor = tf.convert_to_tensor(np_tensor)

话虽如此,如果你想训练一个同时使用pytorch和tensorflow的模型,这将会是...尴尬的,慢的,有缺陷的,并且至少需要很长时间编写。因为这些库必须找出如何反向传播成本。

所以除非你拥有预先训练好的pytorch attention block,我建议你仅仅坚持使用其中一个库,因为两个库都有足够多的实现例子和训练好的模型。Tensorflow通常会更快一些,但速度差异并不那么显着,并且我提出的“hack”的方式可能会使整个过程比单独使用任何一个库更慢。


2

@George的解决方案很好,但至少在2023年4月,转换为numpy的中间步骤不再必要(尽管我不确定它是否在转换过程中内部使用)。因此,这是一个简单的命令:

import torch
import tensorflow as tf

pytorch_tensor = torch.zeros(10)
tf_tensor = tf.convert_to_tensor(pytorch_tensor)

关于笨拙、缓慢等评论,我曾经看到过像Yolov5这样的巨大作品在某个时候混合使用这种组合。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接