netket.driver.AbstractDynamicsDriver#
- class netket.driver.AbstractDynamicsDriver[source]#
Bases:
AbstractDriverAbstract base class for time-evolution (dynamics) drivers.
Unlike optimization drivers there is no optimizer:
update_parametersapplies the parameter delta directly and advances the simulation clock.Note
How to implement a new dynamics driver
Subclass this class and implement:
compute_loss_and_update(): compute one integration step. Return(loss_stats, Δθ)whereΔθis the full parameter delta for this time step (i.e. it already includes the factor ofdt).dtproperty: the current time step size (may vary for adaptive integrators).
Optionally override
tand thetsetter if your integrator owns the time state rather than using the built-in_tfield (e.g. when wrapping an external ODE integrator).The
run()method accepts either a total evolution timeT(float) or a step count (int), and uses theStopRunmechanism internally so the full callback system works unchanged.- Inheritance

- __init__(variational_state, *, t0=0.0, minimized_quantity_name='loss')[source]#
Initializes a dynamics driver.
- Parameters:
variational_state (
VariationalState) – The variational state.t0 (
float) – Initial simulation time (default 0.0).minimized_quantity_name (
str) – Name of the monitored quantity in logged data.
- Attributes
- dt#
Current time step size.
For fixed-step drivers this is a constant. For adaptive drivers this is the accepted step size of the last completed step.
Subclasses must override this property.
- state#
Returns the machine that is optimized by this driver.
- step_count#
Returns a monotonic integer labelling all the steps performed by this driver. This can be used, for example, to identify the line in a log file.
- t#
Current simulation time.
- Methods
- compute_loss_and_update()[source]#
- Return type:
Performs a step of the optimization driver, returning the PyTree of the gradients that will be optimized.
Concrete drivers must override this method.
Note
When implementing this function on a subclass, you must return the gradient which must match the pytree structure of the parameters of the variational state.
The gradient will then be passed on to the optimizer in order to update the parameters.
Moreover, if you are minimising a loss function you must set the field self._loss_stats with the current value of the loss function.
This will be logged to any logger during optimisation.
- Return type:
- Returns:
the update for the weights.
- estimate(observables, fullsum=False)[source]#
Return MCMC statistics for the expectation value of observables in the current state of the driver.
- Parameters:
observables – A pytree of operators for which statistics should be computed.
fullsum (bool)
- Returns:
A pytree of the same structure as the input, containing MCMC statistics for the corresponding operators as leaves.
- replace(**kwargs)[source]#
Replace the values of the fields of the object with the values of the keyword arguments. If the object is a dataclass, dataclasses.replace will be used. Otherwise, a new object will be created with the same type as the original object.
- reset()[source]#
Deprecated since version 3.22: Use
reset_step()to reset the sampler state at the beginning of a step. Note that the oldreset()also resetstep_countto 0; this behaviour is no longer supported.
- reset_step(hard=False)[source]#
Reset the sampler at the beginning of a step attempt.
On the first attempt (
_step_attempt == 0), resets normally. On subsequent attempts after a step rejection (e.g. from adaptive integrators), skips the reset — existing samples remain valid for the revised candidatedt.- Parameters:
hard (bool)
- run(T=None, out=(), obs=None, *, n_iter=None, show_progress=True, callback=None, timeit=False, step_size=0.0)[source]#
Run the time evolution.
- Parameters:
T – Total evolution time as a float (e.g.
run(1.0)), or a fixed number of steps as an int (e.g.run(100)). When a float is given the loop runs untildriver.t >= T_or_n, usingStopRuninternally so all callbacks fire normally.out (
Iterable[AbstractLog] |None) – Logger or iterable of loggers for output.obs – Observables to compute at each logging step.
max_steps – Safety cap on iterations when
T_or_nis a float, to prevent infinite loops ifdtis zero.show_progress (
bool) – Show a progress bar (default True).callback (
Union[Callable[[int,dict,AbstractDriver],bool],AbstractCallback,None]) – User callbacks.timeit (
bool) – If True, print timing information after the run.n_iter (int | None)
step_size (float)