Source code for netket.experimental.vqs.io

# Copyright 2020, 2021 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tarfile as _tarfile
from os import path as _path

from flax import serialization as _serialization
from netket.utils.types import PyTree as _PyTree


[docs] def variables_from_file(filename: str, variables: _PyTree): """ Loads the variables of a variational state from a `.mpack` file. Args: filename: the file containing the variables. Assumes a .mpack extension and adds it if missing and no file exists. variables: An object variables with the same structure and shape of the object to be deserialized. Returns: a PyTree like variables Examples: Serializing the data: >>> import netket as nk >>> import flax >>> # construct an RBM model on 10 spins >>> vstate = nk.vqs.MCState( ... nk.sampler.MetropolisLocal(nk.hilbert.Spin(0.5)**10), ... nk.models.RBM()) >>> with open("test.mpack", 'wb') as file: ... bytes_written = file.write(flax.serialization.to_bytes(vstate.variables)) >>> print(bytes_written) 1052 >>> >>> # Deserialize the data >>> >>> del vstate >>> # construct an RBM model on 10 spins >>> vstate2 = nk.vqs.MCState( ... nk.sampler.MetropolisLocal(nk.hilbert.Spin(0.5)**10), ... nk.models.RBM()) >>> # Load the data by passing the model >>> vars = nk.experimental.vqs.variables_from_file("test.mpack", ... vstate2.variables) >>> # update the variables of vstate with the loaded data. >>> vstate2.variables = vars """ if not _path.isfile(filename): if filename[-6:] != ".mpack": filename = filename + ".mpack" with open(filename, "rb") as f: return _serialization.from_bytes(variables, f.read())
[docs] def variables_from_tar(filename: str, variables: _PyTree, i: int): """ Loads the variables of a variational state from the i-th element of a `.tar` archive. Args: filename: the tar archive name. Assumes a .tar extension and adds it if missing and no file exists. variables: An object variables with the same structure and shape of the object to be deserialized. i: the index of the variables to load """ if not _path.isfile(filename): if filename[-4:] != ".tar": filename = filename + ".tar" with _tarfile.TarFile(filename, "r") as file: info = file.getmember(str(i) + ".mpack") with file.extractfile(info) as f: return _serialization.from_bytes(variables, f.read())