Skip to content
Prev Previous commit
Next Next commit
fix for test failure
  • Loading branch information
martin-gorner committed Dec 20, 2024
commit a64e0933c81e1cfd68a7c1f8d2fc6ddd850781d7
2 changes: 1 addition & 1 deletion keras/src/backend/jax/distribution_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def test_distribute_data_input(self):
mesh, jax.sharding.PartitionSpec("batch", None)
)

result = backend_dlib.distribute_data_input(per_process_batch, layout)
result = backend_dlib.distribute_data_input(per_process_batch, layout, "batch")

# Check the shape of the global batch array
self.assertEqual(
Expand Down