CosineAnnealingWarmRestarts¶
-
class
torch.optim.lr_scheduler.
CosineAnnealingWarmRestarts
(optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False)[source]¶ Set the learning rate of each parameter group using a cosine annealing schedule, where is set to the initial lr, is the number of epochs since the last restart and is the number of epochs between two warm restarts in SGDR:
When , set . When after restart, set .
It has been proposed in SGDR: Stochastic Gradient Descent with Warm Restarts.
- Parameters
optimizer (Optimizer) – Wrapped optimizer.
T_0 (int) – Number of iterations for the first restart.
T_mult (int, optional) – A factor increases after a restart. Default: 1.
eta_min (float, optional) – Minimum learning rate. Default: 0.
last_epoch (int, optional) – The index of last epoch. Default: -1.
verbose (bool) – If
True
, prints a message to stdout for each update. Default:False
.
-
get_last_lr
()¶ Return last computed learning rate by current scheduler.
-
load_state_dict
(state_dict)¶ Loads the schedulers state.
- Parameters
state_dict (dict) – scheduler state. Should be an object returned from a call to
state_dict()
.
-
print_lr
(is_verbose, group, lr, epoch=None)¶ Display the current learning rate.
-
state_dict
()¶ Returns the state of the scheduler as a
dict
.It contains an entry for every variable in self.__dict__ which is not the optimizer.
-
step
(epoch=None)[source]¶ Step could be called after every batch update
Example
>>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult) >>> iters = len(dataloader) >>> for epoch in range(20): >>> for i, sample in enumerate(dataloader): >>> inputs, labels = sample['inputs'], sample['labels'] >>> optimizer.zero_grad() >>> outputs = net(inputs) >>> loss = criterion(outputs, labels) >>> loss.backward() >>> optimizer.step() >>> scheduler.step(epoch + i / iters)
This function can be called in an interleaved way.
Example
>>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult) >>> for epoch in range(20): >>> scheduler.step() >>> scheduler.step(26) >>> scheduler.step() # scheduler.step(27), instead of scheduler(20)