pytorch冻结模型某层参数
Published on Aug. 22, 2023, 12:10 p.m.
pytorch如何冻结某层参数
bert模型12层,参数量达一亿,bert模型做微调有的时候就需要只训练部分参数,那么就需要把其他的参数冻结掉就可以
import torch.nn as nn
from transformers import BertModel
import torch
model = BertModel()
# 这里是一般情况,冻结多层
for para in model.parameters():
para.requires_grad = False
# 冻结某一层
model.linear1.weight.requires_grad = False
设置优化器
做训练的时候,优化器中一定要添加过滤器filter把requires_grad = False的参数过滤掉,这样就不会更新这些参数了。
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=0.1)
# 多个层分别设置
# 另外对于多层可以设置不同的学习率
#好多人都说crf建议设置很大的学习率。
optimizer = getattr(torch.optim,optimizer_name)([
{'params': filter(lambda p: p.requires_grad, or_model.parameters()), 'lr': learning_rate},
{'params': model.parameters(), 'lr': learning_rate}]
)