Shortcuts

torch.linalg.tensorsolve

torch.linalg.tensorsolve(A, B, dims=None, *, out=None) → Tensor

Computes the solution X to the system torch.tensordot(A, X) = B.

If m is the product of the first B.ndim dimensions of A and n is the product of the rest of the dimensions, this function expects m and n to be equal.

The returned tensor x satisfies tensordot(A, x, dims=x.ndim) == B. x has shape A[B.ndim:].

If dims is specified, A will be reshaped as

A = movedim(A, dims, range(len(dims) - A.ndim + 1, 0))

Supports inputs of float, double, cfloat and cdouble dtypes.

See also

torch.linalg.tensorinv() computes the multiplicative inverse of torch.tensordot().

Parameters
  • A (Tensor) – tensor to solve for. Its shape must satisfy prod(A.shape[:B.ndim]) == prod(A.shape[B.ndim:]).

  • B (Tensor) – tensor of shape A.shape[B.ndim].

  • dims (Tuple[int], optional) – dimensions of A to be moved. If None, no dimensions are moved. Default: None.

Keyword Arguments

out (Tensor, optional) – output tensor. Ignored if None. Default: None.

Raises

RuntimeError – if the reshaped A.view(m, m) with m as above is not invertible or the product of the first ind dimensions is not equal to the product of the rest of the dimensions.

Examples:

>>> a = torch.eye(2 * 3 * 4).reshape((2 * 3, 4, 2, 3, 4))
>>> b = torch.randn(2 * 3, 4)
>>> x = torch.linalg.tensorsolve(a, b)
>>> x.shape
torch.Size([2, 3, 4])
>>> torch.allclose(torch.tensordot(a, x, dims=x.ndim), b)
True

>>> a = torch.randn(6, 4, 4, 3, 2)
>>> b = torch.randn(4, 3, 2)
>>> x = torch.linalg.tensorsolve(a, b, dims=(0, 2))
>>> x.shape
torch.Size([6, 4])
>>> a = a.permute(1, 3, 4, 0, 2)
>>> a.shape[b.ndim:]
torch.Size([6, 4])
>>> torch.allclose(torch.tensordot(a, x, dims=x.ndim), b, atol=1e-6)
True

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources