netket.utils.struct.ShardedFieldSpec#

class netket.utils.struct.ShardedFieldSpec[source]#

Bases: object

Specification of a sharded field.

Used to specify (for the time being) how to handle serialization/deserialization (using flax.serialize) of a field in a NetKet-style Pytree.

Inheritance
Inheritance diagram of netket.utils.struct.ShardedFieldSpec
Attributes
deserialization_function: ShardedDeserializationFunction | str | None = 'relaxed'#

Function to use to deserialize the data. Can be a callable with the signature:

def f(value_target: jax.Array, value_state: jax.Array, *, name: str = ".") -> jax.Array

or one of the following strings:

  • “fail”: Raise an error if the sharded data does not match the target data.

  • “strict”: Raise an error if the sharded data does not match the target data.

  • “relaxed”: Ignore extra data in the sharded data if the target data is smaller than the serialized data; error if the target data is larger.

  • “relaxed-ignore-errors”: Ignore extra data in the sharded data if the target data is smaller, and do nothing if the target data is larger.

  • “relaxed-rng-key”: Special case for RNG keys, where we can safely truncate the serialized data if the target data is larger.

The default is “relaxed”.

mpi_sharded_axis: int | None = 0#

Optional integer indicating which axis is to be considered sharded when running with MPI. Defaults to 0.

If this is specified, loading with MPI a state saved from a run with jax sharding will scatter/partition the data along the specified axis.

If None, MPI is not supported.

sharded: bool = True#

Boolean indicating whether the field is sharded or not. If False, all other fields are ignored.