netket.callbacks.AbstractCallback#
- class netket.callbacks.AbstractCallback[source]#
Bases:
PytreeAbstract base class for callbacks in advanced variational drivers.
This class is a Pytree, so it can be used with JAX transformations and automatically handles serialisation, but fields must be declared with struct.field(pytree_node=False) as class attributes.
Subclass this class and override any of the hook methods to inject custom logic at specific points of the optimisation loop. All hook methods have no-op default implementations, so you only need to override the ones you need.
To stop the optimisation early from inside any hook, raise
StopRun(or a subclass of it). The driver will catch it, callon_run_end()on all callbacks, and return normally.For a full description of the run loop structure and every available hook, including pseudocode showing exactly when each hook is called, see The Run Loop and Callback Hooks.
- Inheritance

- Attributes
- callback_order#
An integer representing the order in which this callback should be called.
Lower numbers are called first, and higher numbers are called later.
This can be redefined in subclasses to change the order in which callbacks are called. (Default: 0, for all callbacks, 10 for loggers).
- Methods
- before_parameter_update(step, log_data, driver)[source]#
Called after all update logic has been computed and the step has been accepted, but before the driver applies the parameter update.
At this point:
The loss and its gradient have been computed by
compute_loss_and_update().The step has been accepted (not rejected by
on_compute_update_end()).driver.step_countstill refers to the current step — it has not yet been incremented.The variational state parameters have not yet changed.
This is the right place to estimate additional observables, add data to
log_data, or take a snapshot of the state for logging. Callbacks with a lowercallback_orderrun first, so observables callbacks (order 0) are guaranteed to populatelog_databefore logger callbacks (order 10) read it.
- on_compute_update_end(step, log_data, driver)[source]#
Callback called at the end of the compute update phase, after computing the loss and its gradient.
This is called before the parameters are updated, so it can be used to implement custom logic for rejecting a step based on the computed loss or gradient.
- Return type:
- Returns:
A boolean indicating whether to reject the step (i.e. repeat it with the same parameters). If it returns None, it is treated as False.