Source code for netket.vqs.mc.mc_mixed_state.state

# Copyright 2021 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import jax
from jax import numpy as jnp

from flax import serialization

import netket
from netket import jax as nkjax
from netket.sampler import Sampler
from netket.stats import Stats
from netket.utils.types import PyTree
from netket.operator import AbstractOperator

from netket.jax.sharding import extract_replicated

from netket.vqs import VariationalMixedState

from netket.vqs.mc import MCState


def apply_diagonal(bare_afun, w, x, *args, **kwargs):
    x = jnp.hstack((x, x))
    return bare_afun(w, x, *args, **kwargs)


[docs] class MCMixedState(VariationalMixedState, MCState): """Variational State for a Mixed Variational Neural Quantum State. The state is sampled according to the provided sampler, and it's diagonal is sampled according to another sampler. """
[docs] def __init__( self, sampler, model=None, *, sampler_diag: Optional[Sampler] = None, n_samples_diag: Optional[int] = None, n_samples_per_rank_diag: Optional[int] = None, n_discard_per_chain_diag: Optional[int] = None, seed=None, sampler_seed: Optional[int] = None, variables=None, **kwargs, ): """ Constructs the MCMixedState. Arguments are the same as :class:`MCState`. Arguments: sampler: The sampler model: (Optional) The model. If not provided, you must provide init_fun and apply_fun. n_samples: the total number of samples across chains and processes when sampling (default=1000). n_samples_per_rank: the total number of samples across chains on one process when sampling. Cannot be specified together with n_samples (default=None). n_discard_per_chain: number of discarded samples at the beginning of each monte-carlo chain (default=n_samples/10). n_samples_diag: the total number of samples across chains and processes when sampling the diagonal of the density matrix (default=1000). n_samples_per_rank_diag: the total number of samples across chains on one process when sampling the diagonal. Cannot be specified together with `n_samples_diag` (default=None). n_discard_per_chain_diag: number of discarded samples at the beginning of each monte-carlo chain used when sampling the diagonal of the density matrix for observables (default=n_samples_diag/10). parameters: Optional PyTree of weights from which to start. seed: rng seed used to generate a set of parameters (only if parameters is not passed). Defaults to a random one. sampler_seed: rng seed used to initialise the sampler. Defaults to a random one. mutable: Name or list of names of mutable arguments. Use it to specify if the model has a state that can change during evaluation, but that should not be optimised. See also flax.linen.module.apply documentation (default=False) init_fun: Function of the signature f(model, shape, rng_key, dtype) -> Optional_state, parameters used to initialise the parameters. Defaults to the standard flax initialiser. Only specify if your network has a non-standard init method. apply_fun: Function of the signature f(model, variables, σ) that should evaluate the model. Defaults to `model.apply(variables, σ)`. specify only if your network has a non-standard apply method. training_kwargs: a dict containing the optional keyword arguments to be passed to the apply_fun during training. Useful for example when you have a batchnorm layer that constructs the average/mean only during training. """ seed, seed_diag = jax.random.split(nkjax.PRNGKey(seed)) if sampler_seed is None: sampler_seed_diag = None else: sampler_seed, sampler_seed_diag = jax.random.split( nkjax.PRNGKey(sampler_seed) ) self._diagonal = None hilbert_physical = sampler.hilbert.physical super().__init__( sampler.hilbert.physical, sampler, model, **kwargs, seed=seed, sampler_seed=sampler_seed, variables=variables, ) if sampler_diag is None: sampler_diag = sampler.replace(hilbert=hilbert_physical) sampler_diag = sampler_diag.replace(machine_pow=1) diagonal_apply_fun = nkjax.HashablePartial(apply_diagonal, self._apply_fun) for kw in [ "n_samples", "n_discard_per_chain", ]: if kw in kwargs: kwargs.pop(kw) self._diagonal = MCState( sampler_diag, apply_fun=diagonal_apply_fun, n_samples=n_samples_diag, n_samples_per_rank=n_samples_per_rank_diag, n_discard_per_chain=n_discard_per_chain_diag, variables=self.variables, seed=seed_diag, sampler_seed=sampler_seed_diag, **kwargs, )
@property def diagonal(self): return self._diagonal @property def sampler_diag(self) -> Sampler: """The Monte Carlo sampler used by this Monte Carlo variational state to sample the diagonal.""" return self.diagonal.sampler @sampler_diag.setter def sampler_diag(self, sampler): self.diagonal.sampler = sampler @property def n_samples_diag(self) -> int: """The total number of samples generated at every sampling step when sampling the diagonal of this mixed state. """ return self.diagonal.n_samples @n_samples_diag.setter def n_samples_diag(self, n_samples): self.diagonal.n_samples = n_samples @property def chain_length_diag(self) -> int: """ Length of the markov chain used for sampling the diagonal configurations. If running under MPI, the total samples will be n_nodes * chain_length * n_batches. """ return self.diagonal.chain_length @chain_length_diag.setter def chain_length_diag(self, length: int): self.diagonal.chain_length = length @property def n_discard_per_chain_diag(self) -> int: """Number of discarded samples at the beginning of the markov chain used to sample the diagonal of this mixed state. """ return self.diagonal.n_discard_per_chain @n_discard_per_chain_diag.setter def n_discard_per_chain_diag(self, n_discard_per_chain: Optional[int]): self.diagonal.n_discard_per_chain = n_discard_per_chain @MCState.parameters.setter def parameters(self, pars: PyTree): MCState.parameters.fset(self, pars) if self.diagonal is not None: self.diagonal.parameters = pars @MCState.model_state.setter def model_state(self, state: PyTree): MCState.model_state.fset(self, state) if self.diagonal is not None: self.diagonal.model_state = state
[docs] def reset(self): super().reset() if self.diagonal is not None: self.diagonal.reset()
[docs] def expect_and_grad_operator( self, OÌ‚: AbstractOperator, is_hermitian=None ) -> tuple[Stats, PyTree]: raise NotImplementedError
[docs] def to_matrix(self, normalize: bool = True) -> jnp.ndarray: return netket.nn.to_matrix( self.hilbert, self._apply_fun, self.variables, normalize=normalize, chunk_size=self.chunk_size, )
def __repr__(self): return ( "MCMixedState(" + f"\n hilbert = {self.hilbert}," + f"\n sampler = {self.sampler}," + f"\n n_samples = {self.n_samples}," + f"\n n_discard_per_chain = {self.n_discard_per_chain}," + f"\n sampler_state = {self.sampler_state}," + f"\n sampler_diag = {self.sampler_diag}," + f"\n n_samples_diag = {self.n_samples_diag}," + f"\n n_discard_per_chain_diag = {self.n_discard_per_chain_diag}," + f"\n sampler_state_diag = {self.diagonal.sampler_state}," + f"\n n_parameters = {self.n_parameters})" ) def __str__(self): return ( "MCMixedState(" + f"hilbert = {self.hilbert}, " + f"sampler = {self.sampler}, " + f"n_samples = {self.n_samples})" )
# serialization def serialize_MCMixedState(vstate): state_dict = { "variables": serialization.to_state_dict(extract_replicated(vstate.variables)), "sampler_state": serialization.to_state_dict(vstate._sampler_state_previous), "diagonal": serialization.to_state_dict(vstate.diagonal), "n_samples": vstate.n_samples, "n_discard_per_chain": vstate.n_discard_per_chain, "chunk_size": vstate.chunk_size, } return state_dict def deserialize_MCMixedState(vstate, state_dict): import copy new_vstate = copy.copy(vstate) new_vstate.reset() # restore the diagonal first so we can relink the samples new_vstate._diagonal = serialization.from_state_dict( vstate._diagonal, state_dict["diagonal"] ) new_vstate.variables = serialization.from_state_dict( vstate.variables, state_dict["variables"] ) new_vstate.sampler_state = serialization.from_state_dict( vstate.sampler_state, state_dict["sampler_state"] ) new_vstate.n_samples = state_dict["n_samples"] new_vstate.n_discard_per_chain = state_dict["n_discard_per_chain"] new_vstate.chunk_size = state_dict["chunk_size"] return new_vstate serialization.register_serialization_state( MCMixedState, serialize_MCMixedState, deserialize_MCMixedState, )