如何导入手动下载的MNIST数据集?

20
我一直在尝试使用Keras示例,需要导入MNIST数据。
from keras.datasets import mnist
import numpy as np
(x_train, _), (x_test, _) = mnist.load_data()

它会生成如下错误信息:Exception: URL fetch failure on https://s3.amazonaws.com/img-datasets/mnist.pkl.gz: None -- [Errno 110] Connection timed out

这可能与我使用的网络环境有关。 是否有任何函数或代码可以让我直接导入已经手动下载的MNIST数据集?

我尝试了以下方法

import sys
import pickle
import gzip
f = gzip.open('/data/mnist.pkl.gz', 'rb')
  if sys.version_info < (3,):
    data = pickle.load(f)
else:
    data = pickle.load(f, encoding='bytes')
f.close()
import numpy as np
(x_train, _), (x_test, _) = data

然后我收到以下错误消息

Traceback (most recent call last):
File "test.py", line 45, in <module>
(x_train, _), (x_test, _) = data
ValueError: too many values to unpack (expected 2)
6个回答

13

好的,keras.datasets.mnist文件非常短小。您可以手动模拟相同的操作,即:

  1. https://s3.amazonaws.com/img-datasets/mnist.pkl.gz下载数据集。

import gzip
f = gzip.open('mnist.pkl.gz', 'rb')
if sys.version_info < (3,):
    data = cPickle.load(f)
else:
    data = cPickle.load(f, encoding='bytes')
f.close()
(x_train, _), (x_test, _) = data

嗨sygi,感谢您的建议。然而,我得到了如更新后的帖子所示的错误消息。唯一不同的是我使用了pickle。看起来在加载数据时它没有给我错误。 - user785099
1
我已经检查过了,在我的系统上,使用pickle和cPickle以及Python 2和3都可以正常工作。您确定您拥有相同的文件(md5 b39289ebd4f8755817b1352c8488b486)吗? - sygi
它现在可以工作了,不知道之前为什么会出现错误信息。非常感谢。 - user785099
在我的情况下,添加以下导入import sys; import pickle; import gzip;并使用pickle而不是cPickle有效 - 我正在macOS Mojave上使用Python 3.6.7。 - Giorgio Tempesta
虽然这是个很好的回答,但我建议使用"with open"而不是"f.close()",这样可以避免内存泄漏。 - undefined

12

Keras文件现在位于Google Cloud Storage的新路径中(之前位于AWS S3中):

https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz

使用时:

tf.keras.datasets.mnist.load_data()

您可以传递一个path参数。

load_data()将调用get_file(),它以fname作为参数,如果路径是完整的且文件存在,则不会下载。

示例:

# gsutil cp gs://tensorflow/tf-keras-datasets/mnist.npz /tmp/data/mnist.npz
# python3
>>> import tensorflow as tf
>>> path = '/tmp/data/mnist.npz'
>>> (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data(path)
>>> len(train_images)
>>> 60000

1
load_data函数不再具有路径参数。 - Surendra

10

1
这正是我正在寻找的,非常感谢Tardis。 - Ashwin Hegde

5
  1. Download file https://s3.amazonaws.com/img-datasets/mnist.npz
  2. Move mnist.npz to .keras/datasets/ directory
  3. Load data

    import keras
    from keras.datasets import mnist
    
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    

1

keras.datasets.mnist.load_data()会尝试从远程仓库获取数据,即使指定了本地文件路径。然而,最简单的解决方法是使用numpy.load()来加载已下载的文件,就像他们所做的一样:

path = '/tmp/data/mnist.npz'

import numpy as np

with np.load(path, allow_pickle=True) as f:
    x_train, y_train = f['x_train'], f['y_train']
    x_test, y_test = f['x_test'], f['y_test']

0

Gogasca的回答对我有帮助,只需稍作调整即可。对于Python 3.9,在~/Library/Python/3.9/lib/python/site-packages/keras/datasets/mnist.py中更改代码,使其使用path变量作为完整路径而不是添加origin_folder,这样可以将任何本地路径传递给下载文件。

  1. 下载文件:https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
  2. 将其放入~/Library/Python/3.9/lib/python/site-packages/keras/datasets/或您喜欢的其他位置。
  3. 修改~/Library/Python/3.9/lib/python/site-packages/keras/datasets/mnist.py
path = path

""" origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/' """
""" path = get_file(
path,origin=origin_folder + 'mnist.npz',file_hash='731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1') """

with np.load(path, allow_pickle=True) as f:  # pylint:
    disable=unexpected-keyword-arg
    x_train, y_train = f['x_train'], f['y_train']
    x_test, y_test = f['x_test'], f['y_test']
return (x_train, y_train), (x_test, y_test)

使用以下代码加载数据:
path = "/Users/username/Library/Python/3.9/lib/python/site-packages/keras/datasets/mnist.npz"
(train_images, train_labels), (test_images, test_labels ) = mnist.load_data(path=path)```

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接