netket.optimizer.solver.pinv_smooth_distributed#
- netket.optimizer.solver.pinv_smooth_distributed(A, b, *, rtol=1e-14, rtol_smooth=1e-14, local_tile_size=None, x0=None)[source]#
Solve the linear system using a distributed pseudo-inverse from eigendecomposition.
This is the distributed/multi-GPU equivalent of
pinv_smooth(). It uses jaxmg’s syevd function (cuSOLVERMg SYEVD) to compute eigenvalues and eigenvectors of a symmetric/Hermitian matrix in a distributed manner, then applies the same smoothed regularization aspinv_smooth().The eigenvalues \(\lambda_i\) smaller than \(r_\textrm{tol} \lambda_\textrm{max}\) are truncated (where \(\lambda_\textrm{max}\) is the largest eigenvalue).
The eigenvalues are further smoothed with another filter, originally introduced in Medvidovic, Sels arXiv:2212.11289 (2022), given by the following equation
\[\tilde\lambda_i^{-1}=\frac{\lambda_i^{-1}}{1+\big(\epsilon\frac{\lambda_\textrm{max}}{\lambda_i}\big)^6}\]Note
This solver requires jaxmg package to be installed.
- Parameters:
A – the matrix A in Ax=b (should be symmetric/Hermitian, sharded)
b – the vector b in Ax=b
rtol (
float) – Relative tolerance for small eigenvalues ofA. For the purposes of rank determination, eigenvalues are treated as zero if they are smaller than rtol times the largest eigenvalue ofA.rtol_smooth (
float) – Regularization parameter used with a similar effect to rtol but with a softer curve. See \(\epsilon\) in the formula above.local_tile_size –
Tile size for matrix A. Controls the tiling strategy for distributed computation (see NVIDIA cuSOLVER tile strategy). Defaults to
A.shape[0] // jax.local_device_count()to distribute work evenly across devices, with a minimum of 64.The tile size determines the memory/communication tradeoff:
Larger tiles (closer to matrix size): Less communication overhead, more memory per device
Smaller tiles (64-1024): More communication, less memory per device, better load balancing
See ArXiV:2601.14466 for details specific to NQS. For most applications, the default (matrix size / device count) works well. Adjust if you encounter memory issues (decrease local_tile_size) or excessive communication overhead (increase local_tile_size).
- Returns:
(solution, None) where solution is the unraveled result.
- Return type:
Example
>>> import netket as nk >>> import jax >>> from jax.sharding import PartitionSpec as P >>> >>> # For multi-GPU setup with sharding enabled >>> solver = nk.optimizer.solver.pinv_smooth_distributed( ... rtol=1e-14, rtol_smooth=1e-14, local_tile_size=2**12 ... ) >>> driver = nk.driver.VMC_SR( ... hamiltonian, optimizer, variational_state=vstate, ... linear_solver=solver, diag_shift=0.01 ... )