torch.nn.utils.parametrizations.spectral_norm¶
-
torch.nn.utils.parametrizations.
spectral_norm
(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None)[source]¶ Applies spectral normalization to a parameter in the given module.
When applied on a vector, it simplifies to
Spectral normalization stabilizes the training of discriminators (critics) in Generative Adversarial Networks (GANs) by reducing the Lipschitz constant of the model. is approximated performing one iteration of the power method every time the weight is accessed. If the dimension of the weight tensor is greater than 2, it is reshaped to 2D in power iteration method to get spectral norm.
See Spectral Normalization for Generative Adversarial Networks .
Note
This function is implemented using the new parametrization functionality in
torch.nn.utils.parametrize.register_parametrization()
. It is a reimplementation oftorch.nn.utils.spectral_norm()
.Note
When this constraint is registered, the singular vectors associated to the largest singular value are estimated rather than sampled at random. These are then updated performing
n_power_iterations
of the power method whenever the tensor is accessed with the module on training mode.Note
If the _SpectralNorm module, i.e., module.parametrization.weight[idx], is in training mode on removal, it will perform another power iteration. If you’d like to avoid this iteration, set the module to eval mode before its removal.
- Parameters
module (nn.Module) – containing module
name (str, optional) – name of weight parameter
n_power_iterations (int, optional) – number of power iterations to calculate spectral norm
eps (float, optional) – epsilon for numerical stability in calculating norms
dim (int, optional) – dimension corresponding to number of outputs, the default is
0
, except for modules that are instances of ConvTranspose{1,2,3}d, when it is1
- Returns
The original module with a new parametrization registered to the specified weight
Example:
>>> snm = spectral_norm(nn.Linear(20, 40)) >>> snm ParametrizedLinear( in_features=20, out_features=40, bias=True (parametrizations): ModuleDict( (weight): ParametrizationList( (0): _SpectralNorm() ) ) ) >>> torch.linalg.matrix_norm(snm.weight, 2) tensor(1.0000, grad_fn=<CopyBackwards>)