为什么我不能在@tf.function中使用TensorArray.gather()?

3

从 TensorArray 中读取数据:

def __init__(self, size):
    self.obs_buf = tf.TensorArray(tf.float32, size=size, clear_after_read=False)
    self.obs2_buf = tf.TensorArray(tf.float32, size=size, clear_after_read=False)
    self.act_buf = tf.TensorArray(tf.float32, size=size, clear_after_read=False)
    self.rew_buf = tf.TensorArray(tf.float32, size=size, clear_after_read=False)
    self.done_buf = tf.TensorArray(tf.float32, size=size, clear_after_read=False)

def get_sample(self, batch_size):
        idxs = tf.random.uniform(shape=[batch_size], maxval=self.size, dtype=tf.int32)
        tf.print(idxs)
        return self.obs_buf.gather(indices=idxs),     \     # HERE IS THE ISSUE
               self.act_buf.gather(indices=idxs),     \
               self.rew_buf.gather(indices=idxs),     \
               self.obs2_buf.gather(indices=idxs),    \
               self.done_buf.gather(indices=idxs)

使用:

@tf.function
def train(self, rpm, batch_size, gradient_steps):
    for gradient_step in tf.range(1, gradient_steps + 1):
        obs, act, rew, next_obs, done = rpm.get_sample(batch_size)

        with tf.GradientTape() as tape:
        ...

问题:

回溯(最近的调用最先): 文件“.\main.py”,第130行, rl_training.train() 文件“C:\Users\user\Documents\Projects\rl-toolkit\rl_training.py”,第129行, train() self._rpm,self.batch_size,self.gradient_steps,logging_wandb=self.logging_wandb 文件“C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\def_function.py”,第828行, call() result = self._call(*args,**kwds) 文件“C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\def_function.py”,第871行, _call() self._initialize(args,kwds,add_initializers_to=initializers) 文件“C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\def_function.py”,第726行, _initialize *args, **kwds)) 文件“C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\function.py”,第2969行, _get_concrete_function_internal_garbage_collected graph_function,_ = self._maybe_define_function(args,kwargs) 文件“C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\function.py”,第3361行, _maybe_define_function graph_function = self._create_graph_function(args,kwargs) 文件“C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\function.py”,第3206行, _create_graph_function capture_by_value=self._capture_by_value), 文件“C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\func_graph.py”,第990行, func_graph_from_py_func func_outputs = python_func(*func_args,**func_kwargs) 文件“C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\def_function.py”,第634行, wrapped_fn out = weak_wrapped_fn().wrapped(*args,**kwds) 文件“C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\eager\function.py”,第3887行, bound_method_wrapper return wrapped_fn(*args,**kwargs) 文件“C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\func_graph.py”,第977行, wrapper raise e.ag_error_metadata.to_exception(e) tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError:在用户代码中:

C:\Users\user\Documents\Projects\rl-toolkit\policy\sac\sac.py:183 update  *
    obs, act, rew, next_obs, done = rpm.get_sample(batch_size)
C:\Users\user\Documents\Projects\rl-toolkit\utils\replay_buffer.py:39 __call__  *
    return self.obs_buf.gather(indices=idxs),                    self.act_buf.gather(indices=idxs),                    self.rew_buf.gather(indices=idxs),                    self.obs2_buf.gather(indices=idxs),                   self.done_buf.gather(indices=idxs)
C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\ops\tensor_array_ops.py:1190 gather  **
    return self._implementation.gather(indices, name=name)
C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\ops\tensor_array_ops.py:861 gather
    return array_ops.stack([self._maybe_zero(i) for i in indices])
C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\ops.py:505 __iter__
    self._disallow_iteration()
C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\ops.py:498 _disallow_iteration
    self._disallow_when_autograph_enabled("iterating over `tf.Tensor`")
C:\Users\user\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\framework\ops.py:476 _disallow_when_autograph_enabled
    " indicate you are trying to use an unsupported feature.".format(task))

OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.

为什么我在这种情况下不能使用TensorArray?还有哪些替代方案可供选择?

这个 https://github.com/tensorflow/tensorflow/issues/31952 有帮助吗? - Andrey
抱歉,我无法提供帮助,因为我在使用tf.TensorArray.gather()时遇到了问题,而不是tf.gather()。此外,+0解决方案在这种情况下也不起作用。 - Martin Kubovčík
1个回答

1

问题已经解决。必须使用tf.Variable而不是tf.TensorArray。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接