目录
参考链接:
问题重现:
问题分析:
原因及解决方案
参考链接:
strict=False 但还是size mismatch for []: copying a param with shape [] from checkpoint,the shape in cur_copying a param with shape torch.size([768]) from -CSDN博客
size mismatch for xx.weight错误的解决方法-CSDN博客
问题重现:
RuntimeError: Error(s) in loading state_dict for MMDataParallel: size mismatch for module.neck.def_convs.0.reppoints_pts_init_out.weight: copying a param with shape torch.Size([14, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([21, 64, 1, 1]). size mismatch for module.neck.def_convs.0.reppoints_pts_init_out.bias: copying a param with shape torch.Size([14]) from checkpoint, the shape in current model is torch.Size([21]). size mismatch for module.neck.def_convs.0.reppoints_pts_refine_out.weight: copying a param with shape torch.Size([14, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([21, 64, 1, 1]). size mismatch for module.neck.def_convs.0.reppoints_pts_refine_out.bias: copying a param with shape torch.Size([14]) from checkpoint, the shape in current model is torch.Size([21]). size mismatch for module.neck.def_convs.1.reppoints_pts_init_out.weight: copying a param with shape torch.Size([10, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([15, 64, 1, 1]). size mismatch for module.neck.def_convs.1.reppoints_pts_init_out.bias: copying a param with shape torch.Size([10]) from checkpoint, the shape in current model is torch.Size([15]). size mismatch for module.neck.def_convs.1.reppoints_pts_refine_out.weight: copying a param with shape torch.Size([10, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([15, 64, 1, 1]). size mismatch for module.neck.def_convs.1.reppoints_pts_refine_out.bias: copying a param with shape torch.Size([10]) from checkpoint, the shape in current model is torch.Size([15]). size mismatch for module.neck.def_convs.2.reppoints_pts_init_out.weight: copying a param with shape torch.Size([6, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([9, 64, 1, 1]). size mismatch for module.neck.def_convs.2.reppoints_pts_init_out.bias: copying a param with shape torch.Size([6]) from checkpoint, the shape in current model is torch.Size([9]). size mismatch for module.neck.def_convs.2.reppoints_pts_refine_out.weight: copying a param with shape torch.Size([6, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([9, 64, 1, 1]). size mismatch for module.neck.def_convs.2.reppoints_pts_refine_out.bias: copying a param with shape torch.Size([6]) from checkpoint, the shape in current model is torch.Size([9]).
问题分析:
在使用
model.load_state_dict(state_dict, strict=False)
暂且忽略掉模型和参数文件中不匹配的参数,先将正常匹配的参数从文件中载入模型。
然而当笔者已经写明
原因及解决方案
经过查阅资料之后,发现是这样的:
strict=False可以保证模型中的键与文件中的键不匹配时暂且跳过不管,但是一旦模型中的键和文件中的键匹配上了,PyTorch就会尝试帮我们加载参数,就必须要求参数的尺寸相同,所以会有上述报错。
比如在我们需要将某个预训练的模型的最后的全连接层的输出的类别数替换为我们自己的数据集的类别数,再进行微调,有时会遇到上述情况。这时,我们知道全连接层的参数形状会是不匹配,比如我们加载 ImageNet 1K 1000分类的预训练模型,它的最后一层全连接的输出维度是1000,但如果我们自己的数据集是10分类,我们需要将最后一层全链接的输出维度改为10。但是由于键名相同,所以PyTorch还是尝试给我们加载,这时1000和10维度不匹配,就会导致报错。
解决方案就是我们将 .pth 模型文件读入后,将其中我们不需要的层(通常是最后的全连接层)的参数pop掉即可。
以 笔者自己的任务为例子,假设我们有一个 车道线检测 模型,并有一个参数文件
将最后 pth 文件加载进来之后(即
pretrained_model = torch.load(model_dir) pretrained_model['net'].pop("module.neck.def_convs.0.reppoints_pts_init_out.weight") pretrained_model['net'].pop("module.neck.def_convs.0.reppoints_pts_init_out.bias") pretrained_model['net'].pop("module.neck.def_convs.0.reppoints_pts_refine_out.weight") pretrained_model['net'].pop("module.neck.def_convs.0.reppoints_pts_refine_out.bias") pretrained_model['net'].pop("module.neck.def_convs.1.reppoints_pts_init_out.weight") pretrained_model['net'].pop("module.neck.def_convs.1.reppoints_pts_init_out.bias") pretrained_model['net'].pop("module.neck.def_convs.1.reppoints_pts_refine_out.weight") pretrained_model['net'].pop("module.neck.def_convs.1.reppoints_pts_refine_out.bias") pretrained_model['net'].pop("module.neck.def_convs.2.reppoints_pts_init_out.weight") pretrained_model['net'].pop("module.neck.def_convs.2.reppoints_pts_init_out.bias") pretrained_model['net'].pop("module.neck.def_convs.2.reppoints_pts_refine_out.weight") pretrained_model['net'].pop("module.neck.def_convs.2.reppoints_pts_refine_out.bias") # 使用神经网络模型 net 的 load_state_dict 函数来加载预训练模型的权重。strict=False 参数表示允许加载预训练模型中的权重,即使它们不完全匹配当前神经网络模型的结构。因此你可以将预训练模型的权重加载到与其不完全匹配的模型中 net.load_state_dict(pretrained_model['net'], strict=False)
至此,模型就可以正常运行了。
即使缺失了weight 和 bias 这两个参数,这也是正常的,因为我们要对模型进行修改微调,本就不需要这两个参数,并且已经将它们从模型文件字典中pop掉了。现在,模型其他层的参数已经正常加载了,接下来可以微调自己的模型。
反正我们也不需要这些参数,就直接把这个键值对从字典中pop掉,以免 PyTorch 在帮我们加载时会出现加载这些维度不匹配的情况。