我有一个保存在pb文件中的模型。我希望计算它的FLOPS。我的示例代码如下:
import tensorflow as tf
import sys
from tensorflow.python.platform import gfile
from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.util import compat
pb_file = 'themodel.pb'
run_meta = tf.RunMetadata()
with tf.Session() as sess:
print("load graph")
with gfile.FastGFile(pb_path,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
flops = tf.profiler.profile(tf.get_default_graph(), run_meta=run_meta,
options=tf.profiler.ProfileOptionBuilder.float_operation())
print("test flops:{:,}".format(flops.total_float_ops))
打印信息很奇怪。我的模型有数十层,但打印信息中只报告了18个flops。我非常确定模型已经正确加载,因为如果我尝试按如下方式打印每一层的名称:
print([n.name for n in tf.get_default_graph().as_graph_def().node])
打印信息显示了正确的网络。
我的代码出了什么问题?
谢谢!