Source code for hyperion.torch.optim.factory

"""
 Copyright 2019 Johns Hopkins University  (Author: Jesus Villalba)
 Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""
from jsonargparse import ArgumentParser, ActionParser
import logging

from ...utils.misc import filter_args

import torch
import torch.optim as optim
from .radam import RAdam


[docs]class OptimizerFactory(object):
[docs] @staticmethod def create( params, opt_type, lr, momentum=0, beta1=0.9, beta2=0.99, rho=0.9, eps=1e-8, weight_decay=0, amsgrad=False, nesterov=False, lambd=0.0001, asgd_alpha=0.75, t0=1000000.0, rmsprop_alpha=0.99, centered=False, lr_decay=0, init_acc_val=0, max_iter=20, oss=False, ): kwargs = locals() base_opt = None if opt_type == "sgd": valid_args = ("lr", "momentum", "weight_decay", "nesterov") opt_args = filter_args(valid_args, kwargs) opt_args["dampening"] = 0 base_opt = optim.SGD # return optim.SGD(params, lr, momentum=momentum, dampening=0, # weight_decay=weight_decay, nesterov=nesterov) if opt_type == "adam": betas = (beta1, beta2) valid_args = ("lr", "eps", "weight_decay", "amsgrad") opt_args = filter_args(valid_args, kwargs) opt_args["betas"] = betas base_opt = optim.Adam # return optim.Adam( # params, lr, betas=(beta1, beta2), eps=eps, # weight_decay=weight_decay, amsgrad=amsgrad) if opt_type == "adamw": betas = (beta1, beta2) valid_args = ("lr", "eps", "weight_decay", "amsgrad") opt_args = filter_args(valid_args, kwargs) opt_args["betas"] = betas base_opt = optim.AdamW if opt_type == "radam": betas = (beta1, beta2) valid_args = ("lr", "eps", "weight_decay") opt_args = filter_args(valid_args, kwargs) opt_args["betas"] = betas base_opt = RAdam # return RAdam( # params, lr, betas=(beta1, beta2), eps=eps, # weight_decay=weight_decay) if opt_type == "adadelta": valid_args = ("lr", "eps", "weight_decay", "rho") opt_args = filter_args(valid_args, kwargs) base_opt = optim.Adadelta # return optim.Adadelta(params, lr, rho=rho, eps=eps, # weight_decay=weight_decay) if opt_type == "adagrad": valid_args = ("lr", "lr_decay", "weight_decay") opt_args = filter_args(valid_args, kwargs) opt_args["initial_accumulator_value"] = init_acc_val base_opt = optim.Adagrad # return optim.Adagrad( # params, lr, lr_decay=lr_decay, # weight_decay=weight_decay, initial_accumulator_value=init_acc_val) if opt_type == "sparse_adam": betas = (beta1, beta2) valid_args = ("lr", "eps") opt_args = filter_args(valid_args, kwargs) opt_args["betas"] = betas base_opt = optim.SparseAdam # return optim.SparseAdam(params, lr, betas=(beta1, beta2), eps=eps) if opt_type == "adamax": betas = (beta1, beta2) valid_args = ("lr", "eps", "weight_decay") opt_args = filter_args(valid_args, kwargs) opt_args["betas"] = betas base_opt = optim.Adamax # return optim.Adamax(params, lr, betas=(beta1, beta2), eps=eps, # weight_decay=weight_decay) if opt_type == "asgd": valid_args = ("lr", "lambd", "t0", "weight_decay") opt_args = filter_args(valid_args, kwargs) opt_args["alpha"] = asgd_alpha base_opt = optim.ASGD # return optim.ASGD(params, lr, lambd=lambd, alpha=asgd_alpha, t0=t0, # weight_decay=weight_decay) if opt_type == "lbfgs": valid_args = ("lr", "max_iter") opt_args = filter_args(valid_args, kwargs) base_opt = optim.LBFGS # return optim.LBFGS( # params, lr, max_iter=max_iter) if opt_type == "rmsprop": valid_args = ("lr", "eps", "momentum", "weight_decay", "centered") opt_args = filter_args(valid_args, kwargs) opt_args["alpha"] = rmsprop_alpha base_opt = optim.RMSprop # return optim.RMSprop( # params, lr, alpha=rmsprop_alpha, eps=eps, # weight_decay=weight_decay, momentum=momentum, centered=centered) if opt_type == "rprop": opts_args = {"lr": lr, "etas": (0.5, 1.2), "step_sizes": (1e-06, 50)} base_opt = optim.Rprop # return optim.Rprop(params, lr, etas=(0.5, 1.2), step_sizes=(1e-06, 50)) if base_opt is None: raise Exception("unknown optimizer %s" % opt_type) if oss: from fairscale.optim.oss import OSS logging.info("Optimizer uses OSS") return OSS(params, base_opt, **opt_args) return base_opt(params, **opt_args)
[docs] @staticmethod def filter_args(**kwargs): valid_args = ( "opt_type", "lr", "momentum", "beta1", "beta2", "rho", "eps", "weight_decay", "amsgrad", "nesterov", "lambd", "asgd_alpha", "t0", "rmsprop_alpha", "centered", "lr_decay", "init_acc_val", "max_iter", "oss", ) return filter_args(valid_args, kwargs)
[docs] @staticmethod def add_class_args(parser, prefix=None): if prefix is not None: outer_parser = parser parser = ArgumentParser(prog="") parser.add_argument( "--opt-type", type=str.lower, default="adam", choices=[ "sgd", "adam", "adamw", "radam", "adadelta", "adagrad", "sparse_adam", "adamax", "asgd", "lbfgs", "rmsprop", "rprop", ], help=( "Optimizers: SGD, Adam, AdaDelta, AdaGrad, SparseAdam " "AdaMax, ASGD, LFGS, RMSprop, Rprop" ), ) parser.add_argument( "--lr", default=0.001, type=float, help=("Initial learning rate") ) parser.add_argument("--momentum", default=0.6, type=float, help=("Momentum")) parser.add_argument( "--beta1", default=0.9, type=float, help=( "Beta_1 in Adam optimizers, " "coefficient used for computing " "running averages of gradient" ), ) parser.add_argument( "--beta2", default=0.99, type=float, help=( "Beta_2 in Adam optimizers" "coefficient used for computing " "running averages of gradient square" ), ) parser.add_argument( "--rho", default=0.9, type=float, help=( "Rho in AdaDelta," "coefficient used for computing a " "running average of squared gradients" ), ) parser.add_argument( "--eps", default=1e-8, type=float, help=( "Epsilon in RMSprop and Adam optimizers " "term added to the denominator " "to improve numerical stability" ), ) parser.add_argument( "--weight-decay", default=1e-6, type=float, help=("L2 regularization coefficient"), ) parser.add_argument( "--amsgrad", default=False, action="store_true", help=("AMSGrad variant of Adam"), ) parser.add_argument( "--nesterov", default=False, action="store_true", help=("Use Nesterov momentum in SGD"), ) parser.add_argument( "--lambd", default=0.0001, type=float, help=("decay term in ASGD") ) parser.add_argument( "--asgd-alpha", default=0.75, type=float, help=("power for eta update in ASGD"), ) parser.add_argument( "--t0", default=1e6, type=float, help=("point at which to start averaging in ASGD"), ) parser.add_argument( "--rmsprop-alpha", default=0.99, type=float, help=("smoothing constant in RMSprop"), ) parser.add_argument( "--centered", default=False, action="store_true", help=("Compute centered RMSprop, gradient normalized " "by its variance"), ) parser.add_argument( "--lr-decay", default=1e-6, type=float, help=("Learning rate decay in AdaGrad optimizer"), ) parser.add_argument( "--init-acc-val", default=0, type=float, help=("Init accum value in Adagrad"), ) parser.add_argument( "--max-iter", default=20, type=int, help=("max iterations in LBGS") ) if prefix is not None: outer_parser.add_argument("--" + prefix, action=ActionParser(parser=parser))
# help='optimizer options') add_argparse_args = add_class_args