netket.jax.jacobian
netket.jax.jacobian#
- netket.jax.jacobian(apply_fun, params, samples, model_state=None, *, mode, pdf=None, chunk_size=None, center=False, dense=False)[source]#
Computes the jacobian of a NN model with respect to its parameters. This function differs from jax.jac_bwd because it supports models with both real and complex parameters, as well as non-holomorphic models.
In the context of NQS, if you pass the log-wavefunction to to this function, it will compute the log-derivative of the wavefunction with respect to the parameters, i.e. the matrix commonly known as:
\[O_k(\sigma) = \frac{\partial \ln \Psi(\sigma)}{\partial \theta_k}\]- This function has three modes of operation that must be specified through the
mode
keyword-argument: mode="real"
: The jacobian that is returned is real. The Imaginary part of\(\ln\Psi(\sigma)\) is discarded if present. This mode is useful for models describing real-valued states with a sign. This coincides with the \(O_k(\sigma)\) matrix for real-valued, real-output models.
mode="complex"
: The jacobian that is returned is complex. This mode returns the standard\(O_k(\sigma)\) matrix for real-parameters, complex-output models. If your model has complex parameters and it is not holomorphic, you should use this mode as well. In that case, it will split the jacobian and conjugate-jacobian into two different objects by splitting the real and imaginary part of the parameters.
mode="holomorphic"
: returns correct results only if your model is holomorphic. Works likemode="real"
, but returns a complex valued jacobian.
The returned jacobian has the same PyTree structure as the parameters, with an additional leading dimension equal to the number of samples if
mode=real/holomorphic
or if you have real-valued parameters and usemode=complex
. If you have complex-valued parameters and usemode=complex
, the returned pytree will have two leading dimensions, the first iterating along the samples, and the second with size 2, iterating along the real and imaginary part of the parameters (essentially giving the jacobian and conjugate-jacobian).If dense is True, the returned jacobian is a dense matrix, that is somewhat similar to what would be obtained by calling
jax.vmap(jax.grad(apply_fun))(parameters)
.In a somewhat intransparent way this also internally splits all parameters to real in the ‘real’ and ‘complex’ modes (for C→R, R&C→R, R&C→C and general C→C) resulting in the respective ΔOⱼₖ which is only compatible with split-to-real pytree vectors
- Parameters
apply_fun (
Callable
) – The forward pass of the Ansatzmodel_state (
Optional
[Any
]) – untrained state parameters of the modelparams (
Any
) – a pytree of parameters psamples (
Union
[ndarray
,Array
]) – an array of (n in total) batched samples σmode (
str
) – differentiation mode, must be one of ‘real’, ‘complex’, ‘holomorphic’, real as described above.pdf (
Union
[ndarray
,Array
,None
]) – |ψ(x)|^2 if exact optimization is being used else Nonechunk_size (
Optional
[int
]) – an int specifying the size of the chunks the gradient should be computed in (default: None)center (
bool
) – a boolean specifying if the jacobian should be centered.dense (bool) –
- Return type
- This function has three modes of operation that must be specified through the