The Run Loop and Callback Hooks#

This page describes the internal structure of the optimization loop executed by AbstractDriver and the hooks available to AbstractCallback implementations.

If you only need simple per-step logic, see the legacy callback interface described in the Drivers user guide. The hook-based API described here is more powerful and lets you intervene at every stage of a step, including before and after the gradient computation, before the parameter update, and at the very start or end of a run.

The run loop#

The pseudocode below shows every point at which a callback can be invoked during AbstractDriver.run(). The labels in angle brackets (<hook_name>) correspond directly to the methods of AbstractCallback.

# driver.run(n_iter, ...)

callbacks.on_run_start(step, driver)                          # <on_run_start>

for step in range(step_count, step_count + n_iter):
    step_log_data = {}
    step_attempt = 0

    while True:  # inner loop: allows step rejection and retry
        callbacks.on_step_start(step, step_log_data, driver)  # <on_step_start>

        driver.reset_step()  # resets the sampler

        callbacks.on_compute_update_start(                     # <on_compute_update_start>
            step, step_log_data, driver
        )
        loss_stats, dp = driver.compute_loss_and_update()

        reject = callbacks.on_compute_update_end(              # <on_compute_update_end>
            step, step_log_data, driver
        )
        if reject:
            step_attempt += 1
            continue   # retry from on_step_start
        else:
            break      # step accepted

    # step accepted — loss is added to step_log_data
    # step_count and parameters are still at the current (pre-update) values
    callbacks.before_parameter_update(step, step_log_data, driver)  # <before_parameter_update>
    driver.update_parameters(dp)

    callbacks.on_step_end(step, step_log_data, driver)        # <on_step_end>
    step_count += 1

callbacks.on_run_end(step_count, driver)                      # <on_run_end>

# If StopRun is raised anywhere, on_run_end is still called.
# If any other exception is raised, on_run_error is called instead.
# callbacks.on_run_error(step_count, error, driver)           # <on_run_error>

Hook reference#

The table below summarises all hooks in calling order, their signature, and their intended use.

Hook

Called when

Return value

on_run_start()

Once, before the iteration loop

—

on_step_start()

At the start of every step (and every retry)

—

on_compute_update_start()

After reset_step(), before compute_loss_and_update()

—

on_compute_update_end()

After compute_loss_and_update(), before accepting the step

bool — if True, the step is retried

before_parameter_update()

After the step is accepted; step_count and parameters are still at their current values

—

on_step_end()

After parameters have been updated

—

on_run_end()

Once, after the loop finishes (also called on StopRun)

—

on_run_error()

When an exception (other than StopRun) terminates the loop

—

Hook signatures#

def on_run_start(self, step: int, driver) -> None: ...

def on_step_start(self, step: int, log_data: dict, driver) -> None: ...

def on_compute_update_start(self, step: int, log_data: dict, driver) -> None: ...

def on_compute_update_end(self, step: int, log_data: dict, driver) -> bool: ...

def before_parameter_update(self, step: int, log_data: dict, driver) -> None: ...

def on_step_end(self, step: int, log_data: dict, driver) -> None: ...

def on_run_end(self, step: int, driver) -> None: ...

def on_run_error(self, step: int, error: Exception, driver) -> None: ...

step is the current value of driver.step_count — a monotonically increasing integer that is not reset between consecutive calls to run().

log_data is the per-step dictionary that will eventually be passed to loggers. Callbacks may add arbitrary keys to it.

Implementing a custom callback#

Subclass AbstractCallback and override only the hooks you need. All hooks have no-op default implementations, so you only need to implement the ones relevant to your use case.

import netket as nk

class MyCallback(nk.callbacks.AbstractCallback):

    def on_run_start(self, step, driver):
        print(f"Starting optimisation from step {step}")

    def before_parameter_update(self, step, log_data, driver):
        # parameters and step_count still reflect the current step
        log_data["my_quantity"] = compute_something(driver.state)

    def on_run_end(self, step, driver):
        print(f"Finished at step {step}")

Pass the callback to run():

gs.run(n_iter=300, out="output", callback=MyCallback())

Multiple callbacks can be provided as a list:

gs.run(n_iter=300, callback=[MyCallback(), nk.callbacks.EarlyStopping(...)])

Stopping the optimisation#

To gracefully stop the run from inside a callback, raise StopRun (or a subclass of it). The driver will call on_run_end before exiting.

class StopAfterConvergence(nk.callbacks.AbstractCallback):
    def on_step_end(self, step, log_data, driver):
        if is_converged(log_data):
            raise nk.callbacks.StopRun("Energy converged.")

Injecting custom samples#

The on_compute_update_start hook runs after reset_step() has cleared the sample cache but before compute_loss_and_update() has drawn any new samples. This makes it the right place to inject a custom set of configurations that the driver will use for that step instead of sampling from the Markov chain.

The way this works: reset_step() calls driver.state.reset(), which sets the internal cache driver.state._samples = None. When compute_loss_and_update() later accesses driver.state.samples, the property sees that _samples is None and draws new samples. If you assign to _samples before that point, the property returns your array directly and no MCMC sampling takes place.

class FixedSamplesCallback(nk.callbacks.AbstractCallback):
    def __init__(self, samples):
        self.samples = samples  # pre-computed array of configurations

    def on_compute_update_start(self, step, log_data, driver):
        # Bypass MCMC: use the same fixed samples at every step.
        driver.state._samples = self.samples

Note

_samples is an internal field of MCState (note the leading underscore). It is available as a supported extension point here, but the rest of the public API — in particular samples, sample(), and reset() — should still be used for all other interactions with the variational state.

Rejecting a step#

on_compute_update_end() may return True to signal that the current step should be discarded and recomputed from on_step_start. This can be used, for example, to implement an accept/reject scheme based on the computed loss or gradient norm.

class RejectBadSteps(nk.callbacks.AbstractCallback):
    def on_compute_update_end(self, step, log_data, driver):
        grad_norm = compute_norm(driver._dp)
        if grad_norm > 1e6:
            return True   # reject and retry
        return False

The number of retries for the current step is available as driver._step_attempt.

Callback ordering#

When multiple callbacks are active, they are sorted by callback_order (ascending) before every hook is called. The default value is 0; built-in loggers use 10 so that user callbacks run first and can populate log_data before it is written to disk.

This ordering is especially relevant for before_parameter_update: callbacks that estimate observables (order 0) will always run before loggers that snapshot and write log_data (order 10), guaranteeing that the logged data is complete.

Override callback_order to control the relative execution order of your callback:

class EarlyCallback(nk.callbacks.AbstractCallback):
    @property
    def callback_order(self):
        return -10   # runs before everything else

Relationship to legacy callbacks#

Legacy callbacks — plain callables with the signature (step, log_data, driver) -> bool — are transparently wrapped and invoked during before_parameter_update (at callback_order=0). They can still be used, but the hook-based API described on this page gives access to more stages of the loop and is recommended for new code.