netket.optimizer.solver.cholesky_distributed#
- netket.optimizer.solver.cholesky_distributed(A, b, *, local_tile_size=None, x0=None)[source]#
Solve the linear system using a distributed Cholesky factorization.
This is the distributed/multi-GPU equivalent of
cholesky(). It uses jaxmg’s potrs function (LAPACK POTRF+POTRS) which performs a Cholesky factorization followed by triangular solves, optimized for sharded arrays with automatic tiling and communication optimization.Note
This solver requires jaxmg package to be installed.
Note
This solver expects the input matrix A to be sharded with P(“S”, None).
For single-device or small-scale computations, use
cholesky()instead, which has less overhead. Use cholesky_distributed when:Running on multiple GPUs with sharded arrays and >= 8k samples.
NTK/QGT matrix is very large and doesn’t fit on a single device
- Parameters:
A – the matrix A in Ax=b (should be positive definite, sharded)
b – the vector b in Ax=b
local_tile_size –
Tile size for matrix A. Controls the tiling strategy for distributed computation (see NVIDIA cuSOLVER tile strategy). Defaults to
A.shape[0] // jax.local_device_count()to distribute work evenly across devices, with a minimum of 64.The tile size determines the memory/communication tradeoff:
Larger tiles (closer to matrix size): Less communication overhead, more memory per device
Smaller tiles (64-1024): More communication, less memory per device, better load balancing
See ArXiV:2601.14466 for details specific to NQS. For most applications, the default (matrix size / device count) works well. Adjust if you encounter memory issues (decrease local_tile_size) or excessive communication overhead (increase local_tile_size).
x0 – unused (kept for API compatibility)
- Returns:
(solution, None) where solution is the unraveled result.
- Return type:
Example
>>> import netket as nk >>> import jax >>> from jax.sharding import PartitionSpec as P >>> >>> # For multi-GPU setup with sharding enabled >>> solver = nk.optimizer.solver.cholesky_distributed(local_tile_size=2**12) >>> driver = nk.driver.VMC_SR( ... hamiltonian, optimizer, variational_state=vstate, ... linear_solver=solver, diag_shift=0.01 ... )