目录
  • 参考
  • torch.load()
  • 模型的保存
  • 模型加载中的map_location参数
    • map_location=None
    • map_location=torch.device()
    • map_location={xx:xx}
  • 总结

    参考

    TORCH.LOAD

    torch.load()

    函数格式为:torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args),一般我们使用的时候,基本只使用前两个参数。

    模型的保存

    模型保存有两种形式,一种是保存模型的state_dict(),只是保存模型的参数。那么加载时需要先创建一个模型的实例model,之后通过torch.load()将保存的模型参数加载进来,得到dict,再通过model.load_state_dict(dict)将模型的参数更新。

    另一种是将整个模型保存下来,之后加载的时候只需要通过torch.load()将模型加载,即可返回一个加载好的模型。

    具体可参考:PyTorch模型的保存与加载。

    模型加载中的map_location参数

    具体来说,map_location参数是用于重定向,比如此前模型的参数是在cpu中的,我们希望将其加载到cuda:0中。或者我们有多张卡,那么我们就可以将卡1中训练好的模型加载到卡2中,这在数据并行的分布式深度学习中可能会用到。

    首先定义一个AlexNet,并使用cuda:0将其训练了一个猫狗分类,之后把模型存储起来。

    map_location=None

    我们先把state_dict加载进来。

    model_path = "./cuda_model.pth"
    model = torch.load(model_path)
    print(next(model.parameters()).device)
    

    结果为:

    cuda:0

    因为保存的时候就是模型就是cuda:0的,所以加载进来也是。

    map_location=torch.device()

    model_path = "./cuda_model.pth"
    model = torch.load(model_path, map_location=torch.device('cpu'))
    print(next(model.parameters()).device)
    

    结果为:

    cpu

    模型从cuda:0变成了cpu

    map_location={xx:xx}

    model_path = "./cuda_model.pth"
    model = torch.load(model_path, map_location={'cuda:0':'cuda:1'})
    print(next(model.parameters()).device)
    

    结果为:

    cuda:1

    模型从cuda:0变成了cuda:1

    model_path = "./cuda_model.pth"
    model = torch.load(model_path, map_location={'cuda:2':'cpu'})
    print(next(model.parameters()).device)
    

    结果为:

    cuda:0

    模型还是cuda:0,并没有变成cpu。因为这个map_location的映射是不对的,原始的模型就是cuda:0,而映射是cuda:2cpu,是不对的。这种情况下,map_location返回None,也就是和不加map_location相同。

    总结

    声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。