获取Keras中VGG-16所有已知类的列表

12

我使用了Keras中预训练的VGG-16模型。

到目前为止,我的工作源代码如下:

from keras.applications.vgg16 import VGG16
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
from keras.applications.vgg16 import preprocess_input
from keras.applications.vgg16 import decode_predictions

model = VGG16()

print(model.summary())

image = load_img('./pictures/door.jpg', target_size=(224, 224))
image = img_to_array(image)  #output Numpy-array

image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))

image = preprocess_input(image)
yhat = model.predict(image)

label = decode_predictions(yhat)
label = label[0][0]

print('%s (%.2f%%)' % (label[1], label[2]*100))

我发现这个模型是在1000类上训练的。有没有可能获取这个模型所训练的类别列表?打印出所有的预测标签并不可行,因为只返回了5个。

先行致谢。

4个回答

10

你可以使用decode_predictions函数并在top参数中传递类的总数top=1000(只有默认值为5)。

或者你可以看看Keras内部是如何实现的:它会下载文件imagenet_class_index.json(通常将其缓存到~/.keras/models/中)。这是一个简单的json文件,包含了所有的类标签。


知道如何做是很好的,但如果有一个 jupyter notebook 的链接来展示已完成的操作,那就更好了。 - wordsforthewise
@YSelf。谢谢。我发现这是在R中运行Keras时列出所有类标签最简单的方法。 - R.S.
1
嘿!我正在使用Google Colab编写我的代码,但是当我输入model.decode_classifications(predictions, top = 1000)时,它显示“AttributeError: 'Model' object has no attribute 'decode_predictions'”。 prediction = model.predict(img) model.decode_predictions(predictions, top = 1000) - TheSHETTY-Paradise

2
我认为如果您像这样做:

我认为如果您这样做:

vgg16 = keras.applications.vgg16.VGG16(include_top=True,
                               weights='imagenet',
                               input_tensor=None,
                               input_shape=None,
                               pooling=None,
                               classes=1000)

vgg16.decode_predictions(np.arange(1000), top=1000)

将您的预测数组替换为np.arange(1000)。目前代码未经测试。

这里是训练标签的链接,我想:http://image-net.org/challenges/LSVRC/2014/browse-synsets


0
如果您稍微编辑一下代码,就可以获得所提供示例的所有顶级预测列表。Tensorflow的decode_predictions返回一个列表类预测元组的列表。因此,首先将top = 1000参数添加到label = decode_predictions(yhat, top=1000)中,正如@YSelf建议的那样。然后将label = label[0][0]更改为label = label[0][:]以选择所有预测。标签将看起来像这样:
[('n04252225', 'snowplow', 0.4144803),
('n03796401', 'moving_van', 0.09205707),
('n04461696', 'tow_truck', 0.08912289),
('n03930630', 'pickup', 0.07173037),
('n04467665', 'trailer_truck', 0.048759833),
('n02930766', 'cab', 0.043586567),
('n04037443', 'racer', 0.036957625),....)]

从这里开始,您需要进行元组解包。如果您只想获取1000个类的列表,您可以调用[y for (x,y,z) in label],然后您将获得所有1000个类的列表。输出将如下所示:

['snowplow',
'moving_van',
'tow_truck',
'pickup',
'trailer_truck',
'cab',
'racer',....]

0

这一行代码将打印出所有类别名称和索引:

decode_predictions(np.expand_dims(np.arange(1000), 0), top=1000)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接