Sharding & Model parallelism#
https://kauldron.rtfd.io/en/latest/sharding.html
[TOC]
Model parallelism#
Sharding is defined through the trainer.sharding attribute.
cfg.sharding = kd.sharding.ShardingStrategy(
params={
'encoder': kd.sharding.REPLICATED,
'decoder': my_project.my_sharding_strategy,
},
opt_state=None, # Let jax auto-infer the sharding
)
Each leaf of the sharding pytree can be:
None: Sharding is auto-inferred by jax (propagated from the inputs)jax.sharding.Sharding: Explicitly set the shardingCallable: Lazily compute the sharding from the array sub-tree:def my_sharding_strategy(params: PyTree[jax.Array]) -> kd.sharding.ShardingTree: devices = np.asarray(jax.devices()) devices = devices.reshape((-1, jax.device_count() // 4)) mesh = jax.sharding.Mesh(devices, axis_names=('data', 'params')) def _shard_param(path, x): if 'kernel' in path: return jax.sharding.NamedSharding( mesh, jax.sharding.PartitionSpec('params') ) elif 'bias' in path: return kd.sharding.REPLICATED else: raise ValueError(f'Unexpected param: {path}: {x.shape}') return tree.map_structure_with_path(_shard_param, params)
Available sharding#
By default, Kauldron provides the following jax.sharding.Sharding:
kd.sharding.REPLICATED: All devices hold the same data. This is the defaultparamssharding.kd.sharding.FIRST_DIM: First dimension is sharded across devices. This is the default dataset sharding.
Sharding is applied internally with kd.sharding.with_sharding_constraint:
@jax.jit
def _step(state: TrainState, batch) -> TrainState:
...
return kd.sharding.with_sharding_constraint(state, trainer.sharding.state)
for ex in trainer.train_ds.device_put(trainer.sharding.ds):
state = _step(state, ex)