使用Tensorflow DNN加载检查点并评估单个图像

4
我正在大学进行研究,研究牛津17朵花的AlexNet示例。该示例使用基于TensorFlow的tflearn API。在我的GPU上训练非常顺利,在一段时间后达到了约97%的准确率。
不幸的是,tflearn中尚未实现单个图像的评估,我必须使用model.predict(...)来预测每个批次的所有数据,并循环遍历所有测试集并自己计算准确性。
到目前为止,我的培训代码:
...
import image_loader
X, Y = image_loader.load_data(one_hot=True, shuffle=False)

X = X.reshape(244,244)

# Build network
network = input_data(shape=[None, 224, 224, 3])

network = conv_2d(network, 96, 11, strides=4, activation='relu')
network = max_pool_2d(network, 3, strides=2)
network = local_response_normalization(network)

network = conv_2d(network, 256, 5, activation='relu')
network = max_pool_2d(network, 3, strides=2)
network = local_response_normalization(network)

network = conv_2d(network, 384, 3, activation='relu')
network = conv_2d(network, 384, 3, activation='relu')
network = conv_2d(network, 256, 3, activation='relu')
network = max_pool_2d(network, 3, strides=2)
network = local_response_normalization(network)

network = fully_connected(network, 4096, activation='tanh')
network = dropout(network, 0.5)

network = fully_connected(network, 4096, activation='tanh')
network = dropout(network, 0.5)

network = fully_connected(network, 17, activation='softmax')
network = regression(network, optimizer='momentum',
                 loss='categorical_crossentropy',
                 learning_rate=0.01)

# Training
model = tflearn.DNN(network, checkpoint_path='model_ba',
                max_checkpoints=1, tensorboard_verbose=0)
model.fit(X, Y, n_epoch=3, validation_set=0.1, shuffle=True,
      show_metric=True, batch_size=32, snapshot_step=400,
      snapshot_epoch=False, run_id='ba_soccer_network')

该代码正在保存一个名为“model_ba”的检查点以及以.meta文件形式保存的网络。 有没有可能使用tensorflow加载保存的检查点并评估单个图像?
提前感谢,阿诺

你能检查一下 network/model 是否有 savewrite 方法吗?(灵感来自 这里 - Antonio
是的,确实有一个 model.save() 方法可以保存 ckpt 和 meta 文件(即使在这个 API 中也有一个 model.load() 方法,我需要在不使用 tflearn API 的情况下将保存的 ckpt 和 meta 加载到 tensorflow 代码中)。 - ArnoXf
1个回答

0

保存模型: model.save('name.tflearn')

加载模型: model.load('name.tflearn')

循环测试只需加载模型并按照以下代码执行

files_path = '/your/test/images/directory/path'
img_files_path = os.path.join(files_path, '*.jpg')
img_files = sorted(glob(img_files_path))

for f in img_files:
    try:
        img = Image.open(f).convert('RGB')
        img = ImageOps.fit(img, ((64, 64)), Image.ANTIALIAS)

        img_arr = np.array(img)
        img_arr = img_arr.reshape(-1, 64, 64, 3).astype("float")

        pred = model.predict(img_arr)
        print(" %s" % pred[0])

    except:
        continue

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接