假设有一个训练好的模型,并且我们只想微调部分参数。
比如,这里我们只想更新最后一部分的参数:
可以看到,这里的模块叫b4。
我们可以直接通过获取模块的名字来进行更新:
def update(model,flag=True): for name,p in model.named_parameters(): if "b4" in name: print("update only",name) p.requires_grad = flag
也就是说 只要模块名字包含b4 就会让他跟新网络。
对应的optimizer 的设置如下:
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr_)
然后直接训练就行。
方法二也可以直接 把这些符合条件的 parameters 加入 list中,并传给 optimizer
def update(model,flag=True): paras = [] for name,p in model.named_parameters(): if "b4" in name: print("update only",name) p.requires_grad = flag paras.append(p) return paras
optimizer = torch.optim.Adam(paras, lr=lr_)
直接训练就行。##