netket.callbacks.AutoChunkSize#

class netket.callbacks.AutoChunkSize[source]#

Bases: AbstractCallback

Automatically tunes the chunk size for the (i) sampler, (ii) forward pass and (iii) backward pass of a driver and a variational state in order to avoid OOM errors.

This callback is useful when the optimal chunk size is not known in advance, and the user wants to avoid manually tuning it.

The callback will try to run the sampler, forward pass and backward pass with decreasing chunk sizes until it finds one that fits in memory. The chunk size is halved on each failure, down to minimum_chunk_size.

After the largest chunk size that works is found, the callback will print those values and store them in the callback object. If you reuse the same callback object in a subsequent run, it will reuse the same chunk sizes.

Supporting custom drivers

To tune the forward pass, the callback needs to know which operator to use to estimate the loss. This is determined by the get_forward_operator() dispatch function, which is already registered for the built-in drivers (VMC, VMC_SR, Infidelity_SR).

If you implement a custom driver, you must register a new overload so that AutoChunkSize knows how to probe the forward pass:

from netket._src.callbacks.auto_chunk_size import get_forward_operator

class MyDriver(AbstractOptimizationDriver):
    def __init__(self, hamiltonian, ...):
        self._ham = hamiltonian
        ...

@get_forward_operator.dispatch
def _(driver: MyDriver):
    return driver._ham

The overload should return an operator whose expectation value exercises the same forward-pass code path as the actual optimisation step. The registration is best placed at the bottom of the file that defines the driver class, so it takes effect as soon as the driver is imported.

Inheritance
Inheritance diagram of netket.callbacks.AutoChunkSize
__init__(sampler_chunk_size=None, chunk_size=None, chunk_size_bwd=None, minimum_chunk_size=1)[source]#

Initialize the callback.

Parameters:
  • sampler_chunk_size (int | None) – The initial chunk size to use for the sampler (default: None). If None, the callback will try to find a chunk size that works starting from the number of chains per rank.

  • chunk_size (int | None) – The initial chunk size to use for the forward pass (default: None). If None, the callback will try to find a chunk size that works starting from the number of samples.

  • chunk_size_bwd (int | None) – The initial chunk size to use for the backward pass (default: None). If None, the callback will try to find a chunk size that works starting from the number of samples.

  • minimum_chunk_size (int) – The minimum chunk size to use (default: 1). If the chunk size is decreased to this value and it still fails, the callback will raise an error.

Attributes
callback_order#

An integer representing the order in which this callback should be called.

Lower numbers are called first, and higher numbers are called later.

This can be redefined in subclasses to change the order in which callbacks are called. (Default: 0, for all callbacks, 10 for loggers).

minimum_chunk_size: int#
sampler_chunk_size: int | None#
chunk_size: int | None#
chunk_size_bwd: int | None#
Methods
before_parameter_update(step, log_data, driver)[source]#

Called after all update logic has been computed and the step has been accepted, but before the driver applies the parameter update.

At this point:

  • The loss and its gradient have been computed by compute_loss_and_update().

  • The step has been accepted (not rejected by on_compute_update_end()).

  • driver.step_count still refers to the current step — it has not yet been incremented.

  • The variational state parameters have not yet changed.

This is the right place to estimate additional observables, add data to log_data, or take a snapshot of the state for logging. Callbacks with a lower callback_order run first, so observables callbacks (order 0) are guaranteed to populate log_data before logger callbacks (order 10) read it.

on_compute_update_end(step, log_data, driver)[source]#

Callback called at the end of the compute update phase, after computing the loss and its gradient.

This is called before the parameters are updated, so it can be used to implement custom logic for rejecting a step based on the computed loss or gradient.

Return type:

bool

Returns:

A boolean indicating whether to reject the step (i.e. repeat it with the same parameters). If it returns None, it is treated as False.

on_compute_update_start(step, log_data, driver)[source]#
on_run_end(step, driver)[source]#
on_run_error(step, error, driver)[source]#
on_run_start(step, driver)[source]#
on_step_end(step, log_data, driver)[source]#
on_step_start(step, log_data, driver)[source]#
replace(**kwargs)[source]#

Replace the values of the fields of the object with the values of the keyword arguments. If the object is a dataclass, dataclasses.replace will be used. Otherwise, a new object will be created with the same type as the original object.

Return type:

TypeVar(P, bound= Pytree)

Parameters:
  • self (P)

  • kwargs (Any)