netket.logging.SaveVariationalState#

class netket.logging.SaveVariationalState[source]#

Bases: AbstractCallback

Callback to save the variational state at fixed intervals. This callback uses the nqxpack package to save the variational state, which allows for portable saving of the variational state.

If you have problems with the saving/loading of the variational state, open an issue over at nqxpack.

Warning

This callback requires the nqxpack package to be installed. You can install it with uv add nqxpack. Note that this package is not required by default, and is not installed with the main NetKet package.

Warning

A limitation of nqxpack is that it can only save variational states with models defined inside of a package that can be imported. This callback is not compatible with models defined within a script or a notebook, as those cannot be imported by nqxpack.

To load a saved state you can use the nqxpack.load function, which will return a NetKet variational state object.

Example usage:
>>> import netket as nk
>>> import nqxpack
>>> ...
>>> driver.run(
...     n_iter=50,
...     out="test",
...     callback=nk.callbacks.SaveVariationalStateCallback(path="optimization", interval=10),
... )
>>> nqxpack.load("optimization/state_00010.nk")
Inheritance
Inheritance diagram of netket.logging.SaveVariationalState
__init__(path, interval, *, file_name_root='state')[source]#

Constructs the callback to save the variational state at fixed intervals.

The variational state is saved every interval iterations in the directory specified by path. The file name of the saved state will be of the form {file_name_root}_{step:05d}.nk, where step is the iteration number at which the state was saved.

Parameters:
  • path (str | Path) – The path where to save the variational state.

  • interval (float | int, optional) – The interval at which to save the variational state

  • file_name_root (str, optional) – The root of the file name to save the variational state. Defaults to β€œstate”.

Attributes
callback_order#

An integer representing the order in which this callback should be called.

Lower numbers are called first, and higher numbers are called later.

This can be redefined in subclasses to change the order in which callbacks are called. (Default: 0, for all callbacks, 10 for loggers).

Methods
before_parameter_update(step, log_data, driver)[source]#

Called after all update logic has been computed and the step has been accepted, but before the driver applies the parameter update.

At this point:

  • The loss and its gradient have been computed by compute_loss_and_update().

  • The step has been accepted (not rejected by on_compute_update_end()).

  • driver.step_count still refers to the current step β€” it has not yet been incremented.

  • The variational state parameters have not yet changed.

This is the right place to estimate additional observables, add data to log_data, or take a snapshot of the state for logging. Callbacks with a lower callback_order run first, so observables callbacks (order 0) are guaranteed to populate log_data before logger callbacks (order 10) read it.

on_compute_update_end(step, log_data, driver)[source]#

Callback called at the end of the compute update phase, after computing the loss and its gradient.

This is called before the parameters are updated, so it can be used to implement custom logic for rejecting a step based on the computed loss or gradient.

Return type:

bool

Returns:

A boolean indicating whether to reject the step (i.e. repeat it with the same parameters). If it returns None, it is treated as False.

on_compute_update_start(step, log_data, driver)[source]#
on_run_end(step, driver)[source]#
on_run_error(step, error, driver)[source]#
on_run_start(step, driver)[source]#
on_step_end(step, log_data, driver)[source]#
on_step_start(step, log_data, driver)[source]#
replace(**kwargs)[source]#

Replace the values of the fields of the object with the values of the keyword arguments. If the object is a dataclass, dataclasses.replace will be used. Otherwise, a new object will be created with the same type as the original object.

Return type:

TypeVar(P, bound= Pytree)

Parameters:
  • self (P)

  • kwargs (Any)