netket.utils.struct.ShardedFieldSpec#
- class netket.utils.struct.ShardedFieldSpec[source]#
Bases:
object
Specification of a sharded field.
Used to specify (for the time being) how to handle serialization/deserialization (using flax.serialize) of a field in a NetKet-style Pytree.
- Inheritance
- Attributes
-
deserialization_function:
ShardedDeserializationFunction
|str
|None
= 'relaxed'# Function to use to deserialize the data. Can be a callable with the signature:
def f(value_target: jax.Array, value_state: jax.Array, *, name: str = ".") -> jax.Array
or one of the following strings:
“fail”: Raise an error if the sharded data does not match the target data.
“strict”: Raise an error if the sharded data does not match the target data.
“relaxed”: Ignore extra data in the sharded data if the target data is smaller than the serialized data; error if the target data is larger.
“relaxed-ignore-errors”: Ignore extra data in the sharded data if the target data is smaller, and do nothing if the target data is larger.
“relaxed-rng-key”: Special case for RNG keys, where we can safely truncate the serialized data if the target data is larger.
The default is “relaxed”.
-
mpi_sharded_axis:
int
|None
= 0# Optional integer indicating which axis is to be considered sharded when running with MPI. Defaults to 0.
If this is specified, loading with MPI a state saved from a run with jax sharding will scatter/partition the data along the specified axis.
If None, MPI is not supported.
-
deserialization_function: