深度学习中预训练模型库非常重要,它能够帮助我们非常方便的获取到模型的结构、模型的权重文件等,这大大降低了入门深度学习的门槛,如高性能的硬件设备(服务器、GPU),同时使用迁移学习的思想能够大大缩短我们开发可实用模型的时间。常见的预训练模型库包含有 torchvision.models(CV 模型)、transformers(CV 和 NLP 大模型相关)、timm(包含 CV 领域小模型和大模型,开发公司同 transformers 的 hugging face)。其中 timm 非常方便我们查看模型结构,同时可加载预训练的模型权重,且支持的模型比较多。本篇介绍 timm。

安装

timm 的官方网址是:https://github.com/huggingface/pytorch-image-models

1
pip install timm

使用

查看预训练模型

1
2
3
4
import timm

avail_pretrained_models = timm.list_models(pretrained=True)
len(avail_pretrained_models)
1
2
# 截止目前已经有 1289 个预训练模型
1289

查看某类模型

1
2
all_densnet_models = timm.list_models("*sam*")
all_densnet_models
1
2
3
4
['samvit_base_patch16',
'samvit_base_patch16_224',
'samvit_huge_patch16',
'samvit_large_patch16']

查看模型结构

1
timm.create_model(model_name="vit_huge_patch14_224")

加载模型

1
2
model = timm.create_model("samvit_huge_patch16", pretrained=True)
model.default_cfg
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
{'url': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth',
'hf_hub_id': 'timm/samvit_huge_patch16.sa1b',
'architecture': 'samvit_huge_patch16',
'tag': 'sa1b',
'custom_load': False,
'input_size': (3, 1024, 1024),
'fixed_input_size': True,
'interpolation': 'bicubic',
'crop_pct': 1.0,
'crop_mode': 'center',
'mean': (0.485, 0.456, 0.406),
'std': (0.229, 0.224, 0.225),
'num_classes': 0,
'pool_size': None,
'first_conv': 'patch_embed.proj',
'classifier': 'head.fc',
'license': 'apache-2.0'}

保存模型权重

1
2
3
4
5
6
7
8
9
10
import torch

# 保存模型权重
torch.save(model.state_dict(),'./timm_model-state_dict.pth')

# 保存整个模型,size 更大一些
torch.save(model.state_dict(),'./timm_model.pth')

# 加载模型权重
model.load_state_dict(torch.load('./timm_model-state_dict.pth'))

可视化模型结构

当我们有模型权重文件(*.pth)后,我们可以使用 netron 来可视化模型结构,更加直观。

netron 网址为:https://netron.app/

参考文献

  1. 6.3 模型微调 - timm
  2. 视觉 Transformer 优秀开源工作:timm 库 vision transformer 代码解读
  3. timm——pytorch下的迁移学习模型库·详细使用教程