netket.jax.jacobian#

netket.jax.jacobian(apply_fun, params, samples, model_state=None, *, mode, pdf=None, chunk_size=None, center=False, dense=False)[source]#

Computes the jacobian of a NN model with respect to its parameters. This function differs from jax.jacrev because it supports models with both real and complex parameters, as well as non-holomorphic models.

In the context of NQS, if you pass the log-wavefunction to to this function, it will compute the log-derivative of the wavefunction with respect to the parameters, i.e. the matrix commonly known as:

\[O_k(\sigma) = \frac{\partial \ln \Psi(\sigma)}{\partial \theta_k}\]
This function has three modes of operation that must be specified through the mode keyword-argument:
  • mode="real": The jacobian that is returned is real. The Imaginary part of

    \(\ln\Psi(\sigma)\) is discarded if present. This mode is useful for models describing real-valued states with a sign. This coincides with the \(O_k(\sigma)\) matrix for real-valued, real-output models.

  • mode="complex": The jacobian that is returned is complex. This mode returns the standard

    \(O_k(\sigma)\) matrix for real-parameters, complex-output models. If your model has complex parameters and it is not holomorphic, you should use this mode as well. In that case, it will split the jacobian and conjugate-jacobian into two different objects by splitting the real and imaginary part of the parameters.

  • mode="holomorphic": returns correct results only if your model is holomorphic. Works like

    mode="real", but returns a complex valued jacobian.

The returned jacobian has the same PyTree structure as the parameters, with an additional leading dimension equal to the number of samples if mode=real/holomorphic or if you have real-valued parameters and use mode=complex. If you have complex-valued parameters and use mode=complex, the returned pytree will have two leading dimensions, the first iterating along the samples, and the second with size 2, iterating along the real and imaginary part of the parameters (essentially giving the jacobian and conjugate-jacobian).

If dense is True, the returned jacobian is a dense matrix, that is somewhat similar to what would be obtained by calling jax.vmap(jax.grad(apply_fun))(parameters).

In a somewhat intransparent way this also internally splits all parameters to real in the ‘real’ and ‘complex’ modes (for C→R, R&C→R, R&C→C and general C→C) resulting in the respective ΔOⱼₖ which is only compatible with split-to-real pytree vectors

Parameters:
  • apply_fun (Callable) – The forward pass of the Ansatz

  • model_state (Optional[Any]) – untrained state parameters of the model

  • params (Any) – a pytree of parameters p

  • samples (Union[ndarray, Array]) – an array of (n in total) batched samples σ

  • mode (str) – differentiation mode, must be one of ‘real’, ‘complex’, ‘holomorphic’, real as described above.

  • pdf (Union[ndarray, Array, None]) – |ψ(x)|^2 if exact optimization is being used else None

  • chunk_size (Optional[int]) – an int specifying the size of the chunks the gradient should be computed in (default: None)

  • center (bool) – a boolean specifying if the jacobian should be centered.

  • dense (bool) –

Return type:

Any