Source code for netket.models.tensor_networks.mpdo

# Copyright 2024 The NetKet Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.Mon

from typing import Optional

from flax import linen as nn
import jax
from jax import numpy as jnp

from jax.nn.initializers import normal

from netket.hilbert import HomogeneousHilbert
from netket.utils.types import NNInitFunc
from netket.jax import dtype_complex
from netket.utils.types import DType

default_kernel_init = normal(stddev=0.01)


[docs] class MPDOPeriodic(nn.Module): r""" A Matrix Product Density Operator (MPDO) with periodic boundary conditions for a quantum mixed state of discrete degrees of freedom. The purification is used. The MPDO is defined as (see `F. Verstraete, J. J. García-Ripoll, J. I. Cirac, Phys. Rev. Lett. 93, 207204 (2004) <https://doi.org/10.1103/PhysRevLett.93.207204>`_). .. math:: \rho(s_1,\dots, s_N, s_1',\dots, s_N') = \sum_{\alpha_1, \dots, \alpha_{N-1}} \mathrm{Tr} \left[ M^{\alpha_1}_{s_1,s_1'} \dots M^{\alpha_{N-1}}_{s_N, s_N'} M^{\alpha_N}_{s_1, s_1'} \right], for arbitrary local quantum numbers :math:`s_i` and :math:`s_i'`, where :math:`M^{\alpha_i}_{s_i,s_i'}` are :math:`D^2 \times D^2` matrices that can be decomposed as .. math:: M^{\alpha_i}_{s_i,s_i'} = \sum_{a=1}^{\chi} A^{\alpha_i, a}_{s_i} \otimes (A^{\alpha_i, a}_{s_i'})^*, with :math:`A^{\alpha_i, a}_{s_i}` being :math:`D \times D` matrices. The bond dimension is denoted by :math:`D` and the Kraus dimension by :math:`\chi`, which corresponds to the variable `kraus_dim` in the code. The periodic boundary conditions imply that there are connections between the first and the last tensors, forming a trace over the product of matrices for the entire system. The implementation is based on `this paper <https://arxiv.org/abs/2401.14243>`_. """ hilbert: HomogeneousHilbert """Hilbert space on which the state is defined.""" bond_dim: int """Bond dimension of the MPDO tensors. See formula above.""" kraus_dim: int = 2 """The local Kraus dimension of the MPDO tensors. See formula above.""" symperiod: Optional[bool] = None """ Periodicity in the chain of MPDO tensors. The chain of MPDO tensors is constructed as a sequence of identical unit cells consisting of symperiod tensors. if None, symperiod equals the number of physical degrees of freedom. """ unroll: int = 1 """the number of scan iterations to unroll within a single iteration of a loop.""" checkpoint: bool = True """Whether to use jax.checkpoint on the scan function for memory efficiency.""" kernel_init: NNInitFunc = default_kernel_init """the initializer for the MPS weights. This is added to an identity tensor.""" param_dtype: DType = float """complex or float, whether the variational parameters of the MPDO are real or complex.""" def setup(self): L, d, D, Χ = ( self.hilbert.size, self.hilbert.local_size, self.bond_dim, self.kraus_dim, ) self._L, self._d, self._D, self. = L, d, D, Χ # determine the shape of the unit cell if self.symperiod is None: self._symperiod = L else: self._symperiod = self.symperiod if L % self._symperiod == 0 and self._symperiod > 0: unit_cell_shape = (self._symperiod, d, D, D, Χ) else: raise AssertionError( "The number of degrees of freedom of the Hilbert space needs to be a multiple of the period of the MPS" ) iden_tensors = jnp.repeat( jnp.eye(self.bond_dim, dtype=self.param_dtype)[jnp.newaxis, :, :], self._symperiod * d * Χ, axis=0, ) iden_tensors = iden_tensors.reshape(self._symperiod, d, D, D, Χ) self.tensors = ( self.param("tensors", self.kernel_init, unit_cell_shape, self.param_dtype) + iden_tensors )
[docs] def __call__(self, x): """ Queries this MPDO for the input configurations x, which should contain rows and columns entries concatenated. Args: x: the input configuration """ x = jnp.atleast_2d(x) assert x.shape[-1] == 2 * self.hilbert.size qn_x = self.hilbert.states_to_local_indices(x) ρ = jax.vmap(self.contract_mpdo, in_axes=0)(qn_x) return jnp.log(ρ.astype(dtype_complex(self.param_dtype)))
[docs] def contract_mpdo(self, qn): """ Internal function, used to contract the tensor network with some input tensor. Args: qn: The input tensor to be contracted with this MPDO """ # create all tensors in mps from unit cell all_tensors = jnp.tile(self.tensors, (self._L // self._symperiod, 1, 1, 1, 1)) edge = jnp.eye(self._D**2, dtype=self.param_dtype) def base_scan_func(edge, pair): tensor, qn_r, qn_c = pair tensor_contracted = jnp.einsum( "ijk,lmk->iljm", tensor[qn_r, :], jnp.conj(tensor[qn_c, :]) ) matrix = tensor_contracted.reshape(self._D**2, self._D**2) edge = edge @ matrix return edge, None # Apply jax.checkpoint conditionally to the base_scan_func scan_func = ( jax.checkpoint(base_scan_func) if self.checkpoint else base_scan_func ) qn_r, qn_c = jnp.split(qn, 2, axis=-1) edge, _ = jax.lax.scan( scan_func, edge, (all_tensors, qn_r, qn_c), unroll=self.unroll ) rho = jnp.trace(edge) return rho
[docs] class MPDOOpen(nn.Module): r""" A Matrix Product Density Operator (MPDO) with open boundary conditions for a quantum mixed state of discrete degrees of freedom. The purification is used. The MPDO is defined as (see `F. Verstraete, J. J. García-Ripoll, J. I. Cirac, Phys. Rev. Lett. 93, 207204 (2004) <https://doi.org/10.1103/PhysRevLett.93.207204>`_). .. math:: \rho(s_1,\dots, s_N, s_1',\dots, s_N') = \sum_{\alpha_1, \dots, \alpha_{N-1}} \mathrm{Tr} \left[ M^{\alpha_1}_{s_1,s_1'} \dots M^{\alpha_{N-1}}_{s_N, s_N'} \right], for arbitrary local quantum numbers :math:`s_i` and :math:`s_i'`, where :math:`M^{\alpha_i}_{s_i,s_i'}` are :math:`D^2 \times D^2` matrices for :math:`i=2, \dots, N-1`, and vectors for :math:`i=1, N`, that can be decomposed as .. math:: M^{\alpha_i}_{s_i,s_i'} = \sum_{a=1}^{\chi} A^{\alpha_i, a}_{s_i} \otimes (A^{\alpha_i, a}_{s_i'})^*, with :math:`A^{\alpha_i, a}_{s_i}` being :math:`D \times D` matrices for the bulk of the chain (:math:`i=2, \dots, N-1`) and :math:`D`-dimensional vectors for the edges (:math:`i=1, N`). The bond dimension is denoted by :math:`D` and the Kraus dimension by :math:`\chi`, which corresponds to the variable `kraus_dim` in the code. The open boundary conditions imply that there are no connections between the first and the last tensors in the trace. The implementation is based on `this paper <https://arxiv.org/abs/2401.14243>`_. """ hilbert: HomogeneousHilbert """Hilbert space on which the state is defined.""" bond_dim: int """Bond dimension of the MPDO tensors. See formula above.""" kraus_dim: int = 2 """The local Kraus dimension of the MPDO tensors. See formula above.""" unroll: int = 1 """the number of scan iterations to unroll within a single iteration of a loop.""" checkpoint: bool = True """Whether to use jax.checkpoint on the scan function for memory efficiency.""" kernel_init: NNInitFunc = default_kernel_init """the initializer for the MPS weights. This is added to an identity tensor.""" param_dtype: DType = float """complex or float, whether the variational parameters of the MPDO are real or complex.""" def setup(self): L, d, D, Χ = ( self.hilbert.size, self.hilbert.local_size, self.bond_dim, self.kraus_dim, ) self._L, self._d, self._D, self. = L, d, D, Χ self.param_dtype_cplx = dtype_complex(self.param_dtype) iden_boundary_tensor = jnp.ones((d, D, Χ), dtype=self.param_dtype) self.left_tensors = ( self.param("left_tensors", self.kernel_init, (d, D, Χ), self.param_dtype) + iden_boundary_tensor ) self.right_tensors = ( self.param("right_tensors", self.kernel_init, (d, D, Χ), self.param_dtype) + iden_boundary_tensor ) # determine the shape of the unit cell unit_cell_shape = (L - 2, d, D, D, Χ) # define diagonal tensors with correct unit cell shape iden_tensors = jnp.repeat( jnp.eye(D, dtype=self.param_dtype)[jnp.newaxis, :, :], (L - 2) * d * Χ, axis=0, ) iden_tensors = iden_tensors.reshape(L - 2, d, D, D, Χ) self.middle_tensors = ( self.param( "middle_tensors", self.kernel_init, unit_cell_shape, self.param_dtype ) + iden_tensors )
[docs] def __call__(self, x): """ Queries this MPDO for the input configurations x, which should contain rows and columns entries concatenated. Args: x: the input configuration """ x = jnp.atleast_2d(x) assert x.shape[-1] == 2 * self.hilbert.size qn_x = self.hilbert.states_to_local_indices(x) ρ = jax.vmap(self.contract_mpdo)(qn_x) return jnp.log(ρ.astype(dtype_complex(self.param_dtype)))
[docs] def contract_mpdo(self, qn): """ Internal function, used to contract the tensor network with some input tensor. Args: qn: The input tensor to be contracted with this MPDO """ qn_r, qn_c = jnp.split(qn, 2, axis=-1) left_edge = ( self.left_tensors[qn_r[0], :] @ jnp.conj(self.left_tensors[qn_c[0], :]).T ) right_edge = ( self.right_tensors[qn_r[-1], :] @ jnp.conj(self.right_tensors[qn_c[-1], :]).T ) def base_scan_func(edge, pair): tensor, qn_r, qn_c = pair triangle_tensor = jnp.einsum("ijk,il->jkl", tensor[qn_r, :], edge) edge = jnp.einsum("jkl,lmk->jm", triangle_tensor, jnp.conj(tensor[qn_c, :])) return edge, None # Apply jax.checkpoint conditionally to the base_scan_func scan_func = ( jax.checkpoint(base_scan_func) if self.checkpoint else base_scan_func ) edge, _ = jax.lax.scan( scan_func, left_edge, (self.middle_tensors, qn_r[1:-1], qn_c[1:-1]), unroll=self.unroll, ) rho = jnp.einsum("ij,ij->", edge, right_edge) return rho