

torch.solve(input, A, *, out=None) -> (Tensor, Tensor)

This function returns the solution to the system of linear equations represented by AX=BAX = B and the LU factorization of A, in order as a namedtuple solution, LU.

LU contains L and U factors for LU factorization of A.

torch.solve(B, A) can take in 2D inputs B, A or inputs that are batches of 2D matrices. If the inputs are batches, then returns batched outputs solution, LU.

Supports real-valued and complex-valued inputs.


torch.solve() is deprecated in favor of torch.linalg.solve() and will be removed in a future PyTorch release. torch.linalg.solve() has its arguments reversed and does not return the LU factorization of the input. To get the LU factorization see, which may be used with torch.lu_solve() and torch.lu_unpack().

X = torch.solve(B, A).solution should be replaced with

X = torch.linalg.solve(A, B)


Irrespective of the original strides, the returned matrices solution and LU will be transposed, i.e. with strides like B.contiguous().transpose(-1, -2).stride() and A.contiguous().transpose(-1, -2).stride() respectively.

  • input (Tensor) – input matrix BB of size (,m,k)(*, m, k) , where * is zero or more batch dimensions.

  • A (Tensor) – input square matrix of size (,m,m)(*, m, m), where * is zero or more batch dimensions.

Keyword Arguments

out ((Tensor, Tensor), optional) – optional output tuple.


>>> A = torch.tensor([[6.80, -2.11,  5.66,  5.97,  8.23],
...                   [-6.05, -3.30,  5.36, -4.44,  1.08],
...                   [-0.45,  2.58, -2.70,  0.27,  9.04],
...                   [8.32,  2.71,  4.35,  -7.17,  2.14],
...                   [-9.67, -5.14, -7.26,  6.08, -6.87]]).t()
>>> B = torch.tensor([[4.02,  6.19, -8.22, -7.57, -3.03],
...                   [-1.56,  4.00, -8.67,  1.75,  2.86],
...                   [9.81, -4.09, -4.57, -8.61,  8.99]]).t()
>>> X, LU = torch.solve(B, A)
>>> torch.dist(B,, X))
tensor(1.00000e-06 *

>>> # Batched solver example
>>> A = torch.randn(2, 3, 1, 4, 4)
>>> B = torch.randn(2, 3, 1, 4, 6)
>>> X, LU = torch.solve(B, A)
>>> torch.dist(B, A.matmul(X))
tensor(1.00000e-06 *


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources