Skip to content
Next Next commit
fixing wrong trainer assumption that batch dim is always the first on…
…e in the mesh
  • Loading branch information
martin-gorner committed Dec 20, 2024
commit 6ecc55c042442728f38a0fb292191501acf3b29e
15 changes: 13 additions & 2 deletions keras/src/backend/jax/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def distribute_tensor(tensor, layout):
return global_value


def distribute_data_input(per_process_batch, layout):
def distribute_data_input(per_process_batch, layout, batch_dim_name):
"""Distribute the input data with the corresponding layout.

Note that the inputs here is a local worker batch. Within the local worker,
Expand All @@ -117,9 +117,20 @@ def distribute_data_input(per_process_batch, layout):
if not isinstance(layout, jax.sharding.Sharding):
layout = _to_jax_layout(layout)

num_model_replicas_total = layout.mesh.shape[batch_dim_name] # batch dimension of the mesh
mesh_shape = list(layout.mesh.shape.values())
num_model_replicas_total = mesh_shape[0] # batch dimension of the mesh

# TODO: THIS IS COMPLETELY WRONG AS WELL FOR REPLICATING DATA ON "MODEL"
# dimensions: there may be more than one and the index ins not always "1"
mesh_model_dim_size = mesh_shape[1] if len(mesh_shape) > 1 else 1

# TODO: proper fix for this quick and dirty hack
# this only works for 2D meshes
mesh_model_dim_size = 1
for name, dim_size in layout.mesh.shape.items():
if not name == batch_dim_name:
mesh_model_dim_size = dim_size

num_model_replicas_per_process = num_model_replicas_total / num_processes()
per_process_batch_size = per_process_batch.shape[0]

Expand Down
2 changes: 1 addition & 1 deletion keras/src/backend/jax/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,7 +995,7 @@ def _distribute_data(data, layouts=None):
data,
)
return tree.map_structure(
jax_distribution_lib.distribute_data_input, data, layouts
jax_distribution_lib.distribute_data_input, data, layouts, distribution._batch_dim_name
)

return tree.map_structure(jax.device_put, data)
Expand Down