Source code for netket.logging.json_log

# Copyright 2021 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time

import os
from os import path as _path

from flax import serialization

from netket.jax.sharding import extract_replicated

from .runtime_log import RuntimeLog


[docs] class JsonLog(RuntimeLog): """ This logger serializes expectation values and other log data to a JSON file and can save the latest model parameters in MessagePack encoding to a separate file. It can be passed with keyword argument `out` to Monte Carlo drivers in order to serialize the output data of the simulation. This logger inherits from :class:`netket.logging.RuntimeLog`, so it maintains the dictionary of all logged quantities in memory, which can be accessed through the attribute :attr:`~netket.logging.JsonLog.data`. If the model state is serialized, then it can be de-serialized using the msgpack protocol of flax. For more information on how to de-serialize the output, see `here <https://flax.readthedocs.io/en/latest/flax.serialization.html>`_. The target of the serialization is the variational state itself. Data is serialized to json as several nested dictionaries. You can deserialize by simply calling :func:`json.load(open(filename)) <json.load>`. Logged expectation values will be captured inside histories objects, so they will have a subfield `iter` with the iterations at which that quantity has been computed, then `Mean` and others. Complex numbers are logged as dictionaries :code:`{'real':list, 'imag':list}`. """
[docs] def __init__( self, output_prefix: str, mode: str = "write", save_params_every: int = 50, write_every: int = 50, save_params: bool = True, autoflush_cost: float = 0.005, ): """ Construct a Json Logger. Args: output_prefix: the name of the output files before the extension save_params_every: every how many iterations should machine parameters be flushed to file write_every: every how many iterations should data be flushed to file mode: Specify the behaviour in case the file already exists at this output_prefix. Options are - `[w]rite`: (default) overwrites file if it already exists; - `[x]` or `fail`: fails if file already exists; save_params: bool flag indicating whether variables of the variational state should be serialized at some interval. The output file is overwritten every time variables are saved again. autoflush_cost: Maximum fraction of runtime that can be dedicated to serializing data. Defaults to 0.005 (0.5 per cent) """ super().__init__() # Shorthands for mode if mode == "w": mode = "write" elif mode == "a": mode = "append" elif mode == "x": mode = "fail" if not ((mode == "write") or (mode == "append") or (mode == "fail")): raise ValueError( "Mode not recognized: should be one of `[w]rite`, `[a]ppend` or" "`[x]`(fail)." ) if mode == "append": raise ValueError("Append mode is no longer supported.") file_exists = _path.exists(output_prefix + ".log") or _path.exists( output_prefix + ".mpack" ) if file_exists and mode == "fail": raise ValueError( "Output file already exists. Either delete it manually or" "change `output_prefix`." ) dir_name = _path.dirname(output_prefix) if dir_name != "": os.makedirs(dir_name, exist_ok=True) self._prefix = output_prefix self._file_mode = mode self._write_every = write_every self._save_params_every = save_params_every self._old_step = 0 self._steps_notflushed_write = 0 self._steps_notflushed_pars = 0 self._save_params = save_params self._files_open = [output_prefix + ".log", output_prefix + ".mpack"] self._autoflush_cost = autoflush_cost self._last_flush_time = time.time() self._last_flush_runtime = 0.0 self._last_flush_pars_time = time.time() self._last_flush_pars_runtime = 0.0 self._flush_log_time = 0.0 self._flush_pars_time = 0.0
[docs] def __call__(self, step, item, variational_state=None): old_step = self._old_step super().__call__(step, item, variational_state) # Check if the time from the last flush is higher than the maximum # allowed runtime cost of flushing elapsed_time = time.time() - self._last_flush_time # On windows, the precision of `time.time` is much lower than that on Linux, # so `elapsed_time` may be essentially zero. # We add 1e-7 to avoid the zero division error. flush_anyway = ( self._last_flush_runtime / (elapsed_time + 1e-7) < self._autoflush_cost ) if ( self._steps_notflushed_write % self._write_every == 0 or step == old_step - 1 or flush_anyway ): self._flush_log() elapsed_time = time.time() - self._last_flush_pars_time flush_anyway = ( self._last_flush_pars_runtime / (elapsed_time + 1e-7) < self._autoflush_cost ) if ( self._steps_notflushed_pars % self._save_params_every == 0 or step == old_step - 1 or flush_anyway ): self._flush_params(variational_state) self._old_step = step self._steps_notflushed_write += 1 self._steps_notflushed_pars += 1
def _flush_log(self): # Time how long flushing data takes. self._last_flush_time = time.time() self.serialize(self._prefix + ".log") self._last_flush_runtime = time.time() - self._last_flush_time self._flush_log_time += self._last_flush_runtime self._steps_notflushed_write = 0 def _flush_params(self, variational_state): if not self._save_params: return if variational_state is None: return self._last_flush_pars_time = time.time() binary_data = serialization.to_bytes( extract_replicated(variational_state.variables) ) with open(self._prefix + ".mpack", "wb") as outfile: outfile.write(binary_data) self._last_flush_pars_runtime = time.time() - self._last_flush_pars_time self._flush_pars_time += self._last_flush_pars_runtime self._steps_notflushed_pars = 0
[docs] def flush(self, variational_state=None): """ Writes to file the content of this logger. Args: variational_state: optionally also writes the parameters of the machine. """ self._flush_log() if variational_state is not None: self._flush_params(variational_state)
def __del__(self): if hasattr(self, "_steps_notflushed_write"): if self._steps_notflushed_write > 0: self.flush() if hasattr(self, "_steps_notflushed_pars"): if self._steps_notflushed_pars > 0: self.flush() def __repr__(self): _str = f"JsonLog('{self._prefix}', mode={self._file_mode}, " _str = _str + f"autoflush_cost={self._autoflush_cost})" _str = _str + "\n Runtime cost:" _str = _str + f"\n \tLog: {self._flush_log_time}" _str = _str + f"\n \tParams: {self._flush_pars_time}" return _str