netket.jax.vjp_chunked

Contents

netket.jax.vjp_chunked#

netket.jax.vjp_chunked(fun, *primals, has_aux=False, chunk_argnums=(), chunk_size=None, nondiff_argnums=(), return_forward=False, conjugate=False)[source]#

calculate the vjp in small chunks for a function where the leading dimension of the output only depends on the leading dimension of some of the arguments

Note

If experimental sharing is activated, the chunk_argnums are assumed to be sharded (not replicated) among devices.

Parameters:
  • fun – Function to be differentiated. It must accept chunks of size chunk_size of the primals in chunk_argnums.

  • primals – A sequence of primal values at which the Jacobian of fun should be evaluated.

  • has_aux – Optional, bool. Only False is implemented. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

  • chunk_argnums – an integer or tuple of integers indicating the primals which should be chunked. The leading dimension of each of the primals indicated must be the same as the output of fun.

  • chunk_size – an integer indicating the size of the chunks over which the vjp is computed. It must be a integer divisor of the primals specified in chunk_argnums.

  • nondiff_argnums – an integer or tuple of integers indicating the primals which should not be differentiated with. Specifying the arguments which are not needed should increase performance.

  • return_forward – whether the returned function should also return the output of the forward pass

Returns:

a function corresponding to the vjp_fun returned by an equivalent jax.vjp(fun, *primals)[1]` call which computes the vjp in chunks (recomputing the forward pass every time on subsequent calls). If return_forward=True the vjp_fun returned returns a tuple containing the output of the forward pass and the vjp.

Example

>>> import jax
>>> from netket.jax import vjp_chunked
>>> from functools import partial
>>>
>>> @partial(jax.vmap, in_axes=(None, 0))
... def f(p, x):
...     return jax.lax.log(p.dot(jax.lax.sin(x)))
>>>
>>> k = jax.random.split(jax.random.PRNGKey(123), 4)
>>> p = jax.random.uniform(k[0], shape=(8,))
>>> v = jax.random.uniform(k[1], shape=(8,))
>>> X = jax.random.uniform(k[2], shape=(1024,8))
>>> w = jax.random.uniform(k[3], shape=(1024,))
>>>
>>> vjp_fun_chunked = vjp_chunked(f, p, X, chunk_argnums=(1,), chunk_size=32, nondiff_argnums=1)
>>> vjp_fun = jax.vjp(f, p, X)[1]
>>>
>>> vjp_fun_chunked(w)
(Array([106.76358917, 113.3123931 , 101.95475061, 104.11138622,
              111.95590131, 109.17531467, 108.97138052, 106.89249739],            dtype=float64),)
>>> vjp_fun(w)[:1]
(Array([106.76358917, 113.3123931 , 101.95475061, 104.11138622,
              111.95590131, 109.17531467, 108.97138052, 106.89249739],            dtype=float64),)