netket.optimizer.solver.nan_fallback

Contents

netket.optimizer.solver.nan_fallback#

netket.optimizer.solver.nan_fallback(primary_solver, fallback_solver)[source]#

Creates a solver that transparently falls back to a more robust solver when the primary produces NaN or Inf.

The primary solver always runs. If its output contains NaN or Inf (or the right-hand side b contains NaN), the fallback solver is invoked instead via jax.lax.cond(), so it only executes at runtime when needed.

The returned info dict always contains a solver_fallback key indicating whether the fallback was activated. Any additional entries from the primary solver’s info are also included. Info from the fallback solver is not included, to avoid running it unconditionally.

The returned solver supports equality and hashing, so that two calls with the same solvers produce equal objects and do not trigger JAX recompilation.

Parameters:
  • primary_solver – The preferred (usually faster) solver.

  • fallback_solver – The robust solver to use when the primary fails.

Returns:

A new solver function with the same signature as the inputs.

Example

Create a Cholesky solver that falls back to the pseudo-inverse when numerical issues arise:

>>> solver = nan_fallback(cholesky, pinv_smooth(rtol=1e-6))
>> x, info = solver(A, b)
>> if info["solver_fallback"]:
...     print("Cholesky failed, used pinv_smooth instead")