pytorch冻结模型某层参数

Published on Aug. 22, 2023, 12:10 p.m.

pytorch如何冻结某层参数

bert模型12层,参数量达一亿,bert模型做微调有的时候就需要只训练部分参数,那么就需要把其他的参数冻结掉就可以

import torch.nn as nn
from transformers import BertModel
import torch

model = BertModel()
# 这里是一般情况,冻结多层
for para in model.parameters():
    para.requires_grad = False

# 冻结某一层
model.linear1.weight.requires_grad = False

设置优化器

做训练的时候,优化器中一定要添加过滤器filter把requires_grad = False的参数过滤掉,这样就不会更新这些参数了。

optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=0.1)

# 多个层分别设置
# 另外对于多层可以设置不同的学习率
#好多人都说crf建议设置很大的学习率。

optimizer = getattr(torch.optim,optimizer_name)([
              {'params': filter(lambda p: p.requires_grad, or_model.parameters()), 'lr': learning_rate},
              {'params': model.parameters(), 'lr': learning_rate}]
        )