如何在使用ppc64le和x86架构的不同版本pytorch(1.3.1和1.6.x)之间加载检查点?

12
由于硬件限制,我被困在使用旧版本的pytorch和torchvision上。因此,在不同计算机、集群和我的个人mac之间发送和接收检查点时出现了问题。我想知道是否有任何方法可以加载模型以避免这个问题?例如,当使用1.6.x时,也许将模型保存为旧格式和新格式。当然,对于1.3.1到1.6.x是不可能的,但至少我希望有些东西能够奏效。你有什么建议吗?当然,我的理想解决方案是我不必担心它,我可以始终统一地加载和保存我的检查点和所有通常使用pickle的东西在所有硬件上。首先出现的错误是一个zip jit错误。
RuntimeError: /home/miranda9/data/f.pt is a zip archive (did you mean to use torch.jit.load()?)

所以我使用了那个(和其他的pickle库):

# %%
import torch
from pathlib import Path


def load(path):
    import torch
    import pickle
    import dill

    path = str(path)
    try:
        db = torch.load(path)
        f = db['f']
    except Exception as e:
        db = torch.jit.load(path)
        f = db['f']
        #with open():
        # db = pickle.load(open(path, "r+"))
        # db = dill.load(open(path, "r+"))
        #raise ValueError(f'FAILED: {e}')
    return db, f

p = "~/data/f.pt"
path = Path(p).expanduser()

db, f = load(path)

Din, nb_examples = 1, 5
x = torch.distributions.Normal(loc=0.0, scale=1.0).sample(sample_shape=(nb_examples, Din))

y = f(x)

print(y)
print('Success!\a')

但是我收到了不同版本的PyTorch的投诉,我被迫使用它们:

Traceback (most recent call last):
  File "hal_pg.py", line 27, in <module>
    db, f = load(path)
  File "hal_pg.py", line 16, in load
    db = torch.jit.load(path)
  File "/home/miranda9/.conda/envs/wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/jit/__init__.py", line 239, in load
    cpp_module = torch._C.import_ir_module(cu, f, map_location, _extra_files)
RuntimeError: version_number <= kMaxSupportedFileFormatVersion INTERNAL ASSERT FAILED at /opt/anaconda/conda-bld/pytorch-base_1581395437985/work/caffe2/serialize/inline_container.cc:131, please report a bug to PyTorch. Attempted to read a PyTorch file with version 3, but the maximum supported version for reading is 1. Your PyTorch installation may be too old. (init at /opt/anaconda/conda-bld/pytorch-base_1581395437985/work/caffe2/serialize/inline_container.cc:131)
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xbc (0x7fff7b527b9c in /home/miranda9/.conda/envs/wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: caffe2::serialize::PyTorchStreamReader::init() + 0x1d98 (0x7fff1d293c78 in /home/miranda9/.conda/envs/wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #2: caffe2::serialize::PyTorchStreamReader::PyTorchStreamReader(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x88 (0x7fff1d2950d8 in /home/miranda9/.conda/envs/wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #3: torch::jit::import_ir_module(std::shared_ptr<torch::jit::script::CompilationUnit>, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, c10::optional<c10::Device>, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > >&) + 0x64 (0x7fff1e624664 in /home/miranda9/.conda/envs/wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #4: <unknown function> + 0x70e210 (0x7fff7c0ae210 in /home/miranda9/.conda/envs/wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #5: <unknown function> + 0x28efc4 (0x7fff7bc2efc4 in /home/miranda9/.conda/envs/wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
<omitting python frames>
frame #26: <unknown function> + 0x25280 (0x7fff84b35280 in /lib64/libc.so.6)
frame #27: __libc_start_main + 0xc4 (0x7fff84b35474 in /lib64/libc.so.6)

有没有什么想法可以让所有集群保持一致?我甚至无法打开pickle文件。


也许这只是因为我被迫使用的当前pytorch版本不支持 :(

RuntimeError: version_number <= kMaxSupportedFileFormatVersion INTERNAL ASSERT FAILED at /opt/anaconda/conda-bld/pytorch-base_1581395437985/work/caffe2/serialize/inline_container.cc:131, please report a bug to PyTorch. Attempted to read a PyTorch file with version 3, but the maximum supported version for reading is 1. Your PyTorch installation may be too old. (init at /opt/anaconda/conda-bld/pytorch-base_1581395437985/work/caffe2/serialize/inline_container.cc:131)
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xbc (0x7fff83ba7b9c in /home/miranda9/.conda/envs/automl-meta-learning_wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: caffe2::serialize::PyTorchStreamReader::init() + 0x1d98 (0x7fff25993c78 in /home/miranda9/.conda/envs/automl-meta-learning_wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #2: caffe2::serialize::PyTorchStreamReader::PyTorchStreamReader(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x88 (0x7fff259950d8 in /home/miranda9/.conda/envs/automl-meta-learning_wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #3: torch::jit::import_ir_module(std::shared_ptr<torch::jit::script::CompilationUnit>, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, c10::optional<c10::Device>, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > >&) + 0x64 (0x7fff26d24664 in /home/miranda9/.conda/envs/automl-meta-learning_wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #4: <unknown function> + 0x70e210 (0x7fff8472e210 in /home/miranda9/.conda/envs/automl-meta-learning_wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #5: <unknown function> + 0x28efc4 (0x7fff842aefc4 in /home/miranda9/.conda/envs/automl-meta-learning_wmlce-v1.7.0-py3.7/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
<omitting python frames>
frame #23: <unknown function> + 0x25280 (0x7fff8d335280 in /lib64/libc.so.6)
frame #24: __libc_start_main + 0xc4 (0x7fff8d335474 in /lib64/libc.so.6)

使用代码:

from pathlib import Path

import torch

path = '/home/miranda9/data/dataset/'
path = Path(path).expanduser() / 'fi_db.pt'
path = str(path)

# db = torch.load(path)
# torch.jit.load(path)
db = torch.jit.load(str(path))

print(db)

相关链接:


你考虑过使用虚拟环境 venv 吗?在处理 Python 项目时,这是一个很好的实践。这样你就可以在同一台机器上拥有不同的版本。 - Tiago Martins Peres
@TiagoMartinsPeres李大仁 当然。事实上,我使用的HPC没有它们就无法工作。我正在使用一个IBM克隆环境。我加载了module load wmlce/1.7.0-py3.7。那应该如何解决问题呢?由于硬件架构ppc64le的缘故,pytorch的版本是固定的。因此,在我的情况下,我不知道使用不同版本的pytorch是否有任何好处(因为这很可能是不可能的)。 - Charlie Parker
4个回答

3

我相信开发人员的意图是传递一个标志来保存为 pickle。只是默认行为的更改。

对于先前检查点的文件,请在较新的环境中重新加载保存的 zip 文件权重(使用 pytorch>=1.6),然后再次将其用作 pickle 检查点(无需重新训练);

从下一次开始更新代码并添加标志

从 ver 1.6 开始弃用 :

我们已经将 torch.save 默认切换为基于 zip 文件的格式,而不是旧的基于 pickle 的格式。torch.load 仍然保留了加载旧格式的能力,但推荐使用新格式。新格式是:

更加友好的检查和构建工具以操作保存文件解决了一个长期存在的问题,即依赖于序列化张量值的模块的序列化(getstate, setstate)函数获取错误数据与 TorchScript 序列化格式相同,使得跨 PyTorch 的序列化更加一致

使用方法如下:

m = MyMod()
torch.save(m.state_dict(), 'mymod.pt') # Saves a zipfile to mymod.pt

要使用旧格式,请传递标志_use_new_zipfile_serialization=False

m = MyMod()
torch.save(m.state_dict(), 'mymod.pt', _use_new_zipfile_serialization=False) # Saves pickle

对于使用旧版本保存的检查点,我该如何在新版本的PyTorch中加载它? - Nagabhushan S N

1

在@maxim velikanov的回答基础上,我创建了一个单独的OrderedDict,其中键与模型的原始状态字典相同,但每个张量值都转换为列表。

然后将此OrderedDict转储到JSON文件中。

def save_model_json(model, path):
    actual_dict = OrderedDict()
    for k, v in model.state_dict().items():
      actual_dict[k] = v.tolist()
    with open(path, 'w') as f:
      json.dump(actual_dict, f)

加载器可以将文件加载为JSON格式,每个列表/整数在复制其值到原始状态字典之前将转换回张量。
def load_model_json(model, path):
  data_dict = OrderedDict()
  with open(path, 'r') as f:
    data_dict = json.load(f)    
  own_state = model.state_dict()
  for k, v in data_dict.items():
    print('Loading parameter:', k)
    if not k in own_state:
      print('Parameter', k, 'not found in own_state!!!')
    if type(v) == list or type(v) == int:
      v = torch.tensor(v)
    own_state[k].copy_(v)
  model.load_state_dict(own_state)
  print('Model loaded')

1
这并不是一个理想的解决方案,但它可以将较新版本的检查点传输到旧版本中。
我也使用ppc64le,并面临同样的问题。可以将模型保存为文本格式,任何PyTorch版本都可以读取。我在ppc64le机器上安装了PyTorch v1.3.0,在笔记本电脑上安装了v1.7.0(不需要显卡)。
步骤1:通过较新的PyTorch版本保存模型。
def save_model_txt(model, path):
    fout = open(path, 'w')
    for k, v in model.state_dict().items():
        fout.write(str(k) + '\n')
        fout.write(str(v.tolist()) + '\n')
    fout.close()

在保存之前,我这样加载模型:
checkpoint = torch.load(path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint, strict=False)

步骤2:传输文本文件
步骤3:在旧版PyTorch中加载文本文件。
def load_model_txt(model, path):
    data_dict = {}
    fin = open(path, 'r')
    i = 0
    odd = 1
    prev_key = None
    while True:
        s = fin.readline().strip()
        if not s:
            break
        if odd:
            prev_key = s
        else:
            print('Iter', i)
            val = eval(s)
            if type(val) != type([]):
                data_dict[prev_key] = torch.FloatTensor([eval(s)])[0]
            else:
                data_dict[prev_key] = torch.FloatTensor(eval(s))
            i += 1
        odd = (odd + 1) % 2

    # Replace existing values with loaded

    print('Loading...')
    own_state = model.state_dict()
    print('Items:', len(own_state.items()))
    for k, v in data_dict.items():
        if not k in own_state:
            print('Parameter', k, 'not found in own_state!!!')
        else:
            try:
                own_state[k].copy_(v)
            except:
                print('Key:', k)
                print('Old:', own_state[k])
                print('New:', v)
                sys.exit(0)
    print('Model loaded')

模型在加载之前必须进行初始化。空模型将传递到函数中。 限制 如果您的模型state_dict包含除(str: torch.Tensor)值以外的其他内容,则此方法将无法正常工作。您可以使用以下方式检查state_dict的内容
for k, v in model.state_dict().items():
    ...

阅读以下内容以理解:

https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_models_for_inference.html

https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113


0

当我加载处理过的数据时,遇到了类似的问题。我之前在torch 1.8中保存了数据为'xxx.pt',但是在torch 1.2中加载它时失败了,即使使用torch.jit.load()也无法成功加载。我的唯一解决方案是在旧版本中重新保存数据。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接