资讯 小学 初中 高中 语言 会计职称 学历提升 法考 计算机考试 医护考试 建工考试 教育百科
栏目分类:
子分类:
返回
空麓网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
空麓网 > 计算机考试 > 软件开发 > 后端开发 > Python

Pytorch高级训练框架Ignite详细介绍与常用模版

Python 更新时间: 发布时间: 计算机考试归档 最新发布

Pytorch高级训练框架Ignite详细介绍与常用模版

引言

Ignite是Pytorch配套的高级框架,我们可以借其构筑一套标准化的训练流程,规范训练器在每个循环、轮次中的行为。本文将不再赘述Ignite的具体细节或者API,详见官方教程和其他博文。本文将分析Ignite的运行机制、如何将Pytorch训练代码转为Ignite范式,最后给出个人设计的标准化Ignite训练模版。

Ignite简介

 Ignite所做的事情就是我们在pytorch里常写的范式用更加机械、更加标注格式展现出来,这也就是为啥其核心被称为–Engine,高效而精密。Pytorh里常用的训练范式如下:

for ep in Epoch:	for batch in train_loader:	    model.train()	    inputs, targets = batch	    optimizer.zero_grad()	    outputs = model(inputs)	    loss = criterion(outputs, targets)	    loss.backward()	    optimizer.step()	    		if it%log_period:			print()	if ep%save_perid:		torch.save()

具体而言,可以拆解为批训练、批完结处理、轮次完结处理三个组成部分,批训练部分是网络训练的基础单元,完成数据当前批次读取、前向传播、反向传播等步骤,批完结处理负责在每个批次结束后输出模型训练的相关信息,轮次完结处理负责在每个epoch结束进行模型的保存、对模型的训练参数进行更新。这三个模型训练的主要组成部分在ignite中得到了完整的封装,围绕批训练构造了一个核心的Engine,将批完结处理和轮次完结处理附加该Engine运行的时间轴中,形成了批训练->批完结处理->轮次完结处理的流水线作业范式,更为详细的时间轴如下1

以下将从实用性的角度出发给出Ignite的建设框架,最终给出个人设计的Ignite使用模版,后续直接在train.py文件里直接调用do_train()函数即可利用Ignite进行模型训练。为讲解需要,中间每个子部分的代码为最终代码中相应部分重新排序得到,最终代码中其顺序会进行调整。

批训练

 批训练的代码较为简单,只需将原本的Pytorch版本批处理流程复制粘贴,最后将该过程函数化,并且实例化成Engine即可,代码如下所示,最终启动Engine,即可进行模型的训练。到此为止,实际上已经完成了狭义上的“模型”训练部分。

 def create_supervised_trainer(model,optimizer,criterion,                              device=None, non_blocking=False,                              prepare_batch=_prepare_batch,                              output_transform=lambda x, y, y_pred, loss: loss.item()):      """      有监督模型的Engine创建      Args:          model (`torch.nn.Module`):          optimizer (`torch.optim.Optimizer`):          loss_fn (torch.nn loss function):          device (str, optional):          non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously              with respect to the host. For other cases, this argument has no effect.          prepare_batch (callable, optional): 批处理函数,对dataloader的输出进行处理          output_transform (callable, optional): 输出变换函数,设定输出,默认情况下,输入为x,y,y_pred,loss,输出loss.item()      Note: engine在每个batch下的最终输出由transform所指定,默认传回loss.item      Returns:          Engine: 有监督任务的engine实例      """      if device:          model.to(device)      def _update(engine, batch):          model.train()          optimizer.zero_grad()          x,y = prepare_batch(batch, device=device, non_blocking=non_blocking)          output = model(x)          loss=criterion(output,y)          loss.backward()          optimizer.step()          return output_transform(x, y, None, loss)      return Engine(_update)trainer=create_supervised_trainer(model,optimizer,criterion,device)  # 建立ignite的enginetrainer.run(train_loader,max_epochs=cfg['max_epochs'])

批完结处理

 批完结处理部分我们常做的操作是输出模型在当前批的损失,Ignite中这一过程通过在Engine上附着于ITERAION_COMPLETEDE时触发的回调函数实现。实际上这只是限定了触发时间,具体进行何种操作,完全依赖于个人的选择。我们只需要知道该函数可以利用engine保留的当前批属性信息进行各种操作即可,具体可以利用哪些属性,见官方API2,本文只利用了常用的几个。

    ##########################################################################################    ###########                    Events.ITERATION_COMPLETED                    #############    ##########################################################################################    @trainer.on(Events.ITERATION_COMPLETED)    def log_training_loss(engine):        """        隔一定iteration输出模型损失        """        log_period=int(cfg["log_period"]*len(train_loader))  # 跑了log_period*len输出一次,取值<=1        if engine.state.iteration%log_period==0:            pbar.write(f"Epoch {engine.state.epoch}, iter {engine.state.iteration}: Loss {engine.state.output:.2f}")            pbar.update(log_period)    @trainer.on(Events.ITERATION_COMPLETED)    def scheduler_update(engine):        """        optional 每个ITER更新学习率        """        scheduler.step()

轮次完结处理

  轮次完结处理和批完结处理相同,也是通过回调函数实现,我们通常在轮次完结处理要进行模型的保存,这里就要做两件事:

  1. 在val_loader上验证模型效果
  2. 保留迄今为止效果最好的模型

针对于第一个要求,这里我同样采用了Ignite风格的Engine驱动范式,读者可以自行选择在这里切换为Pytorch范式,构建验证集Engine的代码如下:

    def create_supervised_evaluator(model, metric,                                device=None, non_blocking=False,                                prepare_batch=_prepare_batch,                                output_transform=lambda x, y, y_pred: (y_pred,y)):        """        构造evaluator        :param model:        :param metric: dict,key为metric名字,value为Metric类        :param device:        :param non_blocking:        :param prepare_batch:        :param output_transform:        :return:        """        if device:            model.to(device)        def _inference(engine, batch):            model.eval()            with torch.no_grad:                x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)                output = model(x)            return output_transform(x, y, output)        engine=Engine(_inference)		# 附着metric        for name, metric in metric.items():            metric.attach(engine,name)        return engine

可以看到和trainer较为不同的点在于去除了opt等等选项,此外,由于保存模型时我们要依据验证集上的metric来判断是否要保存当前模型还是沿用此前的模型,因此额外将一个Metric类附着在了Engine上,它使得模型可以自动收集eval_engine每个轮次的输出,并进行metric的计算,ignite中提供了许多metric选项3,这里笔者给出自己定制mertric的范式如下,主要由reset(),update()和commpute()组成,reset()完成每个epoch的记录状态重置,update()则接受某一批次engine的输出值,commpute()完成最终的metric计算。值得一提的是,在trainer上我们并没有额外附着Loss类,而是直接用engine输出了loss,实际上或许你也可以用相同的方式对eval_engine进行处理。

class CustomMetric(Metric):    def __init__(self):        super(CustomMetric,self).__init__()    def reset(self) -> None:        self._num_correct=0        self._num_examples=0    def update(self, output) -> None:        '''        保存该轮次的输出        :param output: 每个batch engine的输出        :return:        '''        pred,label=output        pred=pred.detach()        label=label.detach()        indices=torch.argmax(pred,dim=1)        correct=torch.eq(indices,label).view(-1)        self._num_correct += torch.sum(correct).item()        self._num_examples += correct.shape[0]    def compute(self):        '''        计算总ACC        :return:        '''        return self._num_correct/self._num_examples

 完成第二步的方法Ignite同样进行了封装,即Checkpoint类4,但笔者也进行了自己的定制化,如下:

class BestCheckPoint():    def __init__(self,save_path,n_saved,model_name):        '''        建立存档点类        :param save_path: 存档点保存路径        :param n_saved:  保留的存档点数目        '''        self.save_path=save_path        self.n_save=n_saved        self.model_name=model_name        self.score=[]        if not os.path.exists(save_path):            os.mkdir(self.save_path)    def update(self,score):        '''        更新最优记录        :param score: 当前模型的metric        :return:        '''        if type(self.score)==torch.Tensor:            score=score.item()        if len(self.score)value:                self.score.remove(value)                self.score.append(score)                self.score.sort()                return value            else:                return False    def save(self,score,model):        '''        视当前得分判断是否保存当前模型并删除        :param score: 当前模型得分        :param model: 模型        :return:        '''        is_save=self.update(score)        if is_save:            torch.save(model.state_dict(), os.path.join(self.save_path, self.model_name + f"_{score:.4f}.pth"))            # pop的存档要删除            if not isinstance(is_save,bool):                # 似乎ignite存在并行机制,单步运行的时候没问题,多步就会发生早就remove的错误,可以通过保存每一次更新后的score验证                try:                    os.remove(os.path.join(self.save_path,self.model_name+f"_{is_save:.4f}.pth"))                except:                    print("already removed")

主要就是设置了一个metric池,新的metric进来后判断是否优于池子里最烂的模型,并以此判断是否进行保存。将这个CheckPoint类实例化,并且在trainer每个Epoch完成后遍历eval_engine得到当前模型在验证集上的metric,对其进行更新即可完成模型的保存,代码如下:

    evaluator=create_supervised_evaluator(model,{"ACC":CustomMetric()})    CP=BestCheckPoint(cfg['save_path'],cfg['n_saved'],cfg['model_name'])   ##########################################################################################    ###########                    Events.EPOCH_COMPLETED                    #############    ##########################################################################################    @trainer.on(Events.EPOCH_COMPLETED)    def save_model(engine):        '''        保存模型        :param engine:        :return:        '''        if engine.state.epoch % cfg['save_period']==0:            evaluator.run(val_loader)            metrics=evaluator.state.metrics            print(f"Validation Results - Epoch[{trainer.state.epoch}] Avg accuracy: {metrics['Acc']:.2f}")            CP.save(metrics['ACC'],model)

实际上这里还附加了判断,每隔 cfg['save_period']个轮次才进行验证集上的评估和模型保存。

运行框架

 将上述模块封装在一起,我们就可以得到了最终的ignite运行框架,而后只需导入该文件,并运行其中的do_train()函数即可轻松完成模型训练,其中为了方便模型进程的可视化,使用了pbar模块来进行显示,pbar在固定iter次后输出当前训练信息并更新进度条,在epoch完成后重置,同样通过回调函数的形式附加在了trainer上。

整体代码:

# -*- coding: utf-8 -*-# ---# @File: trainer.py# @Author: sgdy3# @E-mail: sgdy03@163.com# @Time: 2023/5/9 19:44# Describe: # ---import osfrom tqdm import tqdmimport igniteimport torchfrom ignite.engine import Enginefrom ignite.utils import convert_tensorfrom ignite.engine.engine import Engine, State, Eventsfrom ignite.engine import create_supervised_evaluatorfrom ignite.metrics import Metric,Accuracyclass BestCheckPoint():    def __init__(self,save_path,n_saved,model_name):        '''        建立存档点类        :param save_path: 存档点保存路径        :param n_saved:  保留的存档点数目        '''        self.save_path=save_path        self.n_save=n_saved        self.model_name=model_name        self.score=[]        if not os.path.exists(save_path):            os.mkdir(self.save_path)    def update(self,score):        '''        更新最优记录        :param score: 当前模型的metric        :return:        '''        if type(self.score)==torch.Tensor:            score=score.item()        if len(self.score)value:                self.score.remove(value)                self.score.append(score)                self.score.sort()                return value            else:                return False    def save(self,score,model):        '''        视当前得分判断是否保存当前模型并删除        :param score: 当前模型得分        :param model: 模型        :return:        '''        is_save=self.update(score)        if is_save:            torch.save(model.state_dict(), os.path.join(self.save_path, self.model_name + f"_{score:.4f}.pth"))            # pop的存档要删除            if not isinstance(is_save,bool):                # 似乎ignite存在并行机制,单步运行的时候没问题,多步就会发生早就remove的错误,可以通过保存每一次更新后的score验证                try:                    os.remove(os.path.join(self.save_path,self.model_name+f"_{is_save:.4f}.pth"))                except:                    print("already removed")class CustomMetric(Metric):    def __init__(self):        super(CustomMetric,self).__init__()    def reset(self) -> None:        self._num_correct=0        self._num_examples=0    def update(self, output) -> None:        '''        保存该轮次的输出        :param output: 每个batch engine的输出        :return:        '''        pred,label=output        pred=pred.detach()        label=label.detach()        indices=torch.argmax(pred,dim=1)        correct=torch.eq(indices,label).view(-1)        self._num_correct += torch.sum(correct).item()        self._num_examples += correct.shape[0]    def compute(self):        '''        计算总ACC        :return:        '''        return self._num_correct/self._num_examplesdef do_train(model,optimizer,criterion,scheduler,device,train_loader,val_loader,cfg):    def _prepare_batch(batch, device=None, non_blocking=False):        """        对dataloader每个batch的输出进行进一步的处理        :param batch: dataloader输出        :param device:        :param non_blocking:        :return:        """        device = "cuda:" + str(device)        x, y = batch        x = convert_tensor(x,device=device,non_blocking=non_blocking)        y = convert_tensor(y,device=device,non_blocking=non_blocking)        return x,y    def create_supervised_trainer(model,optimizer,criterion,                                device=None, non_blocking=False,                                prepare_batch=_prepare_batch,                                output_transform=lambda x, y, y_pred, loss: loss.item()):        """        有监督模型的Engine创建        Args:            model (`torch.nn.Module`):            optimizer (`torch.optim.Optimizer`):            loss_fn (torch.nn loss function):            device (str, optional):            non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously                with respect to the host. For other cases, this argument has no effect.            prepare_batch (callable, optional): 批处理函数,对dataloader的输出进行处理            output_transform (callable, optional): 输出变换函数,设定输出,默认情况下,输入为x,y,y_pred,loss,输出loss.item()        Note: engine在每个batch下的最终输出由transform所指定,默认传回loss.item        Returns:            Engine: 有监督任务的engine实例        """        if device:            model.to(device)        def _update(engine, batch):            model.train()            optimizer.zero_grad()            x,y = prepare_batch(batch, device=device, non_blocking=non_blocking)            output = model(x)            loss=criterion(output,y)            loss.backward()            optimizer.step()            return output_transform(x, y, None, loss)        return Engine(_update)    def create_supervised_evaluator(model, metric,                                device=None, non_blocking=False,                                prepare_batch=_prepare_batch,                                output_transform=lambda x, y, y_pred: (y_pred,y)):        """        构造evaluator        :param model:        :param metric: dict,key为metric名字,value为Metric类        :param device:        :param non_blocking:        :param prepare_batch:        :param output_transform:        :return:        """        if device:            model.to(device)        def _inference(engine, batch):            model.eval()            with torch.no_grad:                x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)                output = model(x)            return output_transform(x, y, output)        engine=Engine(_inference)        for name, metric in metric.items():            metric.attach(engine,name)        return engine    trainer=create_supervised_trainer(model,optimizer,criterion,device)  # 建立ignite的engine    evaluator=create_supervised_evaluator(model,{"ACC":CustomMetric()})    CP=BestCheckPoint(cfg['save_path'],cfg['n_saved'],cfg['model_name'])    pbar=tqdm(total=len(train_loader))  # 为训练器迭代器建立进度条    ##########################################################################################    ###########                    Events.ITERATION_COMPLETED                    #############    ##########################################################################################    @trainer.on(Events.ITERATION_COMPLETED)    def log_training_loss(engine):        """        隔一定iteration输出模型损失        """        log_period=cfg["log_period"]        if engine.state.iteration%log_period==0:            pbar.write(f"Epoch {engine.state.epoch}, iter {engine.state.iteration}: Loss {engine.state.metrics['avg_loss']:.2f}")            pbar.update(log_period)    @trainer.on(Events.ITERATION_COMPLETED)    def scheduler_update(engine):        """        optional 每个ITER更新学习率        """        scheduler.step()    ##########################################################################################    ###########                    Events.EPOCH_COMPLETED                    #############    ##########################################################################################    @trainer.on(Events.EPOCH_COMPLETED)    def save_model(engine):        '''        保存模型        :param engine:        :return:        '''        if engine.state.epoch % cfg['save_period']==0:            evaluator.run(val_loader)            metrics=evaluator.state.metrics            print(f"Validation Results - Epoch[{trainer.state.epoch}] Avg accuracy: {metrics['Acc']:.2f}")            CP.save(metrics['ACC'],model)    @trainer.on(Events.EPOCH_COMPLETED)    def reset_bar(engine):        '''        重置进度条        :param engine:        :return:        '''        pbar.reset()    ##########################################################################################    #################                    training Start                    ###################    ##########################################################################################    trainer.run(train_loader,max_epochs=cfg['max_epochs'])    pbar.close()

参考


  1. Events | Pytorch-Ignite ↩︎

  2. State | Ignite ↩︎

  3. IGNITE.METRICS ↩︎

  4. CHECKPOINT ↩︎

转载请注明:文章转载自 http://www.konglu.com/
本文地址:http://www.konglu.com/it/1096014.html
免责声明:

我们致力于保护作者版权,注重分享,被刊用文章【Pytorch高级训练框架Ignite详细介绍与常用模版】因无法核实真实出处,未能及时与作者取得联系,或有版权异议的,请联系管理员,我们会立即处理,本文部分文字与图片资源来自于网络,转载此文是出于传递更多信息之目的,若有来源标注错误或侵犯了您的合法权益,请立即通知我们,情况属实,我们会第一时间予以删除,并同时向您表示歉意,谢谢!

我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2023 成都空麓科技有限公司

ICP备案号:蜀ICP备2023000828号-2