netket.operator.DiscreteJaxOperator#

class netket.operator.DiscreteJaxOperator[source]#

Bases: DiscreteOperator

Abstract base class for discrete operators that can be manipulated inside of jax function transformations.

Any operator inheriting from this base class follows the netket.operator.DiscreteOperator interface but can additionally be used inside of jax.jit(), jax.grad(), jax.vmap() or similar transformations. When passed to those functions, jax-compatible operators must not be passed as static arguments but as standard arguments, and they will not trigger recompilation if only the coefficients have changed.

Some operators, such as netket.operator.Ising or netket.operator.PauliStrings can be converted to their jax-enabled counterparts by calling the method to_jax_operator(). Not all operators support this conversion, but as netket.operator.PauliStrings are flexible, if you can convert or write your hamiltonian as a sum of pauli strings you will be able to use netket.operator.PauliStringsJax.

Note

Jax does not support dynamically varying shapes, so not all operators can be written as jax operators, and even if they could be written as such, they might generate more connected elements than their Numba counterpart.

Note

netket.operator.DiscreteJaxOperator require a particular version of the hamiltonian sampling rule, netket.sampler.rules.HamiltonianRuleJax(), that is compatible with Jax.

Defining custom discrete operators that are Jax-compatible#

This class should be inherited by DiscreteOperators which wish to declare jax-compatibility.

Classes inheriting from DiscreteJaxOperator` should be declared following a scheme like the following. Do notice in particular the declaration of the pytree flattening and unflattening, following the standard APIs of Jax discussed in the Jax Pytree documentation.

from jax.tree_util import register_pytree_node_class

@register_pytree_node_class
class MyJaxOperator(DiscreteJaxOperator):
    def __init__(hilbert, ...):
        super().__init__(hilbert)

    def tree_flatten(self):
        array_data = ( ... ) # all arrays
        struct_data = {'hilbert': self.hilbert,
                        ... # all constant data
                        }
        return array_data, struct_data

    @classmethod
    def tree_unflatten(cls, struct_data, array_data):
        ...
        return cls(array_data['hilbert'], ...)

    @property
    def max_conn_size(self) -> int:
        return ...

    def get_conn_padded(self, x):
        ...
        return xp, mels
Inheritance
Inheritance diagram of netket.operator.DiscreteJaxOperator
Attributes
H#

Returns the Conjugate-Transposed operator

T#

Returns the transposed operator

dtype#

The dtype of the operator’s matrix elements ⟨σ|Ô|σ’⟩.

hilbert#

The hilbert space associated to this observable.

is_hermitian#

Returns true if this operator is hermitian.

max_conn_size#

The maximum number of non zero ⟨x|O|x’⟩ for every x.

Methods
__call__(v)#

Call self as a function.

Return type:

ndarray

Parameters:

v (ndarray)

apply(v)#
Return type:

ndarray

Parameters:

v (ndarray)

collect()#

Returns a guaranteed concrete instance of an operator.

As some operations on operators return lazy wrappers (such as transpose, hermitian conjugate…), this is used to obtain a guaranteed non-lazy operator.

Return type:

AbstractOperator

conj(*, concrete=False)#
Return type:

AbstractOperator

conjugate(*, concrete=False)#

Returns the complex-conjugate of this operator.

Parameters:

concrete – if True returns a concrete operator and not a lazy wrapper

Return type:

AbstractOperator

Returns:

if concrete is not True, self or a lazy wrapper; the complex-conjugated operator otherwise

get_conn(x)#

Finds the connected elements of the Operator. Starting from a given quantum number x, it finds all other quantum numbers x’ such that the matrix element \(O(x,x')\) is different from zero. In general there will be several different connected states x’ satisfying this condition, and they are denoted here \(x'(k)\), for \(k=0,1...N_{\mathrm{connected}}\).

Parameters:

x (ndarray) – An array of shape (hilbert.size, ) containing the quantum numbers x.

Returns:

The connected states x’ of shape (N_connected,hilbert.size) array: An array containing the matrix elements \(O(x,x')\) associated to each x’.

Return type:

matrix

Raises:

ValueError – If the given quantum number is not compatible with the hilbert space.

get_conn_flattened(x, sections)[source]#

Finds the connected elements of the Operator.

Starting from a given quantum number \(x\), it finds all other quantum numbers \(x'\) such that the matrix element \(O(x,x')\) is different from zero. In general there will be several different connected states \(x'\) satisfying this condition, and they are denoted here \(x'(k)\), for \(k=0,1...N_{\mathrm{connected}}\).

This is a batched version, where x is a matrix of shape (batch_size,hilbert.size).

Parameters:
  • x (ndarray) – A matrix of shape (batch_size, hilbert.size) containing the batch of quantum numbers x.

  • sections (ndarray) – An array of sections for the flattened x’. See numpy.split for the meaning of sections.

Returns:

The connected states x’, flattened together in

a single matrix. An array containing the matrix elements \(O(x,x')\) associated to each x’.

Return type:

(matrix, array)

abstract get_conn_padded(x)[source]#

Finds the connected elements of the Operator. This method can be executed inside of a Jax function transformation.

Starting from a batch of quantum numbers \(x={x_1, ... x_n}\) of size \(B \times M\) where \(B\) size of the batch and \(M\) size of the hilbert space, finds all states \(y_i^1, ..., y_i^K\) connected to every \(x_i\).

Returns a matrix of size \(B \times K_{max} \times M\) where \(K_{max}\) is the maximum number of connections for every \(y_i\).

Parameters:

x (ndarray) – A N-tensor of shape \((...,hilbert.size)\) containing the batch/batches of quantum numbers \(x\).

Returns:

The connected states x’, in a N+1-tensor and an N-tensor containing the matrix elements \(O(x,x')\) associated to each x’ for every batch.

Return type:

(x_primes, mels)

n_conn(x, out=None)[source]#

Return the number of states connected to x.

Parameters:
  • x (matrix) – A matrix of shape (batch_size,hilbert.size) containing the batch of quantum numbers x.

  • out (array) – If None an output array is allocated.

Returns:

The number of connected states x’ for each x[i].

Return type:

array

to_dense()[source]#

Returns the dense matrix representation of the operator. Note that, in general, the size of the matrix is exponential in the number of quantum numbers, and this operation should thus only be performed for low-dimensional Hilbert spaces or sufficiently sparse operators.

This method requires an indexable Hilbert space.

Return type:

ndarray

Returns:

The dense matrix representation of the operator as a jax Array.

to_linear_operator()#
to_qobj()[source]#

Convert the operator to a qutip’s Qobj.

Returns:

A qutip.Qobj object.

to_sparse()[source]#

Returns the sparse matrix representation of the operator. Note that, in general, the size of the matrix is exponential in the number of quantum numbers, and this operation should thus only be performed for low-dimensional Hilbert spaces or sufficiently sparse operators.

This method requires an indexable Hilbert space.

Return type:

JAXSparse

Returns:

The sparse jax matrix representation of the operator.

transpose(*, concrete=False)#

Returns the transpose of this operator.

Parameters:

concrete – if True returns a concrete operator and not a lazy wrapper

Return type:

AbstractOperator

Returns:

if concrete is not True, self or a lazy wrapper; the transposed operator otherwise