-
Notifications
You must be signed in to change notification settings - Fork 644
Expose common dataloader args #2097
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,153 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import unittest | ||
|
|
||
| from torch.utils.data import IterableDataset | ||
|
|
||
| from torchtitan.components.dataloader import ParallelAwareDataloader | ||
| from torchtitan.components.tokenizer import BaseTokenizer | ||
| from torchtitan.config import ConfigManager | ||
|
|
||
|
|
||
| class DummyDataset(IterableDataset): | ||
| """A simple dummy dataset for testing.""" | ||
|
|
||
| def __iter__(self): | ||
| for i in range(100): | ||
| yield {"input": i}, i | ||
|
|
||
|
|
||
| class DummyTokenizer(BaseTokenizer): | ||
| """A dummy tokenizer for testing that implements BaseTokenizer interface.""" | ||
|
|
||
| def __init__(self): | ||
| super().__init__() | ||
| self.eos_id = 2 | ||
|
|
||
| def encode( | ||
| self, text: str, add_bos: bool = False, add_eos: bool = False | ||
| ) -> list[int]: | ||
| # Simple encoding: convert each character to its ASCII value | ||
| tokens = [ord(c) for c in text] | ||
| if add_bos: | ||
| tokens.insert(0, 1) # BOS token | ||
| if add_eos: | ||
| tokens.append(self.eos_id) | ||
| return tokens | ||
|
|
||
| def decode(self, token_ids: list[int]) -> str: | ||
| # Simple decoding: convert ASCII values back to characters | ||
| return "".join(chr(t) for t in token_ids if t > 2) | ||
|
|
||
| def get_vocab_size(self) -> int: | ||
| return 256 # ASCII range | ||
|
|
||
|
|
||
| class TestParallelAwareDataloader(unittest.TestCase): | ||
| def test_dataloader_yields_correct_batches(self): | ||
| """Test that the dataloader correctly yields batched data from the dataset.""" | ||
| dataset = DummyDataset() | ||
| batch_size = 4 | ||
|
|
||
| dataloader = ParallelAwareDataloader( | ||
| dataset, | ||
| dp_rank=0, | ||
| dp_world_size=1, | ||
| batch_size=batch_size, | ||
| ) | ||
|
|
||
| batches = list(dataloader) | ||
|
|
||
| # DummyDataset yields 100 items, so we expect 25 batches of size 4 | ||
| self.assertEqual(len(batches), 25) | ||
|
|
||
| # Check first batch structure and values | ||
| first_batch_input, first_batch_label = batches[0] | ||
| self.assertEqual(len(first_batch_input["input"]), batch_size) | ||
| self.assertEqual(len(first_batch_label), batch_size) | ||
|
|
||
| # Verify first batch contains expected values (0, 1, 2, 3) | ||
| self.assertEqual(first_batch_input["input"].tolist(), [0, 1, 2, 3]) | ||
| self.assertEqual(first_batch_label.tolist(), [0, 1, 2, 3]) | ||
|
|
||
| # Check last batch | ||
| last_batch_input, last_batch_label = batches[-1] | ||
| self.assertEqual(last_batch_input["input"].tolist(), [96, 97, 98, 99]) | ||
| self.assertEqual(last_batch_label.tolist(), [96, 97, 98, 99]) | ||
|
|
||
| def test_validate_kwargs_rejects_invalid_kwargs(self): | ||
| """Test that passing invalid kwargs raises ValueError.""" | ||
| dataset = DummyDataset() | ||
|
|
||
| with self.assertRaises(ValueError) as context: | ||
| ParallelAwareDataloader( | ||
| dataset, | ||
| dp_rank=0, | ||
| dp_world_size=1, | ||
| invalid_arg=42, | ||
| ) | ||
|
|
||
| self.assertIn("Invalid dataloader kwargs", str(context.exception)) | ||
| self.assertIn("invalid_arg", str(context.exception)) | ||
|
|
||
| def test_config_batch_size_overwritten_by_explicit_batch_size(self): | ||
| """Test that batch_size in config kwargs is overwritten by explicit batch_size.""" | ||
| dataset = DummyDataset() | ||
|
|
||
| config_kwargs = {"batch_size": 2, "num_workers": 0} | ||
|
|
||
| explicit_batch_size = 8 | ||
|
|
||
| # Merge kwargs with explicit args taking precedence (same pattern as in dataset files) | ||
| dataloader_kwargs = { | ||
| **config_kwargs, | ||
| "batch_size": explicit_batch_size, | ||
| } | ||
|
|
||
| dataloader = ParallelAwareDataloader( | ||
| dataset, | ||
| dp_rank=0, | ||
| dp_world_size=1, | ||
| **dataloader_kwargs, | ||
| ) | ||
|
|
||
| # Verify that batch_size is the explicit one, not the config one | ||
| self.assertEqual(dataloader.batch_size, explicit_batch_size) | ||
|
|
||
| def test_build_dataloader_with_job_config(self): | ||
| """Verify batch_size from job_config.training.local_batch_size is correctly used.""" | ||
| from torchtitan.hf_datasets.text_datasets import build_text_dataloader | ||
|
|
||
| tokenizer = DummyTokenizer() | ||
|
|
||
| config_manager = ConfigManager() | ||
| config = config_manager.parse_args( | ||
| [ | ||
| "--training.dataset", | ||
| "c4_test", | ||
| "--training.local_batch_size", | ||
| "8", | ||
| "--training.seq_len", | ||
| "512", | ||
| "--training.dataloader.num_workers", | ||
| "2", | ||
| ] | ||
| ) | ||
|
|
||
| dataloader = build_text_dataloader( | ||
| tokenizer=tokenizer, | ||
| dp_world_size=1, | ||
| dp_rank=0, | ||
| job_config=config, | ||
| ) | ||
|
|
||
| self.assertEqual(dataloader.batch_size, 8) | ||
| self.assertEqual(dataloader.num_workers, 2) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -198,6 +198,40 @@ class LRScheduler: | |
| """ | ||
|
|
||
|
|
||
| @dataclass | ||
| class DataLoader: | ||
| """ | ||
| Configuration for PyTorch DataLoader settings. | ||
| These settings are passed directly to StatefulDataLoader. | ||
| Note: | ||
| persistent_workers and prefetch_factor are only valid if num_workers > 0. | ||
| Example (TOML config file): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you add a dedicated test for dataloader with kwargs passed through?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a GPU integration test. To be able to use the cli to pass in the kwargs I added a tyro rule. I am not super familiar with tyro so please have a look. Also, shout out to the integration test setup. Love that we could do a quick mini-GPU run as part of feature testing. |
||
| [training.dataloader] | ||
| num_workers = 4 | ||
| pin_memory = true | ||
| persistent_workers = true | ||
| prefetch_factor = 2 | ||
| """ | ||
|
|
||
| num_workers: int = 0 | ||
| """Number of worker processes for data loading.""" | ||
|
|
||
| persistent_workers: bool = False | ||
| """Keep workers alive between epochs. Only valid when num_workers > 0.""" | ||
|
|
||
| pin_memory: bool = False | ||
| """Copy tensors to CUDA pinned memory before returning them.""" | ||
|
|
||
| prefetch_factor: int | None = None | ||
| """ | ||
| Number of batches loaded in advance by each worker. Only valid when num_workers > 0. | ||
| Default is 2 when num_workers > 0, otherwise None. | ||
| """ | ||
|
|
||
|
|
||
| @dataclass | ||
| class Training: | ||
| dataset: str = "c4_test" | ||
|
|
@@ -263,6 +297,9 @@ class Training: | |
| many temporary files. | ||
| """ | ||
|
|
||
| dataloader: DataLoader = field(default_factory=DataLoader) | ||
| """DataLoader configuration""" | ||
|
|
||
|
|
||
| @dataclass | ||
| class Parallelism: | ||
|
|
@@ -914,6 +951,9 @@ class Validation: | |
| WARNING: When setting to -1 there could be hangs due to mismatch among ranks | ||
| """ | ||
|
|
||
| dataloader: DataLoader = field(default_factory=DataLoader) | ||
| """DataLoader configuration""" | ||
|
|
||
| def __post_init__(self): | ||
| assert ( | ||
| self.steps > 0 or self.steps == -1 | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.