Skip to content
Prev Previous commit
Next Next commit
lint
  • Loading branch information
martin-gorner committed Dec 20, 2024
commit 2da7fdc013021f6a822a3af689b2c1aeb1d98c03
2 changes: 1 addition & 1 deletion keras/src/backend/jax/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def distribute_data_input(per_process_batch, layout, batch_dim_name):
if not isinstance(layout, jax.sharding.Sharding):
layout = _to_jax_layout(layout)

num_model_replicas_total = layout.mesh.shape[batch_dim_name] # batch dimension of the mesh
num_model_replicas_total = layout.mesh.shape[batch_dim_name]
mesh_shape = list(layout.mesh.shape.values())

# TODO: THIS IS COMPLETELY WRONG AS WELL FOR REPLICATING DATA ON "MODEL"
Expand Down
13 changes: 6 additions & 7 deletions keras/src/backend/jax/trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections
import itertools
from functools import partial

import jax
import numpy as np
Expand All @@ -16,8 +17,6 @@
from keras.src.trainers.epoch_iterator import EpochIterator
from keras.src.utils import traceback_utils

from functools import partial


class JAXTrainer(base_trainer.Trainer):
def __init__(self):
Expand Down Expand Up @@ -990,17 +989,17 @@ def _get_jax_state(

def _distribute_data(data, layouts=None):
distribution = distribution_lib.distribution()
jax_dist_data_input = partial(jax_distribution_lib.distribute_data_input,
batch_dim_name=distribution._batch_dim_name)
jax_dist_data_input = partial(
jax_distribution_lib.distribute_data_input,
batch_dim_name=distribution._batch_dim_name,
)
if distribution is not None:
if layouts is None:
layouts = tree.map_structure(
lambda d: distribution.get_data_layout(d.shape),
data,
)
return tree.map_structure(
jax_dist_data_input, data, layouts
)
return tree.map_structure(jax_dist_data_input, data, layouts)

return tree.map_structure(jax.device_put, data)

Expand Down
Loading