netket.jax.vjp_chunked#
- netket.jax.vjp_chunked(fun, *primals, has_aux=False, chunk_argnums=(), chunk_size=None, nondiff_argnums=(), return_forward=False, conjugate=False)[source]#
calculate the vjp in small chunks for a function where the leading dimension of the output only depends on the leading dimension of some of the arguments
Note
If experimental sharing is activated, the chunk_argnums are assumed to be sharded (not replicated) among devices.
- Parameters:
fun – Function to be differentiated. It must accept chunks of size chunk_size of the primals in chunk_argnums.
primals – A sequence of primal values at which the Jacobian of
fun
should be evaluated.has_aux – Optional, bool. Only False is implemented. Indicates whether
fun
returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.chunk_argnums – an integer or tuple of integers indicating the primals which should be chunked. The leading dimension of each of the primals indicated must be the same as the output of fun.
chunk_size – an integer indicating the size of the chunks over which the vjp is computed. It must be a integer divisor of the primals specified in chunk_argnums.
nondiff_argnums – an integer or tuple of integers indicating the primals which should not be differentiated with. Specifying the arguments which are not needed should increase performance.
return_forward – whether the returned function should also return the output of the forward pass
- Returns:
a function corresponding to the vjp_fun returned by an equivalent
jax.vjp(fun, *primals)[1]`
call which computes the vjp in chunks (recomputing the forward pass every time on subsequent calls). If return_forward=True the vjp_fun returned returns a tuple containing the output of the forward pass and the vjp.
Example
>>> import jax >>> from netket.jax import vjp_chunked >>> from functools import partial >>> >>> @partial(jax.vmap, in_axes=(None, 0)) ... def f(p, x): ... return jax.lax.log(p.dot(jax.lax.sin(x))) >>> >>> k = jax.random.split(jax.random.PRNGKey(123), 4) >>> p = jax.random.uniform(k[0], shape=(8,)) >>> v = jax.random.uniform(k[1], shape=(8,)) >>> X = jax.random.uniform(k[2], shape=(1024,8)) >>> w = jax.random.uniform(k[3], shape=(1024,)) >>> >>> vjp_fun_chunked = vjp_chunked(f, p, X, chunk_argnums=(1,), chunk_size=32, nondiff_argnums=1) >>> vjp_fun = jax.vjp(f, p, X)[1] >>> >>> vjp_fun_chunked(w) (Array([106.76358917, 113.3123931 , 101.95475061, 104.11138622, 111.95590131, 109.17531467, 108.97138052, 106.89249739], dtype=float64),) >>> vjp_fun(w)[:1] (Array([106.76358917, 113.3123931 , 101.95475061, 104.11138622, 111.95590131, 109.17531467, 108.97138052, 106.89249739], dtype=float64),)