torch.linalg.inv¶
-
torch.linalg.
inv
(A, *, out=None) → Tensor¶ Computes the inverse of a square matrix if it exists. Throws a RuntimeError if the matrix is not invertible.
Letting be or , for a matrix , its inverse matrix (if it exists) is defined as
where is the n-dimensional identity matrix.
The inverse matrix exists if and only if is invertible. In this case, the inverse is unique.
Supports input of float, double, cfloat and cdouble dtypes. Also supports batches of matrices, and if
A
is a batch of matrices then the output has the same batch dimensions.Note
When inputs are on a CUDA device, this function synchronizes that device with the CPU.
Note
Consider using
torch.linalg.solve()
if possible for multiplying a matrix on the left by the inverse, as:torch.linalg.solve(A, B) == A.inv() @ B
It is always prefered to use
solve()
when possible, as it is faster and more numerically stable than computing the inverse explicitly.See also
torch.linalg.pinv()
computes the pseudoinverse (Moore-Penrose inverse) of matrices of any shape.torch.linalg.solve()
computesA
.inv() @B
with a numerically stable algorithm.- Parameters
A (Tensor) – tensor of shape (*, n, n) where * is zero or more batch dimensions consisting of invertible matrices.
- Keyword Arguments
out (Tensor, optional) – output tensor. Ignored if None. Default: None.
- Raises
RuntimeError – if the matrix
A
or any matrix in the batch of matricesA
is not invertible.
Examples:
>>> x = torch.rand(4, 4) >>> y = torch.linalg.inv(x) >>> z = x @ y >>> z tensor([[ 1.0000, -0.0000, -0.0000, 0.0000], [ 0.0000, 1.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 1.0000, 0.0000], [ 0.0000, -0.0000, -0.0000, 1.0000]]) >>> torch.dist(z, torch.eye(4)) tensor(1.1921e-07) >>> # Batched inverse example >>> x = torch.randn(2, 3, 4, 4) >>> y = torch.linalg.inv(x) >>> z = x @ y >>> torch.dist(z, torch.eye(4).expand_as(x)) tensor(1.9073e-06) >>> x = torch.rand(4, 4, dtype=torch.cdouble) >>> y = torch.linalg.inv(x) >>> z = x @ y >>> torch.dist(z, torch.eye(4, dtype=torch.cdouble)) tensor(7.5107e-16, dtype=torch.float64)