torch.nn.utils.parametrize.register_parametrization¶
-
torch.nn.utils.parametrize.
register_parametrization
(module, tensor_name, parametrization)[source]¶ Adds a parametrization to a tensor in a module.
Assume that
tensor_name="weight"
for simplicity. When accessingmodule.weight
, the module will return the parametrized versionparametrization(module.weight)
. If the original tensor requires a gradient, the backward pass will differentiate through theparametrization
, and the optimizer will update the tensor accordingly.The first time that a module registers a parametrization, this function will add an attribute
parametrizations
to the module of typeParametrizationList
.The list of parametrizations on a tensor will be accessible under
module.parametrizations.weight
.The original tensor will be accessible under
module.parametrizations.weight.original
.Parametrizations may be concatenated by registering several parametrizations on the same attribute.
The training mode of the registered parametrizations are updated on registration if necessary to match the training mode of the host module
Parametrized parameters and buffers have an inbuilt caching system that can be activated using the context manager
cached()
.A
parametrization
may optionally implement a method with signaturedef right_inverse(self, X: Tensor) -> Tensor
If
parametrization
implements this method, it will be possible to assign to the parametrized tensor. This may be used to initialize the tensor, as shown in the example.In most situations,
right_inverse
will be a function such thatforward(right_inverse(X)) == X
(see right inverse). Sometimes, when the parametrization is not surjective, it may be reasonable to relax this, as shown in the example below.- Parameters
- Returns
module
- Return type
- Raises
ValueError – if the module does not have a parameter or a buffer named
tensor_name
Examples
>>> import torch >>> import torch.nn.utils.parametrize as P >>> >>> class Symmetric(torch.nn.Module): >>> def forward(self, X): >>> return X.triu() + X.triu(1).T # Return a symmetric matrix >>> >>> def right_inverse(self, A): >>> return A.triu() >>> >>> m = torch.nn.Linear(5, 5) >>> P.register_parametrization(m, "weight", Symmetric()) >>> print(torch.allclose(m.weight, m.weight.T)) # m.weight is now symmetric True >>> A = torch.rand(5, 5) >>> A = A + A.T # A is now symmetric >>> m.weight = A # Initialize the weight to be the symmetric matrix A >>> print(torch.allclose(m.weight, A)) True