netket.optimizer.solver.cholesky_distributed

netket.optimizer.solver.cholesky_distributed#

netket.optimizer.solver.cholesky_distributed(A, b, *, local_tile_size=None, x0=None)[source]#

Solve the linear system using a distributed Cholesky factorization.

This is the distributed/multi-GPU equivalent of cholesky(). It uses jaxmg’s potrs function (LAPACK POTRF+POTRS) which performs a Cholesky factorization followed by triangular solves, optimized for sharded arrays with automatic tiling and communication optimization.

Note

This solver requires jaxmg package to be installed.

Note

This solver expects the input matrix A to be sharded with P(“S”, None).

For single-device or small-scale computations, use cholesky() instead, which has less overhead. Use cholesky_distributed when:

  • Running on multiple GPUs with sharded arrays and >= 8k samples.

  • NTK/QGT matrix is very large and doesn’t fit on a single device

Parameters:
  • A – the matrix A in Ax=b (should be positive definite, sharded)

  • b – the vector b in Ax=b

  • local_tile_size

    Tile size for matrix A. Controls the tiling strategy for distributed computation (see NVIDIA cuSOLVER tile strategy). Defaults to A.shape[0] // jax.local_device_count() to distribute work evenly across devices, with a minimum of 64.

    The tile size determines the memory/communication tradeoff:

    • Larger tiles (closer to matrix size): Less communication overhead, more memory per device

    • Smaller tiles (64-1024): More communication, less memory per device, better load balancing

    See ArXiV:2601.14466 for details specific to NQS. For most applications, the default (matrix size / device count) works well. Adjust if you encounter memory issues (decrease local_tile_size) or excessive communication overhead (increase local_tile_size).

  • x0 – unused (kept for API compatibility)

Returns:

(solution, None) where solution is the unraveled result.

Return type:

tuple

Example

>>> import netket as nk
>>> import jax
>>> from jax.sharding import PartitionSpec as P
>>>
>>> # For multi-GPU setup with sharding enabled
>>> solver = nk.optimizer.solver.cholesky_distributed(local_tile_size=2**12)
>>> driver = nk.driver.VMC_SR(
...     hamiltonian, optimizer, variational_state=vstate,
...     linear_solver=solver, diag_shift=0.01
... )