netket.logging.MLFlowLog#

class netket.logging.MLFlowLog[source]#

Bases: AbstractCallback

Logger that streams metrics and optional model checkpoints to an MLflow tracking server.

On the first call the logger lazily starts an MLflow run so that constructing the logger does not create a run if it is never used.

The mlflow package must be installed (pip install mlflow).

Nested metric keys (e.g. Energy/Mean) are mapped to MLflow metric names using . as separator (Energy.Mean), which MLflow’s UI renders in a collapsible tree. Complex-valued scalars are split into separate <key>/re and <key>/im entries.

This class is a full AbstractCallback and can be passed either as out=logger or inside the callbacks=[..., logger] list. When used as a callback the logger automatically captures the variational state snapshot taken just before the parameter update.

Parameters:
  • experiment_name (str | None) – Name of the MLflow experiment. If None the currently active experiment (or the MLflow default) is used.

  • run_name (str | None) – Human-readable label attached to the new run.

  • run_id (str | None) – If provided, resumes an existing run instead of starting a fresh one. Takes precedence over run_name.

  • tags (dict[str, str] | None) – Optional dict of string key/value tags to attach to the run.

  • save_params (bool) – If True, periodically serialize model parameters as an MLflow artifact (MessagePack binary, same format as JsonLog).

  • save_params_every (int) – Save parameters every this many optimisation steps. Only relevant when save_params=True.

  • metadata (dict | None) – Optional flat dict of key/value pairs logged as MLflow params at the start of the run.

Tip

Use metadata to attach a flat dict of hyper-parameters (learning rate, system size, model type, …) to the run. They are logged as MLflow params and appear next to the metrics in the MLflow UI, making it easy to filter and compare runs without external bookkeeping.

Examples

Log an optimisation run to the local MLflow store.

>>> import pytest; pytest.skip("skip automated test of this docstring")
>>>
>>> import netket as nk
>>> logger = nk.logging.MLFlowLog(
...     experiment_name="Ising1d",
...     run_name="RBM_alpha1",
...     tags={"model": "RBM", "L": "20"},
... )
>>> gs.run(n_iter=500, out=logger)

Attaching metadata to record hyper-parameters.

>>> import pytest; pytest.skip("skip automated test of this docstring")
>>>
>>> import netket as nk
>>> logger = nk.logging.MLFlowLog(
...     experiment_name="Ising1d",
...     run_name="RBM_lr0.01",
...     metadata={"learning_rate": 0.01, "alpha": 1, "L": 20},
... )
>>> gs.run(n_iter=500, out=logger)
>>> # 'learning_rate', 'alpha', 'L' appear as params in the MLflow UI

Resume a previous run.

>>> import pytest; pytest.skip("skip automated test of this docstring")
>>>
>>> import netket as nk
>>> logger = nk.logging.MLFlowLog(run_id="<existing-run-id>")
>>> gs.run(n_iter=500, out=logger)

Using the logger as a callback.

>>> import pytest; pytest.skip("skip automated test of this docstring")
>>>
>>> import netket as nk
>>> logger = nk.logging.MLFlowLog(experiment_name="Ising1d")
>>> gs.run(n_iter=500, callbacks=[logger])
Inheritance
Inheritance diagram of netket.logging.MLFlowLog
Attributes
callback_order#
Methods
__call__(step, item, variational_state=None)[source]#

Call self as a function.

Parameters:
before_parameter_update(step, log_data, driver)[source]#

Called after all update logic has been computed and the step has been accepted, but before the driver applies the parameter update.

At this point:

  • The loss and its gradient have been computed by compute_loss_and_update().

  • The step has been accepted (not rejected by on_compute_update_end()).

  • driver.step_count still refers to the current step — it has not yet been incremented.

  • The variational state parameters have not yet changed.

This is the right place to estimate additional observables, add data to log_data, or take a snapshot of the state for logging. Callbacks with a lower callback_order run first, so observables callbacks (order 0) are guaranteed to populate log_data before logger callbacks (order 10) read it.

flush(variational_state=None)[source]#

Flushes pending data and optionally saves model parameters.

Parameters:

variational_state (VariationalState | None) – if provided and save_params=True, the current model parameters are uploaded as an artifact.

on_compute_update_end(step, log_data, driver)[source]#

Callback called at the end of the compute update phase, after computing the loss and its gradient.

This is called before the parameters are updated, so it can be used to implement custom logic for rejecting a step based on the computed loss or gradient.

Return type:

bool

Returns:

A boolean indicating whether to reject the step (i.e. repeat it with the same parameters). If it returns None, it is treated as False.

on_compute_update_start(step, log_data, driver)[source]#
on_run_end(step, driver)[source]#
on_run_error(step, error, driver)[source]#
on_run_start(step, driver)[source]#
on_step_end(step, log_data, driver)[source]#
on_step_start(step, log_data, driver)[source]#
replace(**kwargs)[source]#

Replace the values of the fields of the object with the values of the keyword arguments. If the object is a dataclass, dataclasses.replace will be used. Otherwise, a new object will be created with the same type as the original object.

Return type:

TypeVar(P, bound= Pytree)

Parameters:
  • self (P)

  • kwargs (Any)