import torch import numpy as np class Optimizer(): def __init__(self, parameters, optimizer, lr, eps, lr_scheduler, **kwargs): # Setup torch optimizer self.opt_type = optimizer self.init_lr = lr self.sch_type = lr_scheduler opt = getattr(torch.optim, optimizer) if lr_scheduler == 'warmup': warmup_step = 4000.0 init_lr = lr self.lr_scheduler = lambda step: init_lr * warmup_step ** 0.5 * \ np.minimum((step+1)*warmup_step**-1.5, (step+1)**-0.5) self.opt = opt(parameters, lr=1.0) else: self.lr_scheduler = None self.opt = opt(parameters, lr=lr, eps=eps) # ToDo: 1e-8 better? def get_opt_state_dict(self): return self.opt.state_dict() def load_opt_state_dict(self, state_dict): self.opt.load_state_dict(state_dict) def pre_step(self, step): if self.lr_scheduler is not None: cur_lr = self.lr_scheduler(step) for param_group in self.opt.param_groups: param_group['lr'] = cur_lr else: cur_lr = self.init_lr self.opt.zero_grad() return cur_lr def step(self): self.opt.step() def create_msg(self): return ['Optim.Info.| Algo. = {}\t| Lr = {}\t (schedule = {})' .format(self.opt_type, self.init_lr, self.sch_type)]