The Run Loop and Callback Hooks#
This page describes the internal structure of the optimization loop executed by
AbstractDriver and the hooks available to
AbstractCallback implementations.
If you only need simple per-step logic, see the legacy callback interface described in the Drivers user guide. The hook-based API described here is more powerful and lets you intervene at every stage of a step, including before and after the gradient computation, before the parameter update, and at the very start or end of a run.
The run loop#
The pseudocode below shows every point at which a callback can be invoked during
AbstractDriver.run().
The labels in angle brackets (<hook_name>) correspond directly to the methods of
AbstractCallback.
# driver.run(n_iter, ...)
callbacks.on_run_start(step, driver) # <on_run_start>
for step in range(step_count, step_count + n_iter):
step_log_data = {}
step_attempt = 0
while True: # inner loop: allows step rejection and retry
callbacks.on_step_start(step, step_log_data, driver) # <on_step_start>
driver.reset_step() # resets the sampler
callbacks.on_compute_update_start( # <on_compute_update_start>
step, step_log_data, driver
)
loss_stats, dp = driver.compute_loss_and_update()
reject = callbacks.on_compute_update_end( # <on_compute_update_end>
step, step_log_data, driver
)
if reject:
step_attempt += 1
continue # retry from on_step_start
else:
break # step accepted
# step accepted — loss is added to step_log_data
# step_count and parameters are still at the current (pre-update) values
callbacks.before_parameter_update(step, step_log_data, driver) # <before_parameter_update>
driver.update_parameters(dp)
callbacks.on_step_end(step, step_log_data, driver) # <on_step_end>
step_count += 1
callbacks.on_run_end(step_count, driver) # <on_run_end>
# If StopRun is raised anywhere, on_run_end is still called.
# If any other exception is raised, on_run_error is called instead.
# callbacks.on_run_error(step_count, error, driver) # <on_run_error>
Hook reference#
The table below summarises all hooks in calling order, their signature, and their intended use.
Hook |
Called when |
Return value |
|---|---|---|
Once, before the iteration loop |
— |
|
At the start of every step (and every retry) |
— |
|
After |
— |
|
After |
|
|
After the step is accepted; |
— |
|
After parameters have been updated |
— |
|
Once, after the loop finishes (also called on |
— |
|
When an exception (other than |
— |
Hook signatures#
def on_run_start(self, step: int, driver) -> None: ...
def on_step_start(self, step: int, log_data: dict, driver) -> None: ...
def on_compute_update_start(self, step: int, log_data: dict, driver) -> None: ...
def on_compute_update_end(self, step: int, log_data: dict, driver) -> bool: ...
def before_parameter_update(self, step: int, log_data: dict, driver) -> None: ...
def on_step_end(self, step: int, log_data: dict, driver) -> None: ...
def on_run_end(self, step: int, driver) -> None: ...
def on_run_error(self, step: int, error: Exception, driver) -> None: ...
step is the current value of driver.step_count — a monotonically increasing integer
that is not reset between consecutive calls to run().
log_data is the per-step dictionary that will eventually be passed to loggers.
Callbacks may add arbitrary keys to it.
Implementing a custom callback#
Subclass AbstractCallback and override only the hooks you need.
All hooks have no-op default implementations, so you only need to implement the ones relevant
to your use case.
import netket as nk
class MyCallback(nk.callbacks.AbstractCallback):
def on_run_start(self, step, driver):
print(f"Starting optimisation from step {step}")
def before_parameter_update(self, step, log_data, driver):
# parameters and step_count still reflect the current step
log_data["my_quantity"] = compute_something(driver.state)
def on_run_end(self, step, driver):
print(f"Finished at step {step}")
Pass the callback to run():
gs.run(n_iter=300, out="output", callback=MyCallback())
Multiple callbacks can be provided as a list:
gs.run(n_iter=300, callback=[MyCallback(), nk.callbacks.EarlyStopping(...)])
Stopping the optimisation#
To gracefully stop the run from inside a callback, raise StopRun
(or a subclass of it). The driver will call on_run_end before exiting.
class StopAfterConvergence(nk.callbacks.AbstractCallback):
def on_step_end(self, step, log_data, driver):
if is_converged(log_data):
raise nk.callbacks.StopRun("Energy converged.")
Injecting custom samples#
The on_compute_update_start hook runs after reset_step() has cleared the sample
cache but before compute_loss_and_update() has drawn any new samples. This makes it
the right place to inject a custom set of configurations that the driver will use for
that step instead of sampling from the Markov chain.
The way this works: reset_step() calls driver.state.reset(), which sets the internal
cache driver.state._samples = None. When compute_loss_and_update() later accesses
driver.state.samples, the property sees that _samples is None and draws new samples.
If you assign to _samples before that point, the property returns your array directly
and no MCMC sampling takes place.
class FixedSamplesCallback(nk.callbacks.AbstractCallback):
def __init__(self, samples):
self.samples = samples # pre-computed array of configurations
def on_compute_update_start(self, step, log_data, driver):
# Bypass MCMC: use the same fixed samples at every step.
driver.state._samples = self.samples
Rejecting a step#
on_compute_update_end() may return True to
signal that the current step should be discarded and recomputed from on_step_start.
This can be used, for example, to implement an accept/reject scheme based on the computed
loss or gradient norm.
class RejectBadSteps(nk.callbacks.AbstractCallback):
def on_compute_update_end(self, step, log_data, driver):
grad_norm = compute_norm(driver._dp)
if grad_norm > 1e6:
return True # reject and retry
return False
The number of retries for the current step is available as driver._step_attempt.
Callback ordering#
When multiple callbacks are active, they are sorted by
callback_order (ascending) before every hook is
called. The default value is 0; built-in loggers use 10 so that user callbacks run
first and can populate log_data before it is written to disk.
This ordering is especially relevant for before_parameter_update: callbacks that estimate
observables (order 0) will always run before loggers that snapshot and write log_data
(order 10), guaranteeing that the logged data is complete.
Override callback_order to control the relative execution order of your callback:
class EarlyCallback(nk.callbacks.AbstractCallback):
@property
def callback_order(self):
return -10 # runs before everything else
Relationship to legacy callbacks#
Legacy callbacks — plain callables with the signature
(step, log_data, driver) -> bool — are transparently wrapped and invoked during
before_parameter_update (at callback_order=0). They can still be used, but the
hook-based API described on this page gives access to more stages of the loop and is
recommended for new code.