Sharding & Model parallelism

Sharding & Model parallelism#

https://kauldron.rtfd.io/en/latest/sharding.html

[TOC]

Model parallelism#

Sharding is defined through the trainer.sharding attribute.

cfg.sharding = kd.sharding.ShardingStrategy(
    params={
        'encoder': kd.sharding.REPLICATED,
        'decoder': my_project.my_sharding_strategy,
    },
    opt_state=None,  # Let jax auto-infer the sharding
)

Each leaf of the sharding pytree can be:

  • None: Sharding is auto-inferred by jax (propagated from the inputs)

  • jax.sharding.Sharding: Explicitly set the sharding

  • Callable: Lazily compute the sharding from the array sub-tree:

    def my_sharding_strategy(params: PyTree[jax.Array]) -> kd.sharding.ShardingTree:
      devices = np.asarray(jax.devices())
      devices = devices.reshape((-1, jax.device_count() // 4))
      mesh = jax.sharding.Mesh(devices, axis_names=('data', 'params'))
    
      def _shard_param(path, x):
        if 'kernel' in path:
          return jax.sharding.NamedSharding(
              mesh, jax.sharding.PartitionSpec('params')
          )
        elif 'bias' in path:
          return kd.sharding.REPLICATED
        else:
          raise ValueError(f'Unexpected param: {path}: {x.shape}')
    
      return tree.map_structure_with_path(_shard_param, params)
    

Available sharding#

By default, Kauldron provides the following jax.sharding.Sharding:

  • kd.sharding.REPLICATED: All devices hold the same data. This is the default params sharding.

  • kd.sharding.FIRST_DIM: First dimension is sharded across devices. This is the default dataset sharding.

Sharding is applied internally with kd.sharding.with_sharding_constraint:

@jax.jit
def _step(state: TrainState, batch) -> TrainState:
  ...
  return kd.sharding.with_sharding_constraint(state, trainer.sharding.state)


for ex in trainer.train_ds.device_put(trainer.sharding.ds):
  state = _step(state, ex)