Initialises a PRNGKey using an optional starting seed.
If using sharding, the returned key will be replicated on every process.
- Parameters:
seed (Union[int, Any, None]) – An optional integer value to use as seed
root (int) – the master rank, used when running with multiple nodes (default 0)
- Return type:
Any
- Returns:
A sharded/broadcasted jax.random.PRNGKey().