torch.linalg.cholesky_ex¶
-
torch.linalg.
cholesky_ex
(A, *, check_errors=False, out=None) -> (Tensor, Tensor)¶ Computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix.
This function skips the (slow) error checking and error message construction of
torch.linalg.cholesky()
, instead directly returning the LAPACK error codes as part of a named tuple(L, info)
. This makes this function a faster way to check if a matrix is positive-definite, and it provides an opportunity to handle decomposition errors more gracefully or performantly thantorch.linalg.cholesky()
does.Supports input of float, double, cfloat and cdouble dtypes. Also supports batches of matrices, and if
A
is a batch of matrices then the output has the same batch dimensions.If
A
is not a Hermitian positive-definite matrix, or if it’s a batch of matrices and one or more of them is not a Hermitian positive-definite matrix, theninfo
stores a positive integer for the corresponding matrix. The positive integer indicates the order of the leading minor that is not positive-definite, and the decomposition could not be completed.info
filled with zeros indicates that the decomposition was successful. Ifcheck_errors=True
andinfo
contains positive integers, then a RuntimeError is thrown.Note
If
A
is on a CUDA device, this function may synchronize that device with the CPU.Warning
This function is “experimental” and it may change in a future PyTorch release.
See also
torch.linalg.cholesky()
is a NumPy compatible variant that always checks for errors.- Parameters
- Keyword Arguments
out (tuple, optional) – tuple of two tensors to write the output to. Ignored if None. Default: None.
Examples:
>>> a = torch.randn(2, 2, dtype=torch.complex128) >>> a = a @ a.t().conj() # creates a Hermitian positive-definite matrix >>> l, info = torch.linalg.cholesky_ex(a) >>> a tensor([[ 2.3792+0.0000j, -0.9023+0.9831j], [-0.9023-0.9831j, 0.8757+0.0000j]], dtype=torch.complex128) >>> l tensor([[ 1.5425+0.0000j, 0.0000+0.0000j], [-0.5850-0.6374j, 0.3567+0.0000j]], dtype=torch.complex128) >>> info tensor(0, dtype=torch.int32)