# Copyright 2021 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from functools import partial
from typing import Any, Callable, Optional
import numpy as np
import jax
from jax import numpy as jnp
import flax
from flax import serialization
from flax.core.scope import CollectionFilter, DenyList # noqa: F401
from netket import jax as nkjax
from netket import nn
from netket.stats import Stats
from netket.operator import AbstractOperator, Squared
from netket.sampler import Sampler, SamplerState
from netket.utils import (
maybe_wrap_module,
wrap_afun,
wrap_to_support_scalar,
timing,
)
from netket.utils.types import PyTree, SeedT, NNInitFunc
from netket.optimizer import LinearOperator
from netket.optimizer.qgt import QGTAuto
from netket.jax import sharding
from netket.vqs.base import VariationalState, expect, expect_and_grad, expect_and_forces
from netket.vqs.mc import get_local_kernel, get_local_kernel_arguments
def compute_chain_length(n_chains, n_samples):
if n_samples <= 0:
raise ValueError(f"Invalid number of samples: n_samples={n_samples}")
chain_length = int(np.ceil(n_samples / n_chains))
n_samples_new = chain_length * n_chains
n_samples_per_rank_new = n_samples_new // sharding.device_count()
if n_samples_new != n_samples:
n_samples_per_rank = n_samples // sharding.device_count()
warnings.warn(
f"n_samples={n_samples} ({n_samples_per_rank} per device/MPI rank) "
f"does not divide n_chains={n_chains}, increased to {n_samples_new} "
f"({n_samples_per_rank_new} per device/MPI rank)",
UserWarning,
stacklevel=3,
)
return chain_length
def check_chunk_size(n_samples, chunk_size):
n_samples_per_rank = n_samples // sharding.device_count()
if chunk_size is not None:
if chunk_size < n_samples_per_rank and n_samples_per_rank % chunk_size != 0:
raise ValueError(
f"chunk_size={chunk_size}`<`n_samples_per_rank={n_samples_per_rank}, "
"chunk_size is not an integer fraction of `n_samples_per rank`. This is"
"unsupported. Please change `chunk_size` so that it divides evenly the"
"number of samples per rank or set it to `None` to disable chunking."
)
def _is_power_of_two(n: int) -> bool:
return (n != 0) and (n & (n - 1) == 0)
@partial(jax.jit, static_argnums=0)
def jit_evaluate(fun: Callable, *args):
"""
call `fun(*args)` inside of a `jax.jit` frame.
Args:
fun: the hashable callable to be evaluated.
args: the arguments to the function.
"""
return fun(*args)
[docs]
class MCState(VariationalState):
"""Variational State for a Variational Neural Quantum State.
The state is sampled according to the provided sampler.
"""
# model: Any
# """The model"""
model_state: Optional[PyTree]
"""An Optional PyTree encoding a mutable state of the model that is not trained."""
_sampler: Sampler
"""The sampler used to sample the Hilbert space."""
sampler_state: SamplerState
"""The current state of the sampler."""
_previous_sampler_state: SamplerState = None
"""The sampler state before the last sampling has been effected.
This field is used so that we don't need to serialize the current samples
but we can always regenerate them.
"""
_chain_length: int = 0
"""Length of the Markov chain used for sampling configurations."""
_n_discard_per_chain: int = 0
"""Number of samples discarded at the beginning of every Markov chain."""
_samples: Optional[jnp.ndarray] = None
"""Cached samples obtained with the last sampling."""
_init_fun: Callable = None
"""The function used to initialise the parameters and model_state."""
_apply_fun: Callable = None
"""The function used to evaluate the model."""
_chunk_size: Optional[int] = None
[docs]
def __init__(
self,
sampler: Sampler,
model=None,
*,
n_samples: Optional[int] = None,
n_samples_per_rank: Optional[int] = None,
n_discard_per_chain: Optional[int] = None,
chunk_size: Optional[int] = None,
variables: Optional[PyTree] = None,
init_fun: Optional[NNInitFunc] = None,
apply_fun: Optional[Callable] = None,
seed: Optional[SeedT] = None,
sampler_seed: Optional[SeedT] = None,
mutable: CollectionFilter = False,
training_kwargs: dict = {},
):
"""
Constructs the MCState.
Args:
sampler: The sampler
model: (Optional) The neural quantum state ansatz, encoded into a model.
This should be a :class:`flax.linen.Module` instance, or any other supported
neural network framework. If not provided, you must specify init_fun and apply_fun.
n_samples: the total number of samples across chains and processes when sampling (default=1000).
n_samples_per_rank: the total number of samples across chains on one process when sampling. Cannot be
specified together with n_samples (default=None).
n_discard_per_chain: number of discarded samples at the beginning of each monte-carlo chain (default=0 for exact sampler,
and n_samples/10 for approximate sampler).
seed: rng seed used to generate a set of parameters (only if parameters is not passed). Defaults to a random one.
sampler_seed: rng seed used to initialise the sampler. Defaults to a random one.
mutable: Name or list of names of mutable arguments. Use it to specify if the model has a state that can change
during evaluation, but that should not be optimised. See also :meth:`flax.linen.Module.apply` documentation
(default=False)
init_fun: Function of the signature f(model, shape, rng_key, dtype) -> Optional_state, parameters used to
initialise the parameters. Defaults to the standard flax initialiser. Only specify if your network has
a non-standard init method.
variables: Optional initial value for the variables (parameters and model state) of the model.
apply_fun: Function of the signature f(model, variables, σ) that should evaluate the model. Defaults to
`model.apply(variables, σ)`. specify only if your network has a non-standard apply method.
training_kwargs: a dict containing the optional keyword arguments to be passed to the apply_fun during training.
Useful for example when you have a batchnorm layer that constructs the average/mean only during training.
chunk_size: (Defaults to `None`) If specified, calculations are split into chunks where the neural network
is evaluated at most on :code:`chunk_size` samples at once. This does not change the mathematical results,
but will trade a higher computational cost for lower memory cost.
"""
super().__init__(sampler.hilbert)
# Init type 1: pass in a model
if model is not None:
# extract init and apply functions
# Wrap it in an HashablePartial because if two instances of the same model are provided,
# model.apply and model2.apply will be different methods forcing recompilation, but
# model and model2 will have the same hash.
_, model = maybe_wrap_module(model)
self._model = model
self._init_fun = nkjax.HashablePartial(
lambda model, *args, **kwargs: model.init(*args, **kwargs), model
)
self._apply_fun = wrap_to_support_scalar(
nkjax.HashablePartial(
lambda model, pars, x, **kwargs: model.apply(pars, x, **kwargs),
model,
)
)
elif apply_fun is not None:
self._apply_fun = wrap_to_support_scalar(apply_fun)
if init_fun is not None:
self._init_fun = init_fun
elif variables is None:
raise ValueError(
"If you don't provide variables, you must pass a valid init_fun."
)
self._model = wrap_afun(apply_fun)
else:
raise ValueError("Must either pass the model or apply_fun.")
# default argument for n_samples/n_samples_per_rank
if n_samples is None and n_samples_per_rank is None:
# get the first multiple of sampler.n_chains above 1000 to avoid
# printing a warning on construction
n_samples = int(np.ceil(1000 / sampler.n_chains) * sampler.n_chains)
elif n_samples is not None and n_samples_per_rank is not None:
raise ValueError(
"Only one argument between `n_samples` and `n_samples_per_rank`"
"can be specified at the same time."
)
self.mutable = mutable
self.training_kwargs = flax.core.freeze(training_kwargs)
if variables is not None:
self.variables = variables
else:
self.init(seed, dtype=sampler.dtype)
if sampler_seed is None and seed is not None:
key, key2 = jax.random.split(nkjax.PRNGKey(seed), 2)
sampler_seed = key2
self._sampler_seed = nkjax.PRNGKey(sampler_seed)
self.sampler = sampler
if n_samples is not None:
self.n_samples = n_samples
else:
self.n_samples_per_rank = n_samples_per_rank
self.n_discard_per_chain = n_discard_per_chain
self.chunk_size = chunk_size
[docs]
def init(self, seed=None, dtype=None):
"""
Initialises the variational parameters of the variational state.
"""
if self._init_fun is None:
raise RuntimeError(
"Cannot initialise the parameters of this state"
"because you did not supply a valid init_function."
)
if dtype is None:
dtype = self.sampler.dtype
key = nkjax.PRNGKey(seed)
dummy_input = jnp.zeros((1, self.hilbert.size), dtype=dtype)
variables = jit_evaluate(self._init_fun, {"params": key}, dummy_input)
self.variables = variables
@property
def model(self) -> Optional[Any]:
"""Returns the model definition of this variational state."""
return self._model
@property
def _sampler_model(self):
"""Returns the model definition used for sampling this variational state.
Equal to `.model`.
"""
return self.model
@property
def _sampler_variables(self):
"""Returns the variables used for sampling this variational state.
Equal to `.variables`
"""
return self.variables
@property
def sampler(self) -> Sampler:
"""The Monte Carlo sampler used by this Monte Carlo variational state."""
return self._sampler
@sampler.setter
def sampler(self, sampler: Sampler):
if not isinstance(sampler, Sampler):
raise TypeError(
f"The sampler should be a subtype of netket.sampler.Sampler, but {type(sampler)} is not."
)
self._sampler_seed, seed = jax.random.split(self._sampler_seed, 2)
# Save the old `n_samples` before the new `sampler` is set.
# `_chain_length == 0` means that this `MCState` is being constructed.
if self._chain_length > 0:
n_samples_old = self.n_samples
self._sampler = sampler
self.sampler_state = self.sampler.init_state(
self._sampler_model, self._sampler_variables, seed=seed
)
self._sampler_state_previous = self.sampler_state
# Update `n_samples`, `n_samples_per_rank`, and `chain_length` according
# to the new `sampler.n_chains`.
# If `n_samples` is divisible by the new `sampler.n_chains`, it will be
# unchanged; otherwise it will be rounded up.
# If the new `n_samples_per_rank` is not divisible by `chunk_size`, a
# `ValueError` will be raised.
# `_chain_length == 0` means that this `MCState` is being constructed.
if self._chain_length > 0:
self.n_samples = n_samples_old
self.reset()
@property
def n_samples(self) -> int:
"""The total number of samples generated at every sampling step."""
return self.chain_length * self.sampler.n_chains
@n_samples.setter
def n_samples(self, n_samples: int):
chain_length = compute_chain_length(self.sampler.n_chains, n_samples)
self.chain_length = chain_length
@property
def n_samples_per_rank(self) -> int:
"""The number of samples generated on every jax device or MPI rank
at every sampling step."""
return self.chain_length * self.sampler.n_chains_per_rank
@n_samples_per_rank.setter
def n_samples_per_rank(self, n_samples_per_rank: int):
self.n_samples = n_samples_per_rank * sharding.device_count()
@property
def chain_length(self) -> int:
"""
Length of the markov chain used for sampling configurations.
If running under MPI, the total samples will be n_nodes * chain_length * n_batches.
"""
return self._chain_length
@chain_length.setter
def chain_length(self, chain_length: int):
if chain_length <= 0:
raise ValueError(f"Invalid chain length: chain_length={chain_length}")
n_samples = chain_length * self.sampler.n_chains
check_chunk_size(n_samples, self.chunk_size)
self._chain_length = chain_length
self.reset()
@property
def n_discard_per_chain(self) -> int:
"""
Number of discarded samples at the beginning of the markov chain.
"""
return self._n_discard_per_chain
@n_discard_per_chain.setter
def n_discard_per_chain(self, n_discard_per_chain: Optional[int]):
if n_discard_per_chain is not None and n_discard_per_chain < 0:
raise ValueError(
f"Invalid number of discarded samples: n_discard_per_chain={n_discard_per_chain}"
)
# don't discard if the sampler is exact
if self.sampler.is_exact:
if n_discard_per_chain is not None and n_discard_per_chain > 0:
warnings.warn(
"An exact sampler does not need to discard samples. Setting n_discard_per_chain to 0.",
stacklevel=2,
)
n_discard_per_chain = 0
self._n_discard_per_chain = (
int(n_discard_per_chain)
if n_discard_per_chain is not None
else self.n_samples // 10
)
@property
def chunk_size(self) -> int:
"""
Suggested *maximum size* of the chunks used in forward and backward evaluations
of the Neural Network model.
If your inputs are smaller than the chunk size this setting is ignored.
This can be used to lower the memory required to run a computation with a very
high number of samples or on a very large lattice. Notice that inputs and
outputs must still fit in memory, but the intermediate computations will now
require less memory.
This option comes at an increased computational cost. While this cost should
be negligible for large-enough chunk sizes, don't use it unless you are memory
bound!
This option is an hint: only some operations support chunking. If you perform
an operation that is not implemented with chunking support, it will fall back
to no chunking. To check if this happened, set the environment variable
`NETKET_DEBUG=1`.
"""
return self._chunk_size
@chunk_size.setter
def chunk_size(self, chunk_size: Optional[int]):
# disable chunks if it is None
if chunk_size is None:
self._chunk_size = None
return
if not isinstance(chunk_size, int) or chunk_size <= 0:
raise ValueError("Chunk size must be a positive INTEGER. ")
if not _is_power_of_two(chunk_size):
warnings.warn(
"For performance reasons, we suggest to use a power-of-two chunk size.",
stacklevel=2,
)
check_chunk_size(self.n_samples, chunk_size)
self._chunk_size = chunk_size
[docs]
def reset(self):
"""
Resets the sampled states. This method is called automatically every time
that the parameters/state is updated.
"""
self._samples = None
[docs]
@timing.timed
def sample(
self,
*,
chain_length: Optional[int] = None,
n_samples: Optional[int] = None,
n_discard_per_chain: Optional[int] = None,
) -> jnp.ndarray:
"""
Sample a certain number of configurations.
If one among chain_length or n_samples is defined, that number of samples
are generated. Otherwise the value set internally is used.
Args:
chain_length: The length of the markov chains.
n_samples: The total number of samples across all MPI ranks.
n_discard_per_chain: Number of discarded samples at the beginning of the markov chain.
"""
if n_samples is None and chain_length is None:
chain_length = self.chain_length
else:
if chain_length is not None and n_samples is not None:
raise ValueError("Cannot specify both `chain_length` and `n_samples`.")
elif chain_length is None:
chain_length = compute_chain_length(self.sampler.n_chains, n_samples)
if self.chunk_size is not None:
check_chunk_size(chain_length * self.sampler.n_chains, self.chunk_size)
if n_discard_per_chain is None:
n_discard_per_chain = self.n_discard_per_chain
# Store the previous sampler state, for serialization purposes
self._sampler_state_previous = self.sampler_state
self.sampler_state = self.sampler.reset(
self._sampler_model, self._sampler_variables, self.sampler_state
)
if self.n_discard_per_chain > 0:
with timing.timed_scope("sampling n_discarded samples"):
_, self.sampler_state = self.sampler.sample(
self.model,
self.variables,
state=self.sampler_state,
chain_length=n_discard_per_chain,
)
self._samples, self.sampler_state = self.sampler.sample(
self._sampler_model,
self._sampler_variables,
state=self.sampler_state,
chain_length=chain_length,
)
return self._samples
@property
def samples(self) -> jnp.ndarray:
"""
Returns the set of cached samples.
The samples returned are guaranteed valid for the current state of
the variational state. If no cached parameters are available, then
they are sampled first and then cached.
To obtain a new set of samples either use
:meth:`~MCState.reset` or :meth:`~MCState.sample`.
"""
if self._samples is None:
self.sample()
return self._samples
[docs]
def log_value(self, σ: jnp.ndarray) -> jnp.ndarray:
r"""
Evaluate the variational state for a batch of states and returns
the logarithm of the amplitude of the quantum state.
For pure states,
this is :math:`\log(\langle\sigma|\psi\rangle)`, whereas for mixed states
this is :math:`\log(\langle\sigma_r|\rho|\sigma_c\rangle)`, where
:math:`\psi` and :math:`\rho` are respectively a pure state
(wavefunction) and a mixed state (density matrix).
For the density matrix, the left and right-acting states (row and column)
are obtained as :code:`σr=σ[::,0:N]` and :code:`σc=σ[::,N:]`.
Given a batch of inputs :code:`(Nb, N)`, returns a batch of outputs
:code:`(Nb,)`.
"""
return jit_evaluate(self._apply_fun, self.variables, σ)
[docs]
@timing.timed
def local_estimators(
self, op: AbstractOperator, *, chunk_size: Optional[int] = None
):
r"""
Compute the local estimators for the operator :code:`op` (also known as local energies
when :code:`op` is the Hamiltonian) at the current configuration samples :code:`self.samples`.
.. math::
O_\mathrm{loc}(s) = \frac{\langle s | \mathtt{op} | \psi \rangle}{\langle s | \psi \rangle}
.. warning::
The samples differ between MPI processes, so returned the local estimators will
also take different values on each process. To compute sample averages and similar
quantities, you will need to perform explicit operations over all MPI ranks.
(Use functions like :code:`self.expect` to get process-independent quantities without
manual reductions.)
Args:
op: The operator.
chunk_size: Suggested maximum size of the chunks used in forward and backward evaluations
of the model. (Default: :code:`self.chunk_size`)
"""
return local_estimators(self, op, chunk_size=chunk_size)
# override to use chunks
[docs]
@timing.timed
def expect(self, O: AbstractOperator) -> Stats:
r"""Estimates the quantum expectation value for a given operator
:math:`O` or generic observable.
In the case of a pure state :math:`\psi` and an operator, this is
:math:`\langle O\rangle= \langle \Psi|O|\Psi\rangle/\langle\Psi|\Psi\rangle`
otherwise for a mixed state :math:`\rho`, this is
:math:`\langle O\rangle= \textrm{Tr}[\rho \hat{O}]/\textrm{Tr}[\rho]`.
Args:
O: the operator or observable for which to compute the expectation
value.
Returns:
An estimation of the quantum expectation value
:math:`\langle O\rangle`.
"""
return expect(self, O, self.chunk_size)
# override to use chunks
[docs]
@timing.timed
def expect_and_grad(
self,
O: AbstractOperator,
*,
mutable: Optional[CollectionFilter] = None,
**kwargs,
) -> tuple[Stats, PyTree]:
r"""Estimates the quantum expectation value and its gradient
for a given operator :math:`O`.
Args:
O: The operator :math:`O` for which expectation value and
gradient are computed.
mutable: Can be bool, str, or list. Specifies which collections in the
model_state should be treated as mutable: bool: all/no collections
are mutable. str: The name of a single mutable collection. list: A
list of names of mutable collections. This is used to mutate the state
of the model while you train it (for example to implement BatchNorm. Consult
`Flax's Module.apply documentation <https://flax.readthedocs.io/en/latest/_modules/flax/linen/module.html#Module.apply>`_
for a more in-depth explanation).
use_covariance: whether to use the covariance formula, usually reserved for
hermitian operators,
:math:`\textrm{Cov}[\partial\log\psi, O_{\textrm{loc}}\rangle]`
Returns:
An estimate of the quantum expectation value <O>.
An estimate of the gradient of the quantum expectation value <O>.
"""
if mutable is None:
mutable = self.mutable
return expect_and_grad(
self,
O,
self.chunk_size,
mutable=mutable,
**kwargs,
)
# override to use chunks
[docs]
@timing.timed
def expect_and_forces(
self,
O: AbstractOperator,
*,
mutable: Optional[CollectionFilter] = None,
) -> tuple[Stats, PyTree]:
r"""Estimates the quantum expectation value and the corresponding force
vector for a given operator O.
The force vector :math:`F_j` is defined as the covariance of log-derivative
of the trial wave function and the local estimators of the operator. For complex
holomorphic states, this is equivalent to the expectation gradient
:math:`\frac{\partial\langle O\rangle}{\partial(\theta_j)^\star} = F_j`. For real-parameter states,
the gradient is given by :math:`\frac{\partial\partial_j\langle O\rangle}{\partial\partial_j\theta_j} = 2 \textrm{Re}[F_j]`.
Args:
O: The operator O for which expectation value and force are computed.
mutable: Can be bool, str, or list. Specifies which collections in
the model_state should be treated as mutable: bool: all/no
collections are mutable. str: The name of a single mutable
collection. list: A list of names of mutable collections. This is
used to mutate the state of the model while you train it (for
example to implement BatchNorm. Consult
`Flax's Module.apply documentation <https://flax.readthedocs.io/en/latest/_modules/flax/linen/module.html#Module.apply>`_
for a more in-depth explanation).
Returns:
An estimate of the quantum expectation value <O>.
An estimate of the force vector
:math:`F_j = \textrm{Cov}[\partial_j\log\psi, O_{\textrm{loc}}]`.
"""
if isinstance(O, Squared):
raise NotImplementedError(
"expect_and_forces not yet implemented for `Squared`"
)
if mutable is None:
mutable = self.mutable
return expect_and_forces(self, O, self.chunk_size, mutable=mutable)
[docs]
def quantum_geometric_tensor(
self, qgt_T: Optional[LinearOperator] = None
) -> LinearOperator:
r"""Computes an estimate of the quantum geometric tensor G_ij.
This function returns a linear operator that can be used to apply G_ij to a given vector
or can be converted to a full matrix.
Args:
qgt_T: the optional type of the quantum geometric tensor. By default it's automatically selected.
Returns:
nk.optimizer.LinearOperator: A linear operator representing the quantum geometric tensor.
"""
if qgt_T is None:
qgt_T = QGTAuto()
return qgt_T(self)
[docs]
def to_array(self, normalize: bool = True) -> jnp.ndarray:
return nn.to_array(
self.hilbert,
self._apply_fun,
self.variables,
normalize=normalize,
chunk_size=self.chunk_size,
)
def __repr__(self):
return (
"MCState("
+ f"\n hilbert = {self.hilbert},"
+ f"\n sampler = {self.sampler},"
+ f"\n n_samples = {self.n_samples},"
+ f"\n n_discard_per_chain = {self.n_discard_per_chain},"
+ f"\n sampler_state = {self.sampler_state},"
+ f"\n n_parameters = {self.n_parameters})"
)
def __str__(self):
return (
"MCState("
+ f"hilbert = {self.hilbert}, "
+ f"sampler = {self.sampler}, "
+ f"n_samples = {self.n_samples})"
)
@partial(jax.jit, static_argnames=("kernel", "apply_fun", "shape"))
def _local_estimators_kernel(kernel, apply_fun, shape, variables, samples, extra_args):
O_loc = kernel(apply_fun, variables, samples, extra_args)
return O_loc.reshape(shape)
def local_estimators(
state: MCState, op: AbstractOperator, *, chunk_size: Optional[int]
):
s, extra_args = get_local_kernel_arguments(state, op)
shape = s.shape
if jnp.ndim(s) != 2:
# jit for gda
s = jax.jit(jax.lax.collapse, static_argnums=(1, 2))(s, 0, s.ndim - 1)
if chunk_size is None:
chunk_size = state.chunk_size # state.chunk_size can still be None
if chunk_size is None:
kernel = get_local_kernel(state, op)
else:
kernel = nkjax.HashablePartial(
get_local_kernel(state, op, chunk_size), chunk_size=chunk_size
)
return _local_estimators_kernel(
kernel, state._apply_fun, shape[:-1], state.variables, s, extra_args
)
# serialization
def serialize_MCState(vstate):
state_dict = {
"variables": serialization.to_state_dict(
sharding.extract_replicated(vstate.variables)
),
"sampler_state": serialization.to_state_dict(vstate._sampler_state_previous),
"n_samples": vstate.n_samples,
"n_discard_per_chain": vstate.n_discard_per_chain,
"chunk_size": vstate.chunk_size,
}
return state_dict
def deserialize_MCState(vstate, state_dict):
import copy
new_vstate = copy.copy(vstate)
new_vstate.reset()
new_vstate.variables = serialization.from_state_dict(
vstate.variables, state_dict["variables"]
)
new_vstate.sampler_state = serialization.from_state_dict(
vstate.sampler_state, state_dict["sampler_state"]
)
new_vstate.n_samples = state_dict["n_samples"]
new_vstate.n_discard_per_chain = state_dict["n_discard_per_chain"]
new_vstate.chunk_size = state_dict["chunk_size"]
return new_vstate
serialization.register_serialization_state(
MCState,
serialize_MCState,
deserialize_MCState,
)