使用Tensorflow预训练的inception_resnet_v2模型

15
我一直在尝试使用谷歌发布的预训练inception_resnet_v2模型。我正在使用他们的模型定义(https://github.com/tensorflow/models/blob/master/slim/nets/inception_resnet_v2.py)和给定的检查点(http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz)在tensorflow中加载模型,如下所示[下载并提取检查点文件,并下载样本图片dog.jpg和panda.jpg以测试此代码]-
import tensorflow as tf
slim = tf.contrib.slim
from PIL import Image
from inception_resnet_v2 import *
import numpy as np

checkpoint_file = 'inception_resnet_v2_2016_08_30.ckpt'
sample_images = ['dog.jpg', 'panda.jpg']
#Load the model
sess = tf.Session()
arg_scope = inception_resnet_v2_arg_scope()
with slim.arg_scope(arg_scope):
  logits, end_points = inception_resnet_v2(input_tensor, is_training=False)
saver = tf.train.Saver()
saver.restore(sess, checkpoint_file)
for image in sample_images:
  im = Image.open(image).resize((299,299))
  im = np.array(im)
  im = im.reshape(-1,299,299,3)
  predict_values, logit_values = sess.run([end_points['Predictions'], logits], feed_dict={input_tensor: im})
  print (np.max(predict_values), np.max(logit_values))
  print (np.argmax(predict_values), np.argmax(logit_values))

然而,这个模型代码的结果并没有给出预期的结果(不管输入图像是什么,都会预测到类别918)。有人能帮我理解我错在哪里吗?


1
需要进行Imagenet预处理吗? - jeandut
1
顺便问一下,你的代码中 input_tensor 是什么意思? - tidy
不要直接输入图像,尝试输入以下内容:im = np.array(im) - np.mean(np.array(im)) - Pramod Patil
1个回答

14

Inception网络期望输入图像的颜色通道缩放在[-1, 1]之间,可参见 此处

您可以使用现有的预处理方法,或者在示例中自行对图像进行缩放:im = 2 *(im / 255.0)-1.0 ,然后再将其输入到网络中。

如果不对输入进行缩放,[0-255]的值比网络预期的要大得多,因此偏差会强烈地倾向于预测918类别(漫画书)。


1
@alemi,为什么要使用im = 2 *(im / 255.0)-1.0而不是im = im / 255.0进行归一化? - tidy
1
因为零是特殊的,由于第一次操作是矩阵乘法,而黑色相当普遍。 - alemi

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接