netket.optimizer.solver.pinv_smooth_distributed

netket.optimizer.solver.pinv_smooth_distributed#

netket.optimizer.solver.pinv_smooth_distributed(A, b, *, rtol=1e-14, rtol_smooth=1e-14, local_tile_size=None, x0=None)[source]#

Solve the linear system using a distributed pseudo-inverse from eigendecomposition.

This is the distributed/multi-GPU equivalent of pinv_smooth(). It uses jaxmg’s syevd function (cuSOLVERMg SYEVD) to compute eigenvalues and eigenvectors of a symmetric/Hermitian matrix in a distributed manner, then applies the same smoothed regularization as pinv_smooth().

The eigenvalues \(\lambda_i\) smaller than \(r_\textrm{tol} \lambda_\textrm{max}\) are truncated (where \(\lambda_\textrm{max}\) is the largest eigenvalue).

The eigenvalues are further smoothed with another filter, originally introduced in Medvidovic, Sels arXiv:2212.11289 (2022), given by the following equation

\[\tilde\lambda_i^{-1}=\frac{\lambda_i^{-1}}{1+\big(\epsilon\frac{\lambda_\textrm{max}}{\lambda_i}\big)^6}\]

Note

This solver requires jaxmg package to be installed.

Parameters:
  • A – the matrix A in Ax=b (should be symmetric/Hermitian, sharded)

  • b – the vector b in Ax=b

  • rtol (float) – Relative tolerance for small eigenvalues of A. For the purposes of rank determination, eigenvalues are treated as zero if they are smaller than rtol times the largest eigenvalue of A.

  • rtol_smooth (float) – Regularization parameter used with a similar effect to rtol but with a softer curve. See \(\epsilon\) in the formula above.

  • local_tile_size –

    Tile size for matrix A. Controls the tiling strategy for distributed computation (see NVIDIA cuSOLVER tile strategy). Defaults to A.shape[0] // jax.local_device_count() to distribute work evenly across devices, with a minimum of 64.

    The tile size determines the memory/communication tradeoff:

    • Larger tiles (closer to matrix size): Less communication overhead, more memory per device

    • Smaller tiles (64-1024): More communication, less memory per device, better load balancing

    See ArXiV:2601.14466 for details specific to NQS. For most applications, the default (matrix size / device count) works well. Adjust if you encounter memory issues (decrease local_tile_size) or excessive communication overhead (increase local_tile_size).

Returns:

(solution, None) where solution is the unraveled result.

Return type:

tuple

Example

>>> import netket as nk
>>> import jax
>>> from jax.sharding import PartitionSpec as P
>>>
>>> # For multi-GPU setup with sharding enabled
>>> solver = nk.optimizer.solver.pinv_smooth_distributed(
...     rtol=1e-14, rtol_smooth=1e-14, local_tile_size=2**12
... )
>>> driver = nk.driver.VMC_SR(
...     hamiltonian, optimizer, variational_state=vstate,
...     linear_solver=solver, diag_shift=0.01
... )