InceptionV3和TensorFlow的迁移学习

4
我希望能够从给定的tensorflow inceptionV3示例中进行迁移学习。在这里https://github.com/AKSHAYUBHAT/VisualSearchServer/blob/master/notebooks/notebook_network.ipynb,按照图像分类示例和给出的运算符和张量名称,我可以创建我的图形。但是,当我将一个大小为(100, 299, 299, 3)的图像批次放入预计算的inception图中时,在pool_3层处出现以下形状错误:
ValueError: Cannot reshape a tensor with 204800 elements to shape [1, 2048] (2048 elements)

看起来这个inceptionV3图形不接受图像批次(batch)作为输入。我错了吗?

4个回答

4
实际上,如果你提取正确的内容,它可以用于迁移学习。将形状为[N,299,299,3]的图像作为ResizeBilinear:0输入批处理没有问题,然后使用pool_3:0张量。重塑之后会出现问题,但你可以自己进行重塑(反正之后你会有自己的层)。如果你想使用原始分类器并进行批处理,则可以在pool_3:0之上添加自己的重塑,然后添加softmax层,重用原始softmax的权重/偏置张量。

简而言之,对于双img堆栈,其形状为(2, 299, 299, 3),以下操作有效:

pooled_2 = sess.graph.get_tensor_by_name("pool_3:0").eval(session=sess, feed_dict={'ResizeBilinear:0':double_img})
pooled_2.shape
# => (2, 1, 1, 2048)

尝试过这个并得到了这样的错误提示:"Cannot feed value of shape (10, 299, 299, 3) for Tensor u'ResizeBilinear:0', which has shape '(1, 299, 299, 3)'"。和https://github.com/tensorflow/tensorflow/issues/1021相同的结果。 - Phillip Godzin

2

0

这样的东西应该可以:

    with g.as_default():
     inputs = tf.placeholder(tf.float32, shape=[batch_size, 299, 299, 3],
                                name='input')

        with slim.arg_scope(inception.inception_v3_arg_scope()):

            logits, end_points = inception.inception_v3( inputs, 
            num_classes=FLAGS.num_classes, is_training=False)
            variables_to_restore = lim.get_variables_to_restore(exclude=exclude)
        sess = tf.Session()

        saver = tf_saver.Saver(variables_to_restore)

然后您应该能够调用该操作:

        sess.run("pool_3:0",feed_dict={'ResizeBilinear:0':images})

0

etarion提出了一个非常好的观点。然而,我们不必自己重新塑形;相反,我们可以改变reshape作为输入的shape的值。也就是说,

input_tensor_name = 'import/input:0'
shape_tensor_name = 'import/InceptionV3/Predictions/Shape:0'
output_tensor_name= 'import/InceptionV3/Predictions/Reshape_1:0'

output_tensor = tf.import_graph_def(
    graph.as_graph_def(),
    input_map={input_tensor_name: image_batch,
               shape_tensor_name: [batch_size, num_class]},
    return_elements=[output_tensor_name])

这些张量名称基于inception_v3_2016_08_28_frozen.pb


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接