netket.jax.jacobian_default_mode

netket.jax.jacobian_default_mode#

netket.jax.jacobian_default_mode(apply_fun, pars, model_state, samples, *, holomorphic=None, warn=True)[source]#

Returns the default mode for {func}`netket.jax.jacobian` given a certain wave-function ansatz.

This function uses an abstract evaluation of the ansatz to determine if the ansatz has real or complex output, and uses that to determine the default mode to be used to compute the Jacobian.

In particular:
  • for functions with a real output, it will return RealMode.

  • for functions with a complex output, it will return: - If holomorphic==False or it not been specified, it will return ComplexMode, which will force the calculation of both the jacobian and adjoint jacobian. See the documentation of{func}`nk.jax.jacobian` for more details. - If holomorphic==True, it will compute only the complex-valued jacobian, and assumes the adjoint-jacobian to be zero.

This function will also raise an error if holomorphic is not specified but the output is complex.

Parameters:
  • apply_fun (Callable[[Any, Union[ndarray, Array]], Union[ndarray, Array]]) – A callable taking as input a pytree of parameters and the samples, and returning the output.

  • pars (Any) – The Pytree of parameters.

  • model_state (Optional[Any]) – The optional model_state, according to the flax model definition.

  • samples (Union[ndarray, Array]) – An array of samples.

  • holomorphic (Optional[bool]) – A boolean specifying whether apply_fun is holomorphic or not (None by default).

  • warn (bool) – A boolean specifying whether to raise a warning when holomorphic is not specified. For internal use only.

Return type:

JacobianMode