我需要使用Tensorflow分析器来分析一些代码,这些代码因某种原因运行缓慢。不幸的是,涉及的代码使用了tf.Estimator,因此我无法弄清楚如何将运行元数据对象注入到会话run()调用中,以获取分析器所需的信息。
我该怎么办?
我需要使用Tensorflow分析器来分析一些代码,这些代码因某种原因运行缓慢。不幸的是,涉及的代码使用了tf.Estimator,因此我无法弄清楚如何将运行元数据对象注入到会话run()调用中,以获取分析器所需的信息。
我该怎么办?
tf.estimator
使用 tf.train.ProfilerHook
有效!
只需在 TrainSpec
钩子中添加一个 ProfilerHook
即可!
hook = tf.train.ProfilerHook(
save_steps=20,
output_dir=os.path.join(args.model_dir, "tracing"),
show_dataflow=True,
show_memory=True)
hooks = [hook]
train_spec = tf.estimator.TrainSpec(
hooks=hooks,
input_fn=lambda: input_fn())
然后,您可以在 model_dir/tracing
目录下获取类似于 timeline-{}.json
的跟踪文件,并打开 Chrome 中的 chrome://tracing
进行可视化!
with tf.contrib.tfprof.ProfileContext('/tmp/train_dir', dump_steps=[10]) as pctx:
estimator.train() # any thing you want to profile
然后你会在/tmp/train_dir/profile_10
得到一个文件
参数在https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/profiler/profile_context.py中定义