Shortcuts

Source code for torch.optim.lr_scheduler

import types
import math
from torch._six import inf
from functools import partial, wraps
import warnings
from bisect import bisect_right

from .optimizer import Optimizer


class _LRScheduler(object):
    def __init__(self, optimizer, last_epoch=-1):
        if not isinstance(optimizer, Optimizer):
            raise TypeError('{} is not an Optimizer'.format(
                type(optimizer).__name__))
        self.optimizer = optimizer
        if last_epoch == -1:
            for group in optimizer.param_groups:
                group.setdefault('initial_lr', group['lr'])
            last_epoch = 0
        else:
            for i, group in enumerate(optimizer.param_groups):
                if 'initial_lr' not in group:
                    raise KeyError("param 'initial_lr' is not specified "
                                   "in param_groups[{}] when resuming an optimizer".format(i))
        self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
        self.last_epoch = last_epoch

        # Following https://github.com/pytorch/pytorch/issues/20124
        # We would like to ensure that `lr_scheduler.step()` is called after
        # `optimizer.step()`
        def with_counter(func, opt):
            @wraps(func)
            def wrapper(*args, **kwargs):
                opt._step_count += 1
                return func(*args, **kwargs)
            wrapper._with_counter = True
            return wrapper

        self.optimizer.step = with_counter(self.optimizer.step, self.optimizer)
        self.optimizer._step_count = 0
        self._step_count = 0
        self.step(last_epoch)

    def state_dict(self):
        """Returns the state of the scheduler as a :class:`dict`.

        It contains an entry for every variable in self.__dict__ which
        is not the optimizer.
        """
        return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}

    def load_state_dict(self, state_dict):
        """Loads the schedulers state.

        Arguments:
            state_dict (dict): scheduler state. Should be an object returned
                from a call to :meth:`state_dict`.
        """
        self.__dict__.update(state_dict)

    def get_lr(self):
        raise NotImplementedError

    def step(self, epoch=None):
        # Raise a warning if old pattern is detected
        # https://github.com/pytorch/pytorch/issues/20124
        if self._step_count == 1:
            if not hasattr(self.optimizer.step, "_with_counter"):
                warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler "
                              "initialization. Please, make sure to call `optimizer.step()` before "
                              "`lr_scheduler.step()`. See more details at "
                              "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)

            # Just check if there were two first lr_scheduler.step() calls before optimizer.step()
            elif self.optimizer._step_count < 1:
                warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
                              "In PyTorch 1.1.0 and later, you should call them in the opposite order: "
                              "`optimizer.step()` before `lr_scheduler.step()`.  Failure to do this "
                              "will result in PyTorch skipping the first value of the learning rate schedule."
                              "See more details at "
                              "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
        self._step_count += 1

        if epoch is None:
            epoch = self.last_epoch + 1
        self.last_epoch = epoch
        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr


[docs]class LambdaLR(_LRScheduler): """Sets the learning rate of each parameter group to the initial lr times a given function. When last_epoch=-1, sets initial lr as lr. Args: optimizer (Optimizer): Wrapped optimizer. lr_lambda (function or list): A function which computes a multiplicative factor given an integer parameter epoch, or a list of such functions, one for each group in optimizer.param_groups. last_epoch (int): The index of last epoch. Default: -1. Example: >>> # Assuming optimizer has two groups. >>> lambda1 = lambda epoch: epoch // 30 >>> lambda2 = lambda epoch: 0.95 ** epoch >>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2]) >>> for epoch in range(100): >>> train(...) >>> validate(...) >>> scheduler.step() """ def __init__(self, optimizer, lr_lambda, last_epoch=-1): self.optimizer = optimizer if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) else: if len(lr_lambda) != len(optimizer.param_groups): raise ValueError("Expected {} lr_lambdas, but got {}".format( len(optimizer.param_groups), len(lr_lambda))) self.lr_lambdas = list(lr_lambda) self.last_epoch = last_epoch super(LambdaLR, self).__init__(optimizer, last_epoch)
[docs] def state_dict(self): """Returns the state of the scheduler as a :class:`dict`. It contains an entry for every variable in self.__dict__ which is not the optimizer. The learning rate lambda functions will only be saved if they are callable objects and not if they are functions or lambdas. """ state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')} state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas) for idx, fn in enumerate(self.lr_lambdas): if not isinstance(fn, types.FunctionType): state_dict['lr_lambdas'][idx] = fn.__dict__.copy() return state_dict
[docs] def load_state_dict(self, state_dict): """Loads the schedulers state. Arguments: state_dict (dict): scheduler state. Should be an object returned from a call to :meth:`state_dict`. """ lr_lambdas = state_dict.pop('lr_lambdas') self.__dict__.update(state_dict) for idx, fn in enumerate(lr_lambdas): if fn is not None: self.lr_lambdas[idx].__dict__.update(fn)
def get_lr(self): return [base_lr * lmbda(self.last_epoch) for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]
[docs]class StepLR(_LRScheduler): """Sets the learning rate of each parameter group to the initial lr decayed by gamma every step_size epochs. When last_epoch=-1, sets initial lr as lr. Args: optimizer (Optimizer): Wrapped optimizer. step_size (int): Period of learning rate decay. gamma (float): Multiplicative factor of learning rate decay. Default: 0.1. last_epoch (int): The index of last epoch. Default: -1. Example: >>> # Assuming optimizer uses lr = 0.05 for all groups >>> # lr = 0.05 if epoch < 30 >>> # lr = 0.005 if 30 <= epoch < 60 >>> # lr = 0.0005 if 60 <= epoch < 90 >>> # ... >>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1) >>> for epoch in range(100): >>> train(...) >>> validate(...) >>> scheduler.step() """ def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1): self.step_size = step_size self.gamma = gamma super(StepLR, self).__init__(optimizer, last_epoch) def get_lr(self): return [base_lr * self.gamma ** (self.last_epoch // self.step_size) for base_lr in self.base_lrs]
[docs]class MultiStepLR(_LRScheduler): """Set the learning rate of each parameter group to the initial lr decayed by gamma once the number of epoch reaches one of the milestones. When last_epoch=-1, sets initial lr as lr. Args: optimizer (Optimizer): Wrapped optimizer. milestones (list): List of epoch indices. Must be increasing. gamma (float): Multiplicative factor of learning rate decay. Default: 0.1. last_epoch (int): The index of last epoch. Default: -1. Example: >>> # Assuming optimizer uses lr = 0.05 for all groups >>> # lr = 0.05 if epoch < 30 >>> # lr = 0.005 if 30 <= epoch < 80 >>> # lr = 0.0005 if epoch >= 80 >>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1) >>> for epoch in range(100): >>> train(...) >>> validate(...) >>> scheduler.step() """ def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1): if not list(milestones) == sorted(milestones): raise ValueError('Milestones should be a list of' ' increasing integers. Got {}', milestones) self.milestones = milestones self.gamma = gamma super(MultiStepLR, self).__init__(optimizer, last_epoch) def get_lr(self): return [base_lr * self.gamma ** bisect_right(self.milestones, self.last_epoch) for base_lr in self.base_lrs]
[docs]class ExponentialLR(_LRScheduler): """Set the learning rate of each parameter group to the initial lr decayed by gamma every epoch. When last_epoch=-1, sets initial lr as lr. Args: optimizer (Optimizer): Wrapped optimizer. gamma (float): Multiplicative factor of learning rate decay. last_epoch (int): The index of last epoch. Default: -1. """ def __init__(self, optimizer, gamma, last_epoch=-1): self.gamma = gamma super(ExponentialLR, self).__init__(optimizer, last_epoch) def get_lr(self): return [base_lr * self.gamma ** self.last_epoch for base_lr in self.base_lrs]
[docs]class CosineAnnealingLR(_LRScheduler): r"""Set the learning rate of each parameter group using a cosine annealing schedule, where :math:`\eta_{max}` is set to the initial lr and :math:`T_{cur}` is the number of epochs since the last restart in SGDR: .. math:: \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 + \cos(\frac{T_{cur}}{T_{max}}\pi)) When last_epoch=-1, sets initial lr as lr. It has been proposed in `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only implements the cosine annealing part of SGDR, and not the restarts. Args: optimizer (Optimizer): Wrapped optimizer. T_max (int): Maximum number of iterations. eta_min (float): Minimum learning rate. Default: 0. last_epoch (int): The index of last epoch. Default: -1. .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: https://arxiv.org/abs/1608.03983 """ def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1): self.T_max = T_max self.eta_min = eta_min super(CosineAnnealingLR, self).__init__(optimizer, last_epoch) def get_lr(self): return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2 for base_lr in self.base_lrs]
[docs]class ReduceLROnPlateau(object): """Reduce learning rate when a metric has stopped improving. Models often benefit from reducing the learning rate by a factor of 2-10 once learning stagnates. This scheduler reads a metrics quantity and if no improvement is seen for a 'patience' number of epochs, the learning rate is reduced. Args: optimizer (Optimizer): Wrapped optimizer. mode (str): One of `min`, `max`. In `min` mode, lr will be reduced when the quantity monitored has stopped decreasing; in `max` mode it will be reduced when the quantity monitored has stopped increasing. Default: 'min'. factor (float): Factor by which the learning rate will be reduced. new_lr = lr * factor. Default: 0.1. patience (int): Number of epochs with no improvement after which learning rate will be reduced. For example, if `patience = 2`, then we will ignore the first 2 epochs with no improvement, and will only decrease the LR after the 3rd epoch if the loss still hasn't improved then. Default: 10. verbose (bool): If ``True``, prints a message to stdout for each update. Default: ``False``. threshold (float): Threshold for measuring the new optimum, to only focus on significant changes. Default: 1e-4. threshold_mode (str): One of `rel`, `abs`. In `rel` mode, dynamic_threshold = best * ( 1 + threshold ) in 'max' mode or best * ( 1 - threshold ) in `min` mode. In `abs` mode, dynamic_threshold = best + threshold in `max` mode or best - threshold in `min` mode. Default: 'rel'. cooldown (int): Number of epochs to wait before resuming normal operation after lr has been reduced. Default: 0. min_lr (float or list): A scalar or a list of scalars. A lower bound on the learning rate of all param groups or each group respectively. Default: 0. eps (float): Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is ignored. Default: 1e-8. Example: >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) >>> scheduler = ReduceLROnPlateau(optimizer, 'min') >>> for epoch in range(10): >>> train(...) >>> val_loss = validate(...) >>> # Note that step should be called after validate() >>> scheduler.step(val_loss) """ def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=1e-4, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-8): if factor >= 1.0: raise ValueError('Factor should be < 1.0.') self.factor = factor if not isinstance(optimizer, Optimizer): raise TypeError('{} is not an Optimizer'.format( type(optimizer).__name__)) self.optimizer = optimizer if isinstance(min_lr, list) or isinstance(min_lr, tuple): if len(min_lr) != len(optimizer.param_groups): raise ValueError("expected {} min_lrs, got {}".format( len(optimizer.param_groups), len(min_lr))) self.min_lrs = list(min_lr) else: self.min_lrs = [min_lr] * len(optimizer.param_groups) self.patience = patience self.verbose = verbose self.cooldown = cooldown self.cooldown_counter = 0 self.mode = mode self.threshold = threshold self.threshold_mode = threshold_mode self.best = None self.num_bad_epochs = None self.mode_worse = None # the worse value for the chosen mode self.is_better = None self.eps = eps self.last_epoch = -1 self._init_is_better(mode=mode, threshold=threshold, threshold_mode=threshold_mode) self._reset() def _reset(self): """Resets num_bad_epochs counter and cooldown counter.""" self.best = self.mode_worse self.cooldown_counter = 0 self.num_bad_epochs = 0 def step(self, metrics, epoch=None): # convert `metrics` to float, in case it's a zero-dim Tensor current = float(metrics) if epoch is None: epoch = self.last_epoch = self.last_epoch + 1 self.last_epoch = epoch if self.is_better(current, self.best): self.best = current self.num_bad_epochs = 0 else: self.num_bad_epochs += 1 if self.in_cooldown: self.cooldown_counter -= 1 self.num_bad_epochs = 0 # ignore any bad epochs in cooldown if self.num_bad_epochs > self.patience: self._reduce_lr(epoch) self.cooldown_counter = self.cooldown self.num_bad_epochs = 0 def _reduce_lr(self, epoch): for i, param_group in enumerate(self.optimizer.param_groups): old_lr = float(param_group['lr']) new_lr = max(old_lr * self.factor, self.min_lrs[i]) if old_lr - new_lr > self.eps: param_group['lr'] = new_lr if self.verbose: print('Epoch {:5d}: reducing learning rate' ' of group {} to {:.4e}.'.format(epoch, i, new_lr)) @property def in_cooldown(self): return self.cooldown_counter > 0 def _cmp(self, mode, threshold_mode, threshold, a, best): if mode == 'min' and threshold_mode == 'rel': rel_epsilon = 1. - threshold return a < best * rel_epsilon elif mode == 'min' and threshold_mode == 'abs': return a < best - threshold elif mode == 'max' and threshold_mode == 'rel': rel_epsilon = threshold + 1. return a > best * rel_epsilon else: # mode == 'max' and epsilon_mode == 'abs': return a > best + threshold def _init_is_better(self, mode, threshold, threshold_mode): if mode not in {'min', 'max'}: raise ValueError('mode ' + mode + ' is unknown!') if threshold_mode not in {'rel', 'abs'}: raise ValueError('threshold mode ' + threshold_mode + ' is unknown!') if mode == 'min': self.mode_worse = inf else: # mode == 'max': self.mode_worse = -inf self.is_better = partial(self._cmp, mode, threshold_mode, threshold) def state_dict(self): return {key: value for key, value in self.__dict__.items() if key not in {'optimizer', 'is_better'}} def load_state_dict(self, state_dict): self.__dict__.update(state_dict) self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode)
[docs]class CyclicLR(_LRScheduler): """Sets the learning rate of each parameter group according to cyclical learning rate policy (CLR). The policy cycles the learning rate between two boundaries with a constant frequency, as detailed in the paper `Cyclical Learning Rates for Training Neural Networks`_. The distance between the two boundaries can be scaled on a per-iteration or per-cycle basis. Cyclical learning rate policy changes the learning rate after every batch. `step` should be called after a batch has been used for training. This class has three built-in policies, as put forth in the paper: "triangular": A basic triangular cycle w/ no amplitude scaling. "triangular2": A basic triangular cycle that scales initial amplitude by half each cycle. "exp_range": A cycle that scales initial amplitude by gamma**(cycle iterations) at each cycle iteration. This implementation was adapted from the github repo: `bckenstler/CLR`_ Args: optimizer (Optimizer): Wrapped optimizer. base_lr (float or list): Initial learning rate which is the lower boundary in the cycle for each parameter group. max_lr (float or list): Upper learning rate boundaries in the cycle for each parameter group. Functionally, it defines the cycle amplitude (max_lr - base_lr). The lr at any cycle is the sum of base_lr and some scaling of the amplitude; therefore max_lr may not actually be reached depending on scaling function. step_size_up (int): Number of training iterations in the increasing half of a cycle. Default: 2000 step_size_down (int): Number of training iterations in the decreasing half of a cycle. If step_size_down is None, it is set to step_size_up. Default: None mode (str): One of {triangular, triangular2, exp_range}. Values correspond to policies detailed above. If scale_fn is not None, this argument is ignored. Default: 'triangular' gamma (float): Constant in 'exp_range' scaling function: gamma**(cycle iterations) Default: 1.0 scale_fn (function): Custom scaling policy defined by a single argument lambda function, where 0 <= scale_fn(x) <= 1 for all x >= 0. If specified, then 'mode' is ignored. Default: None scale_mode (str): {'cycle', 'iterations'}. Defines whether scale_fn is evaluated on cycle number or cycle iterations (training iterations since start of cycle). Default: 'cycle' cycle_momentum (bool): If ``True``, momentum is cycled inversely to learning rate between 'base_momentum' and 'max_momentum'. Default: True base_momentum (float or list): Lower momentum boundaries in the cycle for each parameter group. Note that momentum is cycled inversely to learning rate; at the peak of a cycle, momentum is 'base_momentum' and learning rate is 'max_lr'. Default: 0.8 max_momentum (float or list): Upper momentum boundaries in the cycle for each parameter group. Functionally, it defines the cycle amplitude (max_momentum - base_momentum). The momentum at any cycle is the difference of max_momentum and some scaling of the amplitude; therefore base_momentum may not actually be reached depending on scaling function. Note that momentum is cycled inversely to learning rate; at the start of a cycle, momentum is 'max_momentum' and learning rate is 'base_lr' Default: 0.9 last_epoch (int): The index of the last batch. This parameter is used when resuming a training job. Since `step()` should be invoked after each batch instead of after each epoch, this number represents the total number of *batches* computed, not the total number of epochs computed. When last_epoch=-1, the schedule is started from the beginning. Default: -1 Example: >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) >>> scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1) >>> data_loader = torch.utils.data.DataLoader(...) >>> for epoch in range(10): >>> for batch in data_loader: >>> train_batch(...) >>> scheduler.step() .. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186 .. _bckenstler/CLR: https://github.com/bckenstler/CLR """ def __init__(self, optimizer, base_lr, max_lr, step_size_up=2000, step_size_down=None, mode='triangular', gamma=1., scale_fn=None, scale_mode='cycle', cycle_momentum=True, base_momentum=0.8, max_momentum=0.9, last_epoch=-1): if not isinstance(optimizer, Optimizer): raise TypeError('{} is not an Optimizer'.format( type(optimizer).__name__)) self.optimizer = optimizer base_lrs = self._format_param('base_lr', optimizer, base_lr) if last_epoch == -1: for lr, group in zip(base_lrs, optimizer.param_groups): group['lr'] = lr self.max_lrs = self._format_param('max_lr', optimizer, max_lr) step_size_up = float(step_size_up) step_size_down = float(step_size_down) if step_size_down is not None else step_size_up self.total_size = step_size_up + step_size_down self.step_ratio = step_size_up / self.total_size if mode not in ['triangular', 'triangular2', 'exp_range'] \ and scale_fn is None: raise ValueError('mode is invalid and scale_fn is None') self.mode = mode self.gamma = gamma if scale_fn is None: if self.mode == 'triangular': self.scale_fn = self._triangular_scale_fn self.scale_mode = 'cycle' elif self.mode == 'triangular2': self.scale_fn = self._triangular2_scale_fn self.scale_mode = 'cycle' elif self.mode == 'exp_range': self.scale_fn = self._exp_range_scale_fn self.scale_mode = 'iterations' else: self.scale_fn = scale_fn self.scale_mode = scale_mode self.cycle_momentum = cycle_momentum if cycle_momentum: if 'momentum' not in optimizer.defaults: raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled') base_momentums = self._format_param('base_momentum', optimizer, base_momentum) if last_epoch == -1: for momentum, group in zip(base_momentums, optimizer.param_groups): group['momentum'] = momentum self.base_momentums = list(map(lambda group: group['momentum'], optimizer.param_groups)) self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum) super(CyclicLR, self).__init__(optimizer, last_epoch) def _format_param(self, name, optimizer, param): """Return correctly formatted lr/momentum for each param group.""" if isinstance(param, (list, tuple)): if len(param) != len(optimizer.param_groups): raise ValueError("expected {} values for {}, got {}".format( len(optimizer.param_groups), name, len(param))) return param else: return [param] * len(optimizer.param_groups) def _triangular_scale_fn(self, x): return 1. def _triangular2_scale_fn(self, x): return 1 / (2. ** (x - 1)) def _exp_range_scale_fn(self, x): return self.gamma**(x)
[docs] def get_lr(self): """Calculates the learning rate at batch index. This function treats `self.last_epoch` as the last batch index. If `self.cycle_momentum` is ``True``, this function has a side effect of updating the optimizer's momentum. """ cycle = math.floor(1 + self.last_epoch / self.total_size) x = 1. + self.last_epoch / self.total_size - cycle if x <= self.step_ratio: scale_factor = x / self.step_ratio else: scale_factor = (x - 1) / (self.step_ratio - 1) lrs = [] for base_lr, max_lr in zip(self.base_lrs, self.max_lrs): base_height = (max_lr - base_lr) * scale_factor if self.scale_mode == 'cycle': lr = base_lr + base_height * self.scale_fn(cycle) else: lr = base_lr + base_height * self.scale_fn(self.last_epoch) lrs.append(lr) if self.cycle_momentum: momentums = [] for base_momentum, max_momentum in zip(self.base_momentums, self.max_momentums): base_height = (max_momentum - base_momentum) * scale_factor if self.scale_mode == 'cycle': momentum = max_momentum - base_height * self.scale_fn(cycle) else: momentum = max_momentum - base_height * self.scale_fn(self.last_epoch) momentums.append(momentum) for param_group, momentum in zip(self.optimizer.param_groups, momentums): param_group['momentum'] = momentum return lrs
class CosineAnnealingWarmRestarts(_LRScheduler): r"""Set the learning rate of each parameter group using a cosine annealing schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}` is the number of epochs since the last restart and :math:`T_{i}` is the number of epochs between two warm restarts in SGDR: .. math:: \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 + \cos(\frac{T_{cur}}{T_{i}}\pi)) When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`. When :math:`T_{cur}=0`(after restart), set :math:`\eta_t=\eta_{max}`. It has been proposed in `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Args: optimizer (Optimizer): Wrapped optimizer. T_0 (int): Number of iterations for the first restart. T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1. eta_min (float, optional): Minimum learning rate. Default: 0. last_epoch (int, optional): The index of last epoch. Default: -1. .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: https://arxiv.org/abs/1608.03983 """ def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1): if T_0 <= 0 or not isinstance(T_0, int): raise ValueError("Expected positive integer T_0, but got {}".format(T_0)) if T_mult < 1 or not isinstance(T_mult, int): raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult)) self.T_0 = T_0 self.T_i = T_0 self.T_mult = T_mult self.eta_min = eta_min super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch) self.T_cur = self.last_epoch def get_lr(self): return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2 for base_lr in self.base_lrs] def step(self, epoch=None): """Step could be called after every batch update Example: >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult) >>> iters = len(dataloader) >>> for epoch in range(20): >>> for i, sample in enumerate(dataloader): >>> inputs, labels = sample['inputs'], sample['labels'] >>> scheduler.step(epoch + i / iters) >>> optimizer.zero_grad() >>> outputs = net(inputs) >>> loss = criterion(outputs, labels) >>> loss.backward() >>> optimizer.step() This function can be called in an interleaved way. Example: >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult) >>> for epoch in range(20): >>> scheduler.step() >>> scheduler.step(26) >>> scheduler.step() # scheduler.step(27), instead of scheduler(20) """ if epoch is None: epoch = self.last_epoch + 1 self.T_cur = self.T_cur + 1 if self.T_cur >= self.T_i: self.T_cur = self.T_cur - self.T_i self.T_i = self.T_i * self.T_mult else: if epoch < 0: raise ValueError("Expected non-negative epoch, but got {}".format(epoch)) if epoch >= self.T_0: if self.T_mult == 1: self.T_cur = epoch % self.T_0 else: n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult)) self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1) self.T_i = self.T_0 * self.T_mult ** (n) else: self.T_i = self.T_0 self.T_cur = epoch self.last_epoch = math.floor(epoch) for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): param_group['lr'] = lr

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources