如何在Theano中保存/序列化一个训练好的模型?

8

我按照加载和保存文档中的说明保存了模型。

# saving trained model
f = file('models/simple_model.save', 'wb')
cPickle.dump(ca, f, protocol=cPickle.HIGHEST_PROTOCOL)
f.close()

ca 是一个经过训练的自编码器,它是类 cA 的实例。在我构建和保存模型的脚本中,我可以毫不费力地调用 ca.get_reconstructed_input(...)ca.get_hidden_values(...)

在另一个脚本中,我尝试加载已训练好的模型。

# loading the trained model
model_file = file('models/simple_model.save', 'rb')
ca = cPickle.load(model_file)
model_file.close()

我遇到了以下错误。
ca = cPickle.load(model_file)

AttributeError: 'module' object has no attribute 'cA'

2个回答

12

对于被pickle的对象,所有类的定义都需要被执行反序列化的脚本所知道。在其他StackOverflow问题中有更多相关信息(例如,AttributeError: 'module' object has no attribute 'newperson')。

只要正确导入cA,您的代码就是正确的。考虑到您得到的错误,可能并非如此。确保使用from cA import cA而不是import cA

或者,您的模型由其参数定义,因此您可以仅pickle参数值。这可以通过两种方式完成,具体取决于您的视角。

  1. 保存Theano共享变量。在这里,我们假设ca.params是Theano共享变量实例的常规Python列表。

    cPickle.dump(ca.params, f, protocol=cPickle.HIGHEST_PROTOCOL)
    
  2. 保存存储在Theano共享变量中的NumPy数组。

  3. cPickle.dump([param.get_value() for param in ca.params], f, protocol=cPickle.HIGHEST_PROTOCOL)
    
    当您想要加载模型时,您需要重新初始化参数。例如,创建类的新实例,然后执行以下任一操作:

ca.params = cPickle.load(f)
ca.W, ca.b, ca.b_prime = ca.params
或者
ca.params = [theano.shared(param) for param in cPickle.load(f)]
ca.W, ca.b, ca.b_prime = ca.params

请注意,您需要同时设置params字段和单独的参数字段。


我看到的错误是因为我使用了import cA而不是from cA import cA。我发布的代码在其他方面是正确的。你提供的替代方法也是正确的。我认为最干净的关闭此线程的方法是在你的第一段中添加一些内容(标识问题的真正来源),例如“确保你使用的是from cA import cA而不仅仅是import cA”,然后我可以将你的答案标记为已接受。谢谢! - xagg
对我来说,从cpickle加载模型的速度和编译它的速度差不多慢。 - Erik Aronesty
据我所了解,重要的是要知道这个pickle文件基本上将绑定到相同的硬件上,至少你不能在基于CPU的Theano上加载基于CUDA的模型。我对这个事实感到非常惊讶,因为在不同的硬件之间传输学习网络是一个非常棘手的任务。 - flaschenpost
学习的神经网络只需要训练好的参数即可。您不需要将编译的Theano函数进行pickle。实际上,出于您提供的原因,避免这样做可能更好。最好的方法是仅将模型参数(作为numpy数组,而不是Theano共享变量)进行pickle,然后将其加载回编译的网络形式中(即CPU或GPU版本)。 - Daniel Renshaw
1
@DanielRenshaw 可以使用 theano.misc.pkl_utils.dump()theano.misc.pkl_utils.load() 进行序列化/反序列化。 - loretoparisi

0

保存模型的另一种替代方式是保存其权重和架构,然后加载相同的方式,就像我们对预训练CNN所做的那样:

def save_model(model):


   model_json = model.to_json()
   open('cifar10_architecture.json', 'w').write(model_json)
   model.save_weights('cifar10_weights.h5', overwrite=True)

来源/参考:https://blog.rescale.com/neural-networks-using-keras-on- Rescale/


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接