netket.vqs.ExactState
netket.vqs.ExactState#
- class netket.vqs.ExactState#
Bases:
netket.vqs.VariationalState
Variational State for a variational quantum state computed on the whole Hilbert space without Monte Carlo sampling.
Expectation values and gradients are deterministic. The only non-deterministic part is due to the initialization seed used to generate the parameters.
- Inheritance
- __init__(hilbert, model=None, *, variables=None, init_fun=None, apply_fun=None, seed=None, mutable=False, training_kwargs={}, dtype=<class 'float'>)[source]#
Constructs the ExactState.
- Parameters
hilbert (
AbstractHilbert
) – The Hilbert spacemodel – (Optional) The model. If not provided, you must provide init_fun and apply_fun.
parameters – Optional PyTree of weights from which to start.
seed (
Union
[int
,Any
,None
]) – rng seed used to generate a set of parameters (only if parameters is not passed). Defaults to a random one.mutable (
bool
) – Dict specifying mutable arguments. Use it to specify if the model has a state that can change during evaluation, but that should not be optimised. See also flax.linen.module.apply documentation (default=False)init_fun (
Optional
[Callable
[[Any
,Sequence
[int
],Any
],Union
[ndarray
,DeviceArray
,Tracer
]]]) – Function of the signature f(model, shape, rng_key, dtype) -> Optional_state, parameters used to initialise the parameters. Defaults to the standard flax initialiser. Only specify if your network has a non-standard init method.variables (
Optional
[Any
]) – Optional initial value for the variables (parameters and model state) of the model.apply_fun (
Optional
[Callable
]) – Function of the signature f(model, variables, σ) that should evaluate the model. Defaults to model.apply(variables, σ). specify only if your network has a non-standard apply method.training_kwargs (
Dict
) – a dict containing the optional keyword arguments to be passed to the apply_fun during training. Useful for example when you have a batchnorm layer that constructs the average/mean only during training.
- Attributes
- hilbert#
The descriptor of the Hilbert space on which this variational state is defined.
- Return type
- model#
Returns the model definition of this variational state.
This field is optional, and is set to None if the variational state has been initialized using a custom function.
- model_state: Optional[Any]#
The optional pytree with the mutable state of the model.
- Methods
- expect(Ô)#
- Estimates the quantum expectation value for a given operator O.
In the case of a pure state $psi$, this is $<O>= <Psi|O|Psi>/<Psi|Psi>$ otherwise for a mixed state $rho$, this is $<O> = Tr[rho hat{O}/Tr[rho]$.
- Parameters
Ô – the operator O.
- Return type
Stats
- Returns
An estimation of the quantum expectation value <O>.
- expect_and_forces(Ô, *, mutable=None)#
Estimates the quantum expectation value and corresponding force vector for a given operator O.
The force vector F_j is defined as the covariance of log-derivative of the trial wave function and the local estimators of the operator. For complex holomorphic states, this is equivalent to the expectation gradient d<O>/d(θ_j)* = F_j. For real-parameter states, the gradient is given by d<O>/dθ_j = 2 Re[F_j].
- Parameters
Ô – The operator Ô for which expectation value and force are computed.
mutable (
Optional
[Any
]) – Can be bool, str, or list. Specifies which collections in the model_state should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. This is used to mutate the state of the model while you train it (for example to implement BatchNorm. Consult Flax’s Module.apply documentation for a more in-depth explanation).
- Return type
- Returns
An estimate of the quantum expectation value <O>. An estimate of the forve vector F_j = cov[dlog(ψ)/dx_j, O_loc].
- expect_and_grad(Ô, *, mutable=None, use_covariance=None)#
Estimates the quantum expectation value and its gradient for a given operator O.
- Parameters
Ô – The operator Ô for which expectation value and gradient are computed.
Can be bool, str, or list. Specifies which collections in the model_state should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. This is used to mutate the state of the model while you train it (for example to implement BatchNorm. Consult Flax’s Module.apply documentation for a more in-depth explanation).
use_covariance (
Optional
[bool
]) – whether to use the covariance formula, usually reserved for hermitian operators, ⟨∂logψ Oˡᵒᶜ⟩ - ⟨∂logψ⟩⟨Oˡᵒᶜ⟩
- Return type
- Returns
An estimate of the quantum expectation value <O>. An estimate of the gradient of the quantum expectation value <O>.
- grad(Ô, *, use_covariance=None, mutable=None)#
Estimates the gradient of the quantum expectation value of a given operator O.
- Parameters
op (
netket.operator.AbstractOperator
) – the operator O.is_hermitian – optional override for whether to use or not the hermitian logic. By default it’s automatically detected.
- Returns
An estimation of the average gradient of the quantum expectation value <O>.
- Return type
array
- init(seed=None, dtype=None)[source]#
Initialises the variational parameters of the variational state.
- init_parameters(init_fun=None, *, seed=None)#
Re-initializes all the parameters with the provided initialization function, defaulting to the normal distribution of standard deviation 0.01.
Warning
The init function will not change the dtype of the parameters, which is determined by the model. DO NOT SPECIFY IT INSIDE THE INIT FUNCTION
- Parameters
init_fun (
Optional
[Callable
[[Any
,Sequence
[int
],Any
],Union
[ndarray
,DeviceArray
,Tracer
]]]) – a jax initializer such asjax.nn.initializers.normal()
. Must be a Callable taking 3 inputs, the jax PRNG key, the shape and the dtype, and outputting an array with the valid dtype and shape. If left unspecified, defaults tojax.nn.initializers.normal(stddev=0.01)
seed (
Optional
[Any
]) – Optional seed to be used. The seed is synced across all MPI processes. If unspecified, uses a random seed.
- log_value(σ)[source]#
Evaluate the variational state for a batch of states and returns the logarithm of the amplitude of the quantum state. For pure states, this is \(log(<σ|ψ>)\), whereas for mixed states this is \(log(<σr|ρ|σc>)\), where ψ and ρ are respectively a pure state (wavefunction) and a mixed state (density matrix). For the density matrix, the left and right-acting states (row and column) are obtained as
σr=σ[::,0:N]
andσc=σ[::,N:]
.Given a batch of inputs (Nb, N), returns a batch of outputs (Nb,).
- Return type
Array
- Parameters
σ (jax._src.basearray.Array) –
- quantum_geometric_tensor(qgt_T=QGTAuto())[source]#
Computes an estimate of the quantum geometric tensor G_ij. This function returns a linear operator that can be used to apply G_ij to a given vector or can be converted to a full matrix.
- Parameters
qgt_T (
LinearOperator
) – the optional type of the quantum geometric tensor. By default it’s automatically selected.- Returns
A linear operator representing the quantum geometric tensor.
- Return type
nk.optimizer.LinearOperator
- reset()[source]#
Resets the sampled states. This method is called automatically every time that the parameters/state is updated.
- to_array(normalize=True, allgather=True)[source]#
Returns the dense-vector representation of this state.
- to_qobj()#
Convert the variational state to a qutip’s ket Qobj.
- Returns
A
qutip.Qobj
object.