"""
Copyright 2019 Johns Hopkins University (Author: Jesus Villalba)
Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
import os
from collections import OrderedDict as ODict
import logging
import torch
import torch.nn as nn
from ..utils import MetricAcc
from .torch_trainer import TorchTrainer
[docs]class XVectorTrainer(TorchTrainer):
"""Trainer to train x-vector style models.
Attributes:
model: x-Vector model object.
optim: pytorch optimizer object or options dict
epochs: max. number of epochs
exp_path: experiment output path
cur_epoch: current epoch
grad_acc_steps: gradient accumulation steps to simulate larger batch size.
device: cpu/gpu device
metrics: extra metrics to compute besides cxe.
lrsched: learning rate scheduler object or options dict
loggers: LoggerList object, loggers write training progress to std. output and file.
If None, it uses default loggers.
ddp: if True use distributed data parallel training
ddp_type: type of distributed data parallel in (ddp, oss_ddp, oss_shared_ddp)
loss: if None, it uses cross-entropy
train_mode: training mode in ['train', 'ft-full', 'ft-last-layer']
use_amp: uses mixed precision training.
log_interval: number of optim. steps between log outputs
use_tensorboard: use tensorboard logger
use_wandb: use wandb logger
wandb: wandb dictionary of options
grad_clip: norm to clip gradients, if 0 there is no clipping
grad_clip_norm: norm type to clip gradients
swa_start: epoch to start doing swa
swa_lr: SWA learning rate
swa_anneal_epochs: SWA learning rate anneal epochs
cpu_offload: CPU offload of gradients when using fully sharded ddp
"""
[docs] def __init__(
self,
model,
optim={},
epochs=100,
exp_path="./train",
cur_epoch=0,
grad_acc_steps=1,
device=None,
metrics=None,
lrsched=None,
loggers=None,
ddp=False,
ddp_type="ddp",
loss=None,
train_mode="train",
use_amp=False,
log_interval=10,
use_tensorboard=False,
use_wandb=False,
wandb={},
grad_clip=0,
grad_clip_norm=2,
swa_start=0,
swa_lr=1e-3,
swa_anneal_epochs=10,
cpu_offload=False,
):
if loss is None:
loss = nn.CrossEntropyLoss()
super().__init__(
model,
loss,
optim,
epochs,
exp_path,
cur_epoch=cur_epoch,
grad_acc_steps=grad_acc_steps,
device=device,
metrics=metrics,
lrsched=lrsched,
loggers=loggers,
ddp=ddp,
ddp_type=ddp_type,
train_mode=train_mode,
use_amp=use_amp,
log_interval=log_interval,
use_tensorboard=use_tensorboard,
use_wandb=use_wandb,
wandb=wandb,
grad_clip=grad_clip,
grad_clip_norm=grad_clip_norm,
swa_start=swa_start,
swa_lr=swa_lr,
swa_anneal_epochs=swa_anneal_epochs,
cpu_offload=cpu_offload,
)
[docs] def train_epoch(self, data_loader):
"""Training epoch loop
Args:
data_loader: pytorch data loader returning features and class labels.
"""
self.model.update_loss_margin(self.cur_epoch)
metric_acc = MetricAcc(device=self.device)
batch_metrics = ODict()
self.set_train_mode()
for batch, (data, target) in enumerate(data_loader):
self.loggers.on_batch_begin(batch)
if batch % self.grad_acc_steps == 0:
self.optimizer.zero_grad()
data, target = data.to(self.device), target.to(self.device)
batch_size = data.shape[0]
with self.amp_autocast():
output = self.model(data, target, **self.amp_args)
loss = self.loss(output, target).mean() / self.grad_acc_steps
if self.use_amp:
self.grad_scaler.scale(loss).backward()
else:
loss.backward()
if (batch + 1) % self.grad_acc_steps == 0:
if self.lr_scheduler is not None and not self.in_swa:
self.lr_scheduler.on_opt_step()
self.update_model()
batch_metrics["loss"] = loss.item() * self.grad_acc_steps
for k, metric in self.metrics.items():
batch_metrics[k] = metric(output, target)
metric_acc.update(batch_metrics, batch_size)
logs = metric_acc.metrics
logs["lr"] = self._get_lr()
self.loggers.on_batch_end(logs=logs, batch_size=batch_size)
logs = metric_acc.metrics
logs = ODict(("train_" + k, v) for k, v in logs.items())
logs["lr"] = self._get_lr()
return logs