Shortcuts

Source code for torch.optim.adam

import torch
from . import functional as F
from .optimizer import Optimizer


[docs]class Adam(Optimizer): r"""Implements Adam algorithm. It has been proposed in `Adam: A Method for Stochastic Optimization`_. The implementation of the L2 penalty follows changes proposed in `Decoupled Weight Decay Regularization`_. Arguments: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 1e-3) betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) amsgrad (boolean, optional): whether to use the AMSGrad variant of this algorithm from the paper `On the Convergence of Adam and Beyond`_ (default: False) .. _Adam\: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 .. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 .. _On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ """ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if not 0.0 <= weight_decay: raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) super(Adam, self).__init__(params, defaults) def __setstate__(self, state): super(Adam, self).__setstate__(state) for group in self.param_groups: group.setdefault('amsgrad', False)
[docs] @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: params_with_grad = [] grads = [] exp_avgs = [] exp_avg_sqs = [] state_sums = [] max_exp_avg_sqs = [] state_steps = [] for p in group['params']: if p.grad is not None: params_with_grad.append(p) if p.grad.is_sparse: raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') grads.append(p.grad) state = self.state[p] # Lazy state initialization if len(state) == 0: state['step'] = 0 # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) if group['amsgrad']: # Maintains max of all exp. moving avg. of sq. grad. values state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) exp_avgs.append(state['exp_avg']) exp_avg_sqs.append(state['exp_avg_sq']) if group['amsgrad']: max_exp_avg_sqs.append(state['max_exp_avg_sq']) # update the steps for each param group update state['step'] += 1 # record the step after step update state_steps.append(state['step']) beta1, beta2 = group['betas'] F.adam(params_with_grad, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, group['amsgrad'], beta1, beta2, group['lr'], group['weight_decay'], group['eps'] ) return loss

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources