Shortcuts

torch.linalg.eigh

torch.linalg.eigh(A, UPLO='L', *, out=None) -> (Tensor, Tensor)

Computes the eigenvalue decomposition of a complex Hermitian or real symmetric matrix.

Letting K\mathbb{K} be R\mathbb{R} or C\mathbb{C}, the eigenvalue decomposition of a complex Hermitian or real symmetric matrix AKn×nA \in \mathbb{K}^{n \times n} is defined as

A=Qdiag(Λ)QHQKn×n,ΛRnA = Q \operatorname{diag}(\Lambda) Q^{\text{H}}\mathrlap{\qquad Q \in \mathbb{K}^{n \times n}, \Lambda \in \mathbb{R}^n}

where QHQ^{\text{H}} is the conjugate transpose when QQ is complex, and the transpose when QQ is real-valued. QQ is orthogonal in the real case and unitary in the complex case.

Supports input of float, double, cfloat and cdouble dtypes. Also supports batches of matrices, and if A is a batch of matrices then the output has the same batch dimensions.

A is assumed to be Hermitian (resp. symmetric), but this is not checked internally, instead:

  • If UPLO= ‘L’ (default), only the lower triangular part of the matrix is used in the computation.

  • If UPLO= ‘U’, only the upper triangular part of the matrix is used.

The eigenvalues are returned in ascending order.

Note

When inputs are on a CUDA device, this function synchronizes that device with the CPU.

Note

The eigenvalues of real symmetric or complex Hermitian matrices are always real.

Warning

The eigenvectors of a symmetric matrix are not unique, nor are they continuous with respect to A. Due to this lack of uniqueness, different hardware and software may compute different eigenvectors.

This non-uniqueness is caused by the fact that multiplying an eigenvector by -1 in the real case or by eiϕ,ϕRe^{i \phi}, \phi \in \mathbb{R} in the complex case produces another set of valid eigenvectors of the matrix. This non-uniqueness problem is even worse when the matrix has repeated eigenvalues. In this case, one may multiply the associated eigenvectors spanning the subspace by a rotation matrix and the resulting eigenvectors will be valid eigenvectors.

Warning

Gradients computed using the eigenvectors tensor will only be finite when A has unique eigenvalues. Furthermore, if the distance between any two eigvalues is close to zero, the gradient will be numerically unstable, as it depends on the eigenvalues λi\lambda_i through the computation of 1minijλiλj\frac{1}{\min_{i \neq j} \lambda_i - \lambda_j}.

See also

torch.linalg.eigvalsh() computes only the eigenvalues values of a Hermitian matrix. Unlike torch.linalg.eigh(), the gradients of eigvalsh() are always numerically stable.

torch.linalg.cholesky() for a different decomposition of a Hermitian matrix. The Cholesky decomposition gives less information about the matrix but is much faster to compute than the eigenvalue decomposition.

torch.linalg.eig() for a (slower) function that computes the eigenvalue decomposition of a not necessarily Hermitian square matrix.

torch.linalg.svd() for a (slower) function that computes the more general SVD decomposition of matrices of any shape.

torch.linalg.qr() for another (much faster) decomposition that works on general matrices.

Parameters
  • A (Tensor) – tensor of shape (*, n, n) where * is zero or more batch dimensions consisting of symmetric or Hermitian matrices.

  • UPLO ('L', 'U', optional) – controls whether to use the upper or lower triangular part of A in the computations. Default: ‘L’.

Keyword Arguments

out (tuple, optional) – output tuple of two tensors. Ignored if None. Default: None.

Returns

A named tuple (eigenvalues, eigenvectors) which corresponds to Λ\Lambda and VV above.

eigenvalues will always be real-valued, even when A is complex. It will also be ordered in ascending order.

eigenvectors will have the same dtype as A.

Examples:

>>> a = torch.randn(2, 2, dtype=torch.complex128)
>>> a = a + a.t().conj()  # creates a Hermitian matrix
>>> a
tensor([[2.9228+0.0000j, 0.2029-0.0862j],
        [0.2029+0.0862j, 0.3464+0.0000j]], dtype=torch.complex128)
>>> w, v = torch.linalg.eigh(a)
>>> w
tensor([0.3277, 2.9415], dtype=torch.float64)
>>> v
tensor([[-0.0846+-0.0000j, -0.9964+0.0000j],
        [ 0.9170+0.3898j, -0.0779-0.0331j]], dtype=torch.complex128)
>>> torch.allclose(torch.matmul(v, torch.matmul(w.to(v.dtype).diag_embed(), v.t().conj())), a)
True

>>> a = torch.randn(3, 2, 2, dtype=torch.float64)
>>> a = a + a.transpose(-2, -1)  # creates a symmetric matrix
>>> w, v = torch.linalg.eigh(a)
>>> torch.allclose(torch.matmul(v, torch.matmul(w.diag_embed(), v.transpose(-2, -1))), a)
True

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources