Source code for netket.models.slater

# Copyright 2023 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import flax.linen as nn
import jax.numpy as jnp

from functools import partial

import numpy as np

from netket.experimental.hilbert import SpinOrbitalFermions
from netket.utils.types import NNInitFunc, DType
from netket.nn.masked_linear import default_kernel_init
from netket import jax as nkjax


[docs] class Slater2nd(nn.Module): r""" A slater determinant ansatz for second-quantised spinless or spin-full fermions. When working with spin-full fermionic hilbert spaces (where the number of degrees of freedom is a multiple of the number of orbitals) :class:`~netket.models.Slater2nd` may behave in 3 different ways, depending on some flags. Those modes differ by how the orbitals are represented. The more restrictions we impose, the lower the number of parameters, the higher the number of imposed symmetries, but the expressivity will be worse and it will be less likely to attain accurate ground-state energies. Those options are summarised in this table, while the details are discussed below. ================= ===================================================================== ============================================= Hartree Fock type Number of Parameters Options ================= ===================================================================== ============================================= Generalized :math:`n_{\mathrm{M}} \times n_{\mathrm{f}}` :code:`generalized = True` Unrestricted :math:`n_{\mathrm{S}} \times n_{\mathrm{L}} \times n_{\textrm{f, s}}` :code:`generalized=False, restricted=False` Restricted :math:`n_{\mathrm{L}} \times n_{\textrm{f, s}}` :code:`generalized=False, restricted=True` ================= ===================================================================== ============================================= Where we use :math:`n_{\textrm{M}}` to denote the number of fermionic modes, :math:`n_{\textrm{L}}` the number of spatial orbitals, and :math:`n_{\textrm{S}}` the number of spin states. The number of fermions is denoted :math:`n_{\mathrm{f}}`, or :math:`n_{\textrm{f}, \alpha}` for the number of fermions in a given spin sector :math:`\alpha`. We assume the same number of fermions in each spin sector for simplicity. Details about different hartree Fock types ========================================== Assume we introduce a set of orbitals :math:`\phi_\mu(r, s)` with orbital index :math:`\mu`. - **Generalized Hartree-Fock** (:code:`generalized = True`) is the most general case where we impose no restrictions. In particular, we do not restrict the orbitals to have definite spin or orbital quantum numbers. The total number of parameters is :math:`n_{\mathrm{M}} \times n_{\mathrm{f}}`. Hence, any fermion can occupy any of the fermionic modes. - **Hartree-Fock (Spin-Conserving)** (:code:`[generalized=False,] restricted=True/False`). Most physical Hamiltonians are spin conserving, and hence we can impose it also on the wave-function. In this case, we separate the orbital index :math:`\mu \to (l, \alpha)` into a spin and spatial orbital part: :math:`\phi_\mu(r, s)=\varphi_{l,\alpha}(r) \chi_{\alpha}(s)`. Here, :math:`l` and :math:`\alpha` indicate the orbital and spin quantum numbers associated with the orbital, and :math:`(r, s)` are the position vector and spin quantum number at which we aim to evaluate the orbital (i.e. properties of a given fermion). Furthermore, :math:`\varphi_{l,\alpha}(r)` is the spatial orbital at position :math:`r`, and and :math:`\chi_\alpha(s)` the spin part. - **Unrestricted Hartree Fock (UHF)** (:code:`[generalized=False,] restricted=False`), the orbitals can have a different spatial orbital :math:`\varphi` for different spin states. Since e.g. the up spin fermions cannot occupy the down spin orbitals and vice versa, the Slater matrix becomes block diagonal. This allows us to write the determinant as a product of determinants of the two spin sectors. The total number of parameters is :math:`n_{\mathrm{S}} \times n_{\mathrm{L}} \times n_{\textrm{f, s}}`. For more information, see `Wikipedia: Unrestricted Hartree-Fock <https://en.wikipedia.org/wiki/Unrestricted_Hartree%E2%80%93Fock>`_ - **Restricted Hartree-Fock (RHF)** (:code:`[generalized=False, restricted=True]`), which assumes that different spin states have the same spatial orbitals in :math:`\phi_\mu(r, s)=\varphi_l(r) \chi_\alpha(s)`, and hence :math:`\varphi_l` only depends on the spatial orbital index :math:`l`. The number of parameters now reduces to :math:`n_{\mathrm{L}} \times n_{\textrm{f, s}}`. """ hilbert: SpinOrbitalFermions """The Hilbert space upon which this ansatz is defined. Used to determine the number of orbitals and spin subspectors.""" generalized: bool = False """Uses Generalized Hartree-Fock if True (defaults to `False`, corresponding to the standard spin-conserving Hartree-Fock). """ restricted: bool = True """Flag to select the restricted- or unrestricted- Hartree Fock orbitals (Defaults to restricted). This flag is ignored if :code:`generalized=True`. - If restricted, only one set of orbitals are parametrised, and they are used for all spin subsectors. This only works if every spin subsector holds the same number of fermions. - If unrestricted, a different set of orbitals are parametrised and used for each spin subsector. """ kernel_init: NNInitFunc = default_kernel_init """Initializer for the orbital parameters.""" param_dtype: DType = float """Dtype of the orbital amplitudes.""" def __post_init__(self): if not isinstance(self.hilbert, SpinOrbitalFermions): raise TypeError( "Slater2nd only supports 2nd quantised fermionic hilbert spaces." ) if self.hilbert.n_fermions is None: raise TypeError( "Slater2nd only supports hilbert spaces with a " "fixed number of fermions." ) if self.restricted: if not all( np.equal( self.hilbert.n_fermions_per_spin, self.hilbert.n_fermions_per_spin[0], ) ): raise ValueError( "Restricted Hartree Fock only makes sense for spaces with " "same number of fermions on every subspace." ) super().__post_init__() def setup(self): # Every determinant is a matrix of shape (n_orbitals, n_fermions_i) where # n_fermions_i is the number of fermions in the i-th spin sector. if self.generalized: M = self.param( "M", self.kernel_init, (self.hilbert.size, self.hilbert.n_fermions), self.param_dtype, ) self.orbitals = M else: if self.restricted: M = self.param( "M", self.kernel_init, (self.hilbert.n_orbitals, self.hilbert.n_fermions_per_spin[0]), self.param_dtype, ) self.orbitals = [M for _ in self.hilbert.n_fermions_per_spin] else: self.orbitals = [ self.param( f"M_{i}", self.kernel_init, (self.hilbert.n_orbitals, nf_i), self.param_dtype, ) for i, nf_i in enumerate(self.hilbert.n_fermions_per_spin) ]
[docs] def __call__(self, n): """ Assumes inputs are strings of 0,1 that specify which orbitals are occupied. Spin sectors are assumed to follow the SpinOrbitalFermion's factorisation, meaning that the first `n_orbitals` entries correspond to sector -1, the second `n_orbitals` correspond to 0 ... etc. """ if not n.shape[-1] == self.hilbert.size: raise ValueError( f"Dimension mismatch. Expected samples with {self.hilbert.size} " f"degrees of freedom, but got a sample of shape {n.shape} ({n.shape[-1]} dof)." ) @partial(jnp.vectorize, signature="(n)->()") def log_sd(n): # Find the positions of the occupied sites R = n.nonzero(size=self.hilbert.n_fermions)[0] log_det_sum = 0 i_start = 0 if self.generalized: # extract Nf x Nf submatrix A_i = self.orbitals[R, :] log_det_sum = nkjax.logdet_cmplx(A_i) else: for i, (n_fermions_i, M_i) in enumerate( zip(self.hilbert.n_fermions_per_spin, self.orbitals) ): # convert global orbital positions to spin-sector-local positions R_i = ( R[i_start : i_start + n_fermions_i] - i * self.hilbert.n_orbitals ) # extract the corresponding Nf x Nf submatrix A_i = M_i[R_i] log_det_sum = log_det_sum + nkjax.logdet_cmplx(A_i) i_start = n_fermions_i return log_det_sum return log_sd(n)
[docs] class MultiSlater2nd(nn.Module): r""" A slater determinant ansatz for second-quantised spinless or spin-full fermions with a sum of determinants. Refer to :class:`~netket.experimental.models.Slater2nd` for details about the different variants of Hartree Fock and the flags. """ hilbert: SpinOrbitalFermions """The Hilbert space upon which this ansatz is defined. Used to determine the number of orbitals and spin subspectors.""" n_determinants: int = 1 """The number of determinants to be summed.""" generalized: bool = False """Uses Generalized Hartree-Fock if True (defaults to `False`, corresponding to the standard spin-conserving Hartree-Fock). """ restricted: bool = True """Flag to select the restricted- or unrestricted- Hartree Fock orbitals (Defaults to restricted). This flag is ignored if :code:`generalized=True`. - If restricted, only one set of orbitals are parametrised, and they are used for all spin subsectors. This only works if every spin subsector holds the same number of fermions. - If unrestricted, a different set of orbitals are parametrised and used for each spin subsector. """ kernel_init: NNInitFunc = default_kernel_init """Initializer for the orbital parameters.""" param_dtype: DType = float """Dtype of the orbital amplitudes."""
[docs] @nn.compact def __call__(self, n): """ Assumes inputs are strings of 0,1 that specify which orbitals are occupied. Spin sectors are assumed to follow the SpinOrbitalFermion's factorisation, meaning that the first `n_orbitals` entries correspond to sector -1, the second `n_orbitals` correspond to 0 ... etc. """ if not self.n_determinants: raise ValueError( "Number of determinants must be an integer greater than 0." ) # make extra axis with copies to run determinants in parallel n_bc = jnp.broadcast_to(n, (self.n_determinants, *n.shape)) multi_log_det = nn.vmap( Slater2nd, in_axes=0, out_axes=0, # vmap over copied axis variable_axes={"params": 0}, split_rngs={"params": True}, )( self.hilbert, restricted=self.restricted, generalized=self.generalized, kernel_init=self.kernel_init, param_dtype=self.param_dtype, )( n_bc ) # sum the determinants log_det_sum = nkjax.logsumexp_cplx(multi_log_det, axis=0) return log_det_sum