Shuffle before batching in TF load_data_fashion_mnist#2700
Open
Chessing234 wants to merge 1 commit into
Open
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Bug
`load_data_fashion_mnist` in the TensorFlow build of `d2l` (`chapter_appendix-tools-for-deep-learning/utils.md` → `d2l/tensorflow.py`) chains the train pipeline as `.batch(batch_size).shuffle(len(mnist_train[0]))`:
```python
return (
tf.data.Dataset.from_tensor_slices(process(*mnist_train)).batch(
batch_size).shuffle(len(mnist_train[0])).map(resize_fn),
tf.data.Dataset.from_tensor_slices(process(*mnist_test)).batch(
batch_size).map(resize_fn))
```
Because `tf.data.Dataset` operators fire top-down, this first groups items into fixed `batch_size` blocks and only then shuffles those blocks. Batch composition is frozen for the lifetime of the dataset — every epoch sees the exact same 64-item groups, only their order changes. This defeats the item-level stochasticity SGD relies on.
Root cause
The two transform calls are reversed. Everywhere else in the same module the idiomatic order is used, e.g. `DataModule.get_tensorloader` in `d2l/tensorflow.py` (`chapter_builders-guide/oo-design.md`):
```python
tf.data.Dataset.from_tensor_slices(tensors).shuffle(
buffer_size=shuffle_buffer).batch(self.batch_size)
```
That's `.shuffle(...).batch(...)` — shuffle at item granularity, then batch from the shuffled stream. The Fashion-MNIST loader is the single place that got the two composed in the wrong direction. The PyTorch counterpart (`torch.utils.data.DataLoader(..., shuffle=True)`) and the MXNet counterpart (`gluon.data.DataLoader(..., shuffle=True)`) both correctly shuffle at item granularity.
Why the fix is correct
Swapping to `.shuffle(len(mnist_train[0])).batch(batch_size).map(resize_fn)` fills a shuffle buffer of size 60000 (the full training set) with individual items before batching, so each epoch produces a fresh, fully randomized set of batches — matching the torch/mxnet loaders and the rest of the `tf.data` pipelines in this package. The test pipeline is left untouched (no shuffle), and the `map(resize_fn)` stage still runs on the batched tensors exactly as before.
The fix is applied in both the `utils.md` source (canonical) and the generated `d2l/tensorflow.py` so the change is visible to readers before `d2lbook build lib` is re-run.