Source code for hyperion.torch.trainers.torch_trainer

"""
 Copyright 2019 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""

import os
import contextlib
from collections import OrderedDict as ODict
from enum import Enum
from jsonargparse import ArgumentParser, ActionParser
import logging
from pathlib import Path

import torch
import torch.nn as nn
import torch.cuda.amp as amp
from torch.optim.swa_utils import AveragedModel, SWALR
import torch.distributed as dist

from fairscale.optim.grad_scaler import ShardedGradScaler

from ..utils import MetricAcc, TorchDDP, FairShardedDDP, FairFullyShardedDDP
from ..loggers import LoggerList, CSVLogger, ProgLogger, TensorBoardLogger, WAndBLogger
from ..optim import OptimizerFactory as OF
from ..lr_schedulers import LRSchedulerFactory as LRSF
from ..lr_schedulers import LRScheduler as LRS


class DDPType(str, Enum):
    DDP = "ddp"
    OSS_DDP = "oss_ddp"
    OSS_SHARDED_DDP = "oss_sharded_ddp"
    FULLY_SHARDED_DDP = "fully_sharded_ddp"


ddp_choices = [o.value for o in DDPType]


[docs]class TorchTrainer(object): """Base Trainer class to train basic neural network models Attributes: model: model object. loss: nn.Module loss class optim: pytorch optimizer object or optimizer options dict epochs: max. number of epochs exp_path: experiment output path cur_epoch: current epoch grad_acc_steps: gradient accumulation steps to simulate larger batch size. device: cpu/gpu device metrics: extra metrics to compute besides cxe. lrsched: learning rate scheduler object loggers: LoggerList object, loggers write training progress to std. output and file. ddp: if True use distributed data parallel training ddp_type: type of distributed data parallel in (ddp, oss_ddp, oss_shared_ddp) train_mode: training mode in ['train', 'ft-full', 'ft-last-layer'] use_amp: uses mixed precision training. log_interval: number of optim. steps between log outputs use_tensorboard: use tensorboard logger use_wandb: use wandb logger wandb: wandb dictionary of options grad_clip: norm to clip gradients, if 0 there is no clipping grad_clip_norm: norm type to clip gradients swa_start: epoch to start doing swa swa_lr: SWA learning rate swa_anneal_epochs: SWA learning rate anneal epochs cpu_offload: CPU offload of gradients when using fully sharded ddp """
[docs] def __init__( self, model, loss, optim={}, epochs=100, exp_path="./train", cur_epoch=0, grad_acc_steps=1, device=None, metrics=None, lrsched=None, loggers=None, ddp=False, ddp_type="ddp", train_mode="train", use_amp=False, log_interval=10, use_tensorboard=False, use_wandb=False, wandb={}, grad_clip=0, grad_clip_norm=2, swa_start=0, swa_lr=1e-3, swa_anneal_epochs=10, cpu_offload=False, ): self.model = model # self.optimizer = optim self.loss = loss self.epochs = epochs self.cur_epoch = cur_epoch self.grad_acc_steps = grad_acc_steps self.exp_path = Path(exp_path) if loggers is None: self.loggers = self._default_loggers( log_interval, use_tensorboard, use_wandb, wandb ) elif isinstance(loggers, list): self.loggers = LoggerList(loggers) else: self.loggers = loggers # self.lr_scheduler = lr_scheduler self.metrics = metrics self.device = device self.train_mode = train_mode self.use_amp = use_amp self.grad_clip = grad_clip self.grad_clip_norm = grad_clip_norm self.swa_start = swa_start self.do_swa = swa_start > 0 self.swa_lr = swa_lr self.swa_anneal_epochs = swa_anneal_epochs self.amp_args = {} if device is not None: self.model.to(device) if loss is not None: self.loss.to(device) self.ddp = ddp self.ddp_type = ddp_type self.rank = 0 self.world_size = 1 if ddp: self.rank = dist.get_rank() self.world_size = dist.get_world_size() if ddp_type == DDPType.DDP or ddp_type == DDPType.OSS_DDP: self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model) if self.rank == 0: logging.info( "training in multiple gpus with distributed-data-parallel" ) oss = False if ddp_type == DDPType.DDP else True self.optimizer = self._make_optimizer(optim, self.model, oss=oss) self.model = TorchDDP( self.model, device_ids=[device], output_device=device ) elif ddp_type == DDPType.OSS_SHARDED_DDP: self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model) if self.rank == 0: logging.info( "training in multiple gpus with fair sharded-distributed-data-parallel" ) self.optimizer = self._make_optimizer(optim, self.model, oss=True) self.model = FairShardedDDP(self.model, self.optimizer) else: if self.rank == 0: logging.info( "training in multiple gpus with fair fully-sharded-distributed-data-parallel" ) # syncbathcnorm is not supported here, it raises exception self.model = FairFullyShardedDDP( self.model, mixed_precision=self.use_amp, cpu_offload=cpu_offload ) self.optimizer = self._make_optimizer(optim, self.model, oss=False) # if loss is not None: # self.loss = TorchDDP(self.loss, device_ids=[device]) # if self.use_amp: # self.amp_args = {'use_amp': self.use_amp } else: self.optimizer = self._make_optimizer(optim, self.model) # make the learning rate scheduler self.lr_scheduler = self._make_lr_sched(lrsched, self.optimizer) if self.use_amp: if ddp and ddp_type != DDPType.DDP: if self.rank == 0: logging.info( "using automatic mixed precision training with sharded-grad-scaler" ) self.grad_scaler = ShardedGradScaler() else: if self.rank == 0: logging.info( "using automatic mixed precision training with grad-scaler" ) self.grad_scaler = amp.GradScaler() self.amp_autocast = amp.autocast else: self.amp_autocast = contextlib.nullcontext self.in_swa = False if self.do_swa: if self.rank == 0: logging.info("init SWA model") self.swa_model = AveragedModel(self.model) self.swa_scheduler = SWALR( self.optimizer, swa_lr=self.swa_lr, anneal_epochs=self.swa_anneal_epochs )
[docs] def fit(self, train_data, val_data=None): """Training function, it performs the training and validation epochs Args: train_data: PyTorch data loader for the training loop val_data: PyTorch data loader for the validation loop """ self.exp_path.mkdir(parents=True, exist_ok=True) # if not os.path.exists(self.exp_path): # os.makedirs(self.exp_path) if self.do_swa and self.cur_epoch >= self.swa_start: self.in_swa = True val_logs = {} self.loggers.on_train_begin(epochs=self.epochs) for epoch in range(self.cur_epoch, self.epochs): self.loggers.on_epoch_begin(epoch, batches=len(train_data)) if self.lr_scheduler is not None: # this is needed by cosine scheduler epoch_updates = int(len(train_data) / self.grad_acc_steps) self.lr_scheduler.on_epoch_begin(epoch, epoch_updates=epoch_updates) logs = self.train_epoch(train_data) if val_data is not None: val_logs = self.validation_epoch(val_data) logs.update(val_logs) self.cur_epoch += 1 self.loggers.on_epoch_end(logs) if self.do_swa and self.cur_epoch >= self.swa_start: self.in_swa = True self.swa_model.update_parameters(self.model) self.swa_scheduler.step() else: if self.lr_scheduler is not None: self.lr_scheduler.on_epoch_end(logs) self.save_checkpoint(logs) if self.in_swa: self.loggers.on_epoch_begin(self.cur_epoch, batches=len(train_data)) self.model = self.swa_model.module logs = self.bn_update_epoch(train_data) if val_data is not None: val_logs = self.validation_epoch(val_data) logs.update(val_logs) self.cur_epoch += 1 self.loggers.on_epoch_end(logs) self.save_swa_model(logs)
[docs] def set_train_mode(self): if self.train_mode == "train": self.model.train() else: self.model.train_mode(self.train_mode)
[docs] def train_epoch(self, data_loader): """Training epoch loop Args: data_loader: PyTorch data loader return input/output pairs """ metric_acc = MetricAcc(device=self.device) batch_metrics = ODict() self.set_train_mode() for batch, (data, target) in enumerate(data_loader): self.loggers.on_batch_begin(batch) if batch % self.grad_acc_steps == 0: self.optimizer.zero_grad() data, target = data.to(self.device), target.to(self.device) batch_size = data.shape[0] with self.amp_autocast(): output = self.model(data) loss = self.loss(output, target).mean() / self.grad_acc_steps if self.use_amp: self.grad_scaler.scale(loss).backward() else: loss.backward() if (batch + 1) % self.grad_acc_steps == 0: if self.lr_scheduler is not None and not self.in_swa: self.lr_scheduler.on_opt_step() self.update_model() self._reduce_metric(loss) batch_metrics["loss"] = loss.item() * self.grad_acc_steps for k, metric in self.metrics.items(): batch_metrics[k] = metric(output, target) metric_acc.update(batch_metrics, batch_size) logs = metric_acc.metrics logs["lr"] = self._get_lr() self.loggers.on_batch_end(logs=logs, batch_size=batch_size) # total_batches += 1 logs = metric_acc.metrics logs = ODict(("train_" + k, v) for k, v in logs.items()) logs["lr"] = self._get_lr() return logs
[docs] def validation_epoch(self, data_loader, swa_update_bn=False): """Validation epoch loop Args: data_loader: PyTorch data loader return input/output pairs """ metric_acc = MetricAcc(self.device) batch_metrics = ODict() with torch.no_grad(): if swa_update_bn: log_tag = "train_" self.set_train_mode() else: log_tag = "val_" self.model.eval() for batch, (data, target) in enumerate(data_loader): data, target = data.to(self.device), target.to(self.device) batch_size = data.shape[0] with self.amp_autocast(): output = self.model(data, **self.amp_args) loss = self.loss(output, target) batch_metrics["loss"] = loss.mean().item() for k, metric in self.metrics.items(): batch_metrics[k] = metric(output, target) metric_acc.update(batch_metrics, batch_size) logs = metric_acc.metrics logs = ODict((log_tag + k, v) for k, v in logs.items()) return logs
[docs] def bn_update_epoch(self, data_loader): logs = self.validation_epoch(data_loader, swa_update_bn=True) logs["lr"] = self._get_lr() return logs
def _clip_grad_norm(self, model, optim, grad_clip, grad_clip_norm): if self.ddp: if self.ddp_type == DDPType.DDP: nn.utils.clip_grad_norm_( model.parameters(), grad_clip, norm_type=grad_clip_norm ) return if self.ddp_type == DDPType.FULLY_SHARDED_DDP: # we have to use the member function in FullyShardedDDP class model.clip_grad_norm_(grad_clip, norm_type=grad_clip_norm) return else: # not sure about this but it looks like # we have to use the member function in the OSS optimizer wrapper optim.clip_grad_norm(grad_clip, norm_type=grad_clip_norm) # if no DDP clip normally nn.utils.clip_grad_norm_( model.parameters(), grad_clip, norm_type=grad_clip_norm )
[docs] def update_model(self): if self.use_amp: if self.grad_clip > 0: self.grad_scaler.unscale_(self.optimizer) self._clip_grad_norm( self.model, self.optimizer, self.grad_clip, self.grad_clip_norm ) self.grad_scaler.step(self.optimizer) self.grad_scaler.update() else: if self.grad_clip > 0: self._clip_grad_norm( self.model, self.optimizer, self.grad_clip, self.grad_clip_norm ) self.optimizer.step()
def _make_optimizer(self, optim, model, oss=False): if isinstance(optim, torch.optim.Optimizer): return optim assert isinstance(optim, dict) opt_args = OF.filter_args(**optim) opt_args["oss"] = oss if self.rank == 0: logging.info("optimizer args={}".format(opt_args)) optimizer = OF.create(model.parameters(), **opt_args) return optimizer def _make_lr_sched(self, lr_sched, optim): if lr_sched is None or isinstance(lr_sched, LRS): return lr_sched assert isinstance(lr_sched, dict) args = LRSF.filter_args(**lr_sched) if self.rank == 0: logging.info("lr scheduler args={}".format(args)) lr_sched = LRSF.create(optim, **args) return lr_sched
[docs] def _default_loggers(self, log_interval, use_tensorboard, use_wandb, wandb): """Creates the default data loaders""" prog_log = ProgLogger(interval=log_interval) csv_log = CSVLogger(self.exp_path / "train.log", append=True) loggers = [prog_log, csv_log] if use_tensorboard: loggers.append( TensorBoardLogger(self.exp_path / "tb", interval=log_interval) ) if use_wandb: loggers.append( WAndBLogger( **wandb, path=self.exp_path / "wandb", interval=log_interval ) ) return LoggerList(loggers)
[docs] def _get_lr(self): """Returns the current learning rate to show in the loggers""" for param_group in self.optimizer.param_groups: return param_group["lr"]
[docs] def checkpoint(self, logs=None): """Creates a checkpoint of the training, to save and posterior recovery Args: logs: logs containing the current value of the metrics. """ checkpoint = { "epoch": self.cur_epoch, "rng_state": torch.get_rng_state(), "model_cfg": self.model.get_config(), "model_state_dict": self.model.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(), "loss_state_dict": self.loss.state_dict() if self.loss is not None else None, } if self.lr_scheduler is not None: checkpoint["lr_scheduler_state_dict"] = self.lr_scheduler.state_dict() if logs is not None: checkpoint["logs"] = logs if self.in_swa: checkpoint["swa_model_state_dict"] = self.swa_model.state_dict() checkpoint["swa_scheduler_state_dict"] = self.swa_scheduler.state_dict() return checkpoint
[docs] def save_checkpoint(self, logs=None): """Saves a checkpoint of the training status Args: logs: logs containing the current value of the metrics. """ if self.ddp and ( self.ddp_type == DDPType.OSS_DDP or self.ddp_type == DDPType.OSS_SHARDED_DDP ): # Not sure what this does, just copying from the example in # https://github.com/facebookresearch/fairscale/blob/master/benchmarks/oss.py # Check the checkpointing in the case of the OSS optimizer # Memory usage could spill over from there # optimizer = cast(OSS, optimizer) self.optimizer.consolidate_state_dict() if self.rank != 0: return checkpoint = self.checkpoint(logs) file_path = "%s/model_ep%04d.pth" % (self.exp_path, self.cur_epoch) torch.save(checkpoint, file_path)
[docs] def save_swa_model(self, logs=None): """Saves a checkpoint of the training status Args: logs: logs containing the current value of the metrics. """ if self.rank != 0: return checkpoint = self.checkpoint(logs) checkpoint["model_state_dict"] = checkpoint["swa_model_state_dict"] del checkpoint["swa_model_state_dict"] file_path = "%s/swa_model_ep%04d.pth" % (self.exp_path, self.cur_epoch) torch.save(checkpoint, file_path)
[docs] def load_checkpoint(self, file_path): """Loads a training checkpoint from file. Args: file_path: checkpoint file path """ checkpoint = torch.load(file_path, map_location=torch.device("cpu")) rng_state = checkpoint["rng_state"] torch.set_rng_state(rng_state) if self.rank > 0: # this will make sure that each process produces different data # when using ddp dummy = torch.rand(1000 * self.rank) del dummy self.cur_epoch = checkpoint["epoch"] try: self.model.load_state_dict(checkpoint["model_state_dict"]) except: self.model.module.load_state_dict(checkpoint["model_state_dict"]) self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) if self.loss is not None: self.loss.load_state_dict(checkpoint["loss_state_dict"]) if self.lr_scheduler is not None: self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"]) # if self.use_amp: # amp.load_state_dict(checkpoint['amp']) if self.do_swa: if "swa_model_state_dict" in checkpoint: self.swa_model.load_state_dict(checkpoint["swa_model_state_dict"]) self.swa_scheduler.load_state_dict( checkpoint["swa_scheduler_state_dict"] ) else: self.swa_scheduler = SWALR( self.optimizer, swa_lr=self.swa_lr, anneal_epochs=self.swa_anneal_epochs, ) logs = None if "logs" in checkpoint: logs = checkpoint["logs"] del checkpoint if self.device is not None: torch.cuda.empty_cache() return logs
[docs] def load_last_checkpoint(self): """Loads the last training checkpoint in the experiment dir.""" for epoch in range(self.epochs, 0, -1): file_path = "%s/model_ep%04d.pth" % (self.exp_path, epoch) if os.path.isfile(file_path): return self.load_checkpoint(file_path) return None
[docs] @staticmethod def filter_args(**kwargs): valid_args = ( "grad_acc_steps", "epochs", "log_interval", "use_amp", "ddp_type", "grad_clip", "swa_start", "swa_lr", "swa_anneal_epochs", "exp_path", "optim", "lrsched", "cpu_offload", "use_tensorboard", "use_wandb", "wandb", ) args = dict((k, kwargs[k]) for k in valid_args if k in kwargs) return args
[docs] @staticmethod def add_class_args(parser, prefix=None, skip=[]): if prefix is not None: outer_parser = parser parser = ArgumentParser(prog="") if "optim" not in skip: OF.add_class_args(parser, prefix="optim") if "lrsched" not in skip: LRSF.add_class_args(parser, prefix="lrsched") parser.add_argument( "--grad-acc-steps", type=int, default=1, help="gradient accumulation batches before weigth update", ) parser.add_argument("--epochs", type=int, default=200, help="number of epochs") parser.add_argument( "--log-interval", type=int, default=10, help="how many batches to wait before logging training status", ) parser.add_argument( "--use-tensorboard", action="store_true", default=False, help="use tensorboard logger", ) parser.add_argument( "--use-wandb", action="store_true", default=False, help="use wandb logger" ) parser.add_argument("--wandb.project", default=None, help="wandb project name") parser.add_argument("--wandb.group", default=None, help="wandb group name") parser.add_argument("--wandb.name", default=None, help="wandb display name") # parser.add_argument( # '--wandb.path', default=None, # help='wandb directory') parser.add_argument( "--wandb.mode", default="online", choices=["online", "offline"], help="wandb mode (online, offline)", ) parser.add_argument( "--ddp-type", default="ddp", choices=ddp_choices, help="DDP type in {}".format(ddp_choices), ) parser.add_argument( "--use-amp", action="store_true", default=False, help="use mixed precision training", ) parser.add_argument( "--cpu-offload", action="store_true", default=False, help="CPU offload of gradients when using fully_sharded_ddp", ) parser.add_argument( "--grad-clip", type=float, default=0, help="gradient clipping norm value" ) parser.add_argument( "--grad-clip-norm", default=2, choices=["inf", 1, 2], help="gradient clipping norm type", ) parser.add_argument( "--swa-start", type=int, default=0, help="start epoch for SWA, if 0 it does not use SWA", ) parser.add_argument( "--swa-lr", type=float, default=1e-3, help="learning rate for SWA phase" ) parser.add_argument( "--swa-anneal-epochs", type=int, default=10, help="SWA learning rate anneal epochs", ) parser.add_argument("--exp-path", help="experiment path") if prefix is not None: outer_parser.add_argument("--" + prefix, action=ActionParser(parser=parser))
# help='trainer options') add_argparse_args = add_class_args