资讯 小学 初中 高中 语言 会计职称 学历提升 法考 计算机考试 医护考试 建工考试 教育百科
栏目分类:
子分类:
返回
空麓网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
空麓网 > 计算机考试 > 软件开发 > 后端开发 > Python

Pytorch只更新预训练模型的部分参数

Python 更新时间: 发布时间: 计算机考试归档 最新发布

Pytorch只更新预训练模型的部分参数

Pytorch只更新预训练模型的部分参数

假设有一个训练好的模型,并且我们只想微调部分参数。
比如,这里我们只想更新最后一部分的参数:
可以看到,这里的模块叫b4。


我们可以直接通过获取模块的名字来进行更新:

方法1
def update(model,flag=True):
    for name,p in model.named_parameters():
        if "b4" in name:
            print("update only",name)
            p.requires_grad = flag

也就是说 只要模块名字包含b4 就会让他跟新网络。
对应的optimizer 的设置如下:

optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr_)

然后直接训练就行。

方法二

也可以直接 把这些符合条件的 parameters 加入 list中,并传给 optimizer

def update(model,flag=True):
    paras = []
    for name,p in model.named_parameters():
        if "b4" in name:
            print("update only",name)
            p.requires_grad = flag
            paras.append(p)
    return paras
optimizer = torch.optim.Adam(paras, lr=lr_)

直接训练就行。##

转载请注明:文章转载自 http://www.konglu.com/
本文地址:http://www.konglu.com/it/990193.html
免责声明:

我们致力于保护作者版权,注重分享,被刊用文章【Pytorch只更新预训练模型的部分参数】因无法核实真实出处,未能及时与作者取得联系,或有版权异议的,请联系管理员,我们会立即处理,本文部分文字与图片资源来自于网络,转载此文是出于传递更多信息之目的,若有来源标注错误或侵犯了您的合法权益,请立即通知我们,情况属实,我们会第一时间予以删除,并同时向您表示歉意,谢谢!

我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2023 成都空麓科技有限公司

ICP备案号:蜀ICP备2023000828号-2