netket.jax.jacobian

Contents

netket.jax.jacobian#

netket.jax.jacobian(apply_fun, params, samples, model_state=None, *, mode, pdf=None, chunk_size=None, center=False, dense=False, _sqrt_rescale=False)[source]#

Computes the jacobian of a NN model with respect to its parameters. This function differs from jax.jacrev() because it supports models with both real and complex parameters, as well as non-holomorphic models.

In the context of NQS, if you pass the log-wavefunction to to this function, it will compute the log-derivative of the wavefunction with respect to the parameters, i.e. the matrix commonly known as:

\[O_k(\sigma) = \frac{\partial \ln \Psi(\sigma)}{\partial \theta_k}\]

This function has three modes of operation that must be specified through the mode keyword-argument:

  • mode="real": Which works for real-valued functions with real- valued parameters, or truncating the imaginary part of the function.

  • mode="complex" which always returns the correct result, but results in redundant computations if the function is holomorphic. For functions of real-parameters but complex output returns the derivatives of the real and imaginary part concatenated. If the parameters are complex the derivatives w.r.t. the real and imaginary part of the parameters are split into two different jacobians.

  • mode="holomorphic" for complex-valued, complex parametrs holomorphic functions.

Parameters:
  • apply_fun (Callable) – The function for which the jacobian should be computed. It must have the signature f: PyTree, Array -> Array where the first PyTree are the parameters with respect to which the jacobian will be computed, while the second argument is not differentiated. The second argument (samples) should be a 2D batch of inputs.

  • params (Any) – The PyTree of parameters (\(\theta\) in the equations), with repsect to which the jacobian will computed.

  • samples (Union[ndarray, Array]) – A batch of samples (\(\sigma\) in the equations), encoded in a 2D matrix where the first dimension is the batch dimension and the latter dimension encodes the different degrees of freedom. The gradient is not computed with respect to this argument.

  • model_state (Optional[Any]) – Optional model variables that are not trained/differentiated. See the jax documentation to understand how those are used.

  • mode (str) – differentiation mode, must be one of real, complex or holomorphic as quickly described above. For a detailed explanation, read the detailed discussion below.

  • pdf (Union[ndarray, Array]) – Optional coefficient that is used to multiply every row of the Jacobian. When performing calculations in full-summation, this can be used to multiply every row by \(|\psi(\sigma)|^2\), which is needed to compute the correct average.

  • chunk_size (Optional[int]) – Optional integer specifying the maximum number of samples for which the gradient is simulataneously computed. Low-values will require lower amounts of memory, but might increase computational cost (chunking is disabled by default).

  • center (bool) – a boolean specifying if the jacobian should be centered (disabled by default).

  • dense (bool) – a boolean flag (disabled by default) to specify if the jacobian should be raveled to a contiguous dense array. For real and holomorphic mode this will return a 2D matrix where the first dimension matches the number of samples (the first axis of samples), while the second dimension will match the total number of parameters. This raveling is equivalent to jax.vmap() of netket.jax.tree_ravel(), jax.vmap(nk.jax.tree_ravel, nk.jax.jacobian(...)). If using complex mode with real parameters the returned tensor has 3 dimensions, where the first and last match the other modes while the middle one has size 2, and encodes the gradient of the real and imaginary part of apply_fun. If using complex mode with complex parameters, the returned tensor has 3 dimensions, where the first has the number of samples, the second has size 2 as described above, and the last has twice the number of parameters, where the first \(N_\text{pars}\) elements are the derivatives wrt the real part of the parameters, while the second \(N_\text{pars}\) elements are the derivatives wrt the imaginary part of the paramters.

  • _sqrt_rescale (bool) – internal flag (do not rely on it) a boolean flag (disabled by default). If enabled, the jacobian is rescaled by \(1/\sqrt{N_\text{samples}}\) to match the scaling emerging in some use-cases such when building the Quantum Geometric Tensor. If a pdf is specified, the scaling will instead be \(\sqrt{pdf_i}\). This flag is temporary and internal and might be discontinued at any point in the future. Do not use it.

Return type:

Any

Extra details of the different modes are given below:

Real-valued mode (mode='real')#

This mode should be used for functions with real output or if you wish to truncate the imaginary part of the jacobian. Practically, it computes the Jacobian defined as

\[O_k(\sigma) = \frac{\partial \ln\Re[\Psi(\sigma)]}{\partial \Re[\theta_k]}\]

and it should return a result roughly equivalent to the following listing:

samples = samples.reshape(-1, samples.shape[-1])
parameters = jax.tree_util.tree_map(lambda x: x.real, parameters)
O_k = jax.jacrev(lambda pars: logpsi(pars, samples).real, parameters)

The jacobian that is returned is a PyTree with the same shape as parameters, with real data type. The Imaginary part of \(\ln\Psi(\sigma)\) is discarded if present. This mode is useful for models describing real-valued states with a sign. This coincides with the \(O_k(\sigma)\) matrix for real-valued, real-output models.

Complex-valued, non-holomorphic mode (mode='complex')#

This function computes all the information necessary to reconstruct the Jacobian and potentially the conjugate-Jacobian that is non-zero for non-holomorphic functions. It should be used for:

  • complex-valued functions with real parameters, of which we do not want to truncate the imaginary part;

  • complex-valued functions with mixed real and complex parameters, which are therefore not-holomorphic;

  • complex-valued functions with complex parameters which are not holomorphic (if the function is holomoprhic, the results will be correct but the returned data will be redundant);

If all parameters \(\theta_k\) are real, this mode returns the derivatives of the real and imaginary part of the function,

\[O^{r}_k(\sigma) = \frac{\partial \ln\Re[\Psi(\sigma)]}{\partial \theta_k} \,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\, O^{i}_k(\sigma) = \frac{\partial \ln\Im[\Psi(\sigma)]}{\partial \theta_k}\]

where \(O^{r}_k(\sigma)\) and \(O^{i}_k(\sigma)\) are real-valued pytrees with the same shape as the original parameters. In practice, it should return a result roughly equivalent to the following listing:

samples = samples.reshape(-1, samples.shape[-1])
Or_k = jax.jacrev(lambda pars: logpsi(pars, samples).real, parameters)
Oi_k = jax.jacrev(lambda pars: logpsi(pars, samples).imag, parameters)
O_k = jax.tree_util.tree_map(lambda jr, ji: jnp.concatenate([jr, ji]], axis=1),
                                                  Or_k, Oi_k)

As both Or_k and Oi_k are real, instead of concatenating we could also construct the full complex Jacobian. However, we chose not to do this for performance reason, but the downstream user is free to do it if he wishes.

If you wish to get the complex jacobian in the case of real parameters, it is possible to define

\[O_k(\sigma) = O^{r}_k(\sigma) + i O^{i}_k(\sigma)\]

which is now complex-valued. In code, this is equivalent to

O_k_cmplx = jax.tree_util.tree_map(lambda jri: jri[:, 0, :] + 1j* jri[:, 1, :], O_k)

If some parameters \(\theta_k\) are complex, this mode splits the \(N\) complex parameters into \(2N\) real parameters, where the first block of \(N\) parameters correspond to the real parts and the latter block to the imaginary part, and then follows the logic discussed above.

In formulas, this can be seen as defining the vector of \(2N\) real parameters

\[\tilde\theta = (\Re[\theta], \Im[\theta])\]

and then computing the same quantities as above

\[O^{r}_k(\sigma) = \frac{\partial \ln\Re[\Psi(\sigma)]}{\partial \tilde\theta_k]} \,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\, O^{i}_k(\sigma) = \frac{\partial \ln\Im[\Psi(\sigma)]}{\partial \tilde\theta_k]}\]

where now those objects have twice the number of elements as the parameters. In practice, it should return a result roughly equivalent to the following listing:

samples = samples.reshape(-1, samples.shape[-1])
# tree_to_real splits the parameters in a tuple like
# {'real': jax.tree.map(jnp.real, pars), 'imag': jax.tree.map(jnp.imag, pars)}
pars_real, reconstruct = nk.jax.tree_to_real(parameters)
Or_k = jax.jacrev(lambda pars_re: logpsi(reconstruct(pars_re), samples).real,
                  pars_real)
Oi_k = jax.jacrev(lambda pars_re: logpsi(reconstruct(pars_re), samples).imag,
                  pars_real)
O_k = jax.tree_util.tree_map(lambda jr, ji: jnp.concatenate([jr, ji]], axis=1),
                                                  Or_k, Oi_k)

This code is also valid if all parameters are real, in which case O_k.real is identical to what was described above. Otherwise, O_k.imag contains the derivative w.r.t. the imaginary part of the parameters which are complex. Every element in O_k has the shape (N_s, 2, ...) where \(N_{s}\) is the number of samples and 2 arises from the derivatives wrt the real and imaginary parts.

Holomorphic mode (mode='holomorphic')#

This function computes the gradient with respect to the complex parameters \(\theta\). It can only be applied to functions whose parameters are all complex-valued, and which are holomorphic (they satisfy Cauchy-Riemann equations, which can be numerically checked with is_probably_holomorphic()). This function is roughly equivalent to

samples = samples.reshape(-1, samples.shape[-1])
O_k = jax.jacrev(lambda pars: logpsi(pars, samples), parameters, holomorphic=True)

If the function is not holomorphic the result will be numerically wrong.

The returned jacobian has the same PyTree structure as the parameters, with an additional leading dimension equal to the number of samples if mode=real/holomorphic or if you have real-valued parameters and use mode=complex. If you have complex-valued parameters and use mode=complex, the returned pytree will have two leading dimensions, the first iterating along the samples, and the second with size 2, iterating along the real and imaginary part of the parameters (essentially giving the jacobian and conjugate-jacobian).

If dense is True, the returned jacobian is a dense matrix, that is somewhat similar to what would be obtained by calling jax.vmap(jax.grad(apply_fun))(parameters).

In a somewhat intransparent way this also internally splits all parameters to real in the ‘real’ and ‘complex’ modes (for C→R, R&C→R, R&C→C and general C→C) resulting in the respective ΔOⱼₖ which is only compatible with split-to-real pytree vectors