Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions tests/integration_tests/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,21 @@ def build_features_test_list() -> list[OverrideDefinitions]:
"validation_tp_cp_pp",
ngpu=8,
),
OverrideDefinitions(
[
[
"--training.dataloader.num_workers",
"2",
"--training.dataloader.pin_memory",
"--training.dataloader.persistent_workers",
"--training.dataloader.prefetch_factor",
"4",
],
],
"Dataloader kwargs (via CLI args)",
"dataloader_kwargs",
ngpu=2,
),
]

return integration_tests_flavors
153 changes: 153 additions & 0 deletions tests/unit_tests/test_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

from torch.utils.data import IterableDataset

from torchtitan.components.dataloader import ParallelAwareDataloader
from torchtitan.components.tokenizer import BaseTokenizer
from torchtitan.config import ConfigManager


class DummyDataset(IterableDataset):
"""A simple dummy dataset for testing."""

def __iter__(self):
for i in range(100):
yield {"input": i}, i


class DummyTokenizer(BaseTokenizer):
"""A dummy tokenizer for testing that implements BaseTokenizer interface."""

def __init__(self):
super().__init__()
self.eos_id = 2

def encode(
self, text: str, add_bos: bool = False, add_eos: bool = False
) -> list[int]:
# Simple encoding: convert each character to its ASCII value
tokens = [ord(c) for c in text]
if add_bos:
tokens.insert(0, 1) # BOS token
if add_eos:
tokens.append(self.eos_id)
return tokens

def decode(self, token_ids: list[int]) -> str:
# Simple decoding: convert ASCII values back to characters
return "".join(chr(t) for t in token_ids if t > 2)

def get_vocab_size(self) -> int:
return 256 # ASCII range


class TestParallelAwareDataloader(unittest.TestCase):
def test_dataloader_yields_correct_batches(self):
"""Test that the dataloader correctly yields batched data from the dataset."""
dataset = DummyDataset()
batch_size = 4

dataloader = ParallelAwareDataloader(
dataset,
dp_rank=0,
dp_world_size=1,
batch_size=batch_size,
)

batches = list(dataloader)

# DummyDataset yields 100 items, so we expect 25 batches of size 4
self.assertEqual(len(batches), 25)

# Check first batch structure and values
first_batch_input, first_batch_label = batches[0]
self.assertEqual(len(first_batch_input["input"]), batch_size)
self.assertEqual(len(first_batch_label), batch_size)

# Verify first batch contains expected values (0, 1, 2, 3)
self.assertEqual(first_batch_input["input"].tolist(), [0, 1, 2, 3])
self.assertEqual(first_batch_label.tolist(), [0, 1, 2, 3])

# Check last batch
last_batch_input, last_batch_label = batches[-1]
self.assertEqual(last_batch_input["input"].tolist(), [96, 97, 98, 99])
self.assertEqual(last_batch_label.tolist(), [96, 97, 98, 99])

def test_validate_kwargs_rejects_invalid_kwargs(self):
"""Test that passing invalid kwargs raises ValueError."""
dataset = DummyDataset()

with self.assertRaises(ValueError) as context:
ParallelAwareDataloader(
dataset,
dp_rank=0,
dp_world_size=1,
invalid_arg=42,
)

self.assertIn("Invalid dataloader kwargs", str(context.exception))
self.assertIn("invalid_arg", str(context.exception))

def test_config_batch_size_overwritten_by_explicit_batch_size(self):
"""Test that batch_size in config kwargs is overwritten by explicit batch_size."""
dataset = DummyDataset()

config_kwargs = {"batch_size": 2, "num_workers": 0}

explicit_batch_size = 8

# Merge kwargs with explicit args taking precedence (same pattern as in dataset files)
dataloader_kwargs = {
**config_kwargs,
"batch_size": explicit_batch_size,
}

dataloader = ParallelAwareDataloader(
dataset,
dp_rank=0,
dp_world_size=1,
**dataloader_kwargs,
)

# Verify that batch_size is the explicit one, not the config one
self.assertEqual(dataloader.batch_size, explicit_batch_size)

def test_build_dataloader_with_job_config(self):
"""Verify batch_size from job_config.training.local_batch_size is correctly used."""
from torchtitan.hf_datasets.text_datasets import build_text_dataloader

tokenizer = DummyTokenizer()

config_manager = ConfigManager()
config = config_manager.parse_args(
[
"--training.dataset",
"c4_test",
"--training.local_batch_size",
"8",
"--training.seq_len",
"512",
"--training.dataloader.num_workers",
"2",
]
)

dataloader = build_text_dataloader(
tokenizer=tokenizer,
dp_world_size=1,
dp_rank=0,
job_config=config,
)

self.assertEqual(dataloader.batch_size, 8)
self.assertEqual(dataloader.num_workers, 2)


if __name__ == "__main__":
unittest.main()
52 changes: 44 additions & 8 deletions torchtitan/components/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

import inspect
import pickle
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any

from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset
from torchdata.stateful_dataloader import StatefulDataLoader
from torchtitan.tools.logging import logger


# NOTE: This class deliberately inherits from `Exception` and not `StopIteration`.
# According to PEP 479, raising a `StopIteration` or its subclass from within a
# generator will wrap it in a `RuntimeError`. Since this exception is designed
Expand Down Expand Up @@ -53,28 +54,63 @@ class ParallelAwareDataloader(StatefulDataLoader, BaseDataLoader):
dataset (IterableDataset): The dataset to iterate over.
dp_rank: Data parallelism rank for this dataloader.
dp_world_size: The world size of the data parallelism.
batch_size: The batch size to use for each iteration.
collate_fn: Optional function to collate samples in a batch.
**kwargs: Additional keyword arguments passed to StatefulDataLoader (e.g.,
batch_size, collate_fn, num_workers, persistent_workers, prefetch_factor,
pin_memory).
"""

dp_rank: int
dp_world_size: int
batch_size: int | None

def __init__(
self,
dataset: IterableDataset,
dp_rank: int,
dp_world_size: int,
batch_size: int,
collate_fn: Callable | None = None,
**kwargs,
):
self._validate_kwargs(kwargs)

self.dp_world_size = dp_world_size
self.dp_rank = dp_rank
self.batch_size = batch_size
super().__init__(dataset, batch_size, collate_fn=collate_fn)
self._rank_id = f"dp_rank_{dp_rank}"

super().__init__(dataset, **kwargs)

@staticmethod
def _validate_kwargs(kwargs: dict[str, Any]) -> None:
"""Validate and sanitize kwargs passed to the dataloader.
Args:
kwargs: Dictionary of keyword arguments to validate. This dict is
modified in-place to remove invalid combinations.
Raises:
ValueError: If 'dataset' is in kwargs or if any invalid kwargs are passed.
"""
if "dataset" in kwargs:
raise ValueError(
"'dataset' should not be passed in kwargs; "
"it must be provided as the first positional argument."
)

sig = inspect.signature(StatefulDataLoader.__init__)
valid_kwargs = frozenset(
name for name in sig.parameters.keys() if name not in ("self", "dataset")
)
invalid_kwargs = set(kwargs.keys()) - valid_kwargs
if invalid_kwargs:
raise ValueError(
f"Invalid dataloader kwargs: {invalid_kwargs}. "
f"Valid kwargs are: {sorted(valid_kwargs)}"
)

# persistent_workers and prefetch_factor are only valid when num_workers > 0.
# Removing them here if num_workers is 0 to avoid StatefulDataLoader errors
if kwargs.get("num_workers", 0) == 0:
kwargs.pop("persistent_workers", None)
kwargs.pop("prefetch_factor", None)

def state_dict(self) -> dict[str, Any]:
# Store state only for dp rank to avoid replicating the same state across other dimensions.
return {
Expand Down
40 changes: 40 additions & 0 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,40 @@ class LRScheduler:
"""


@dataclass
class DataLoader:
"""
Configuration for PyTorch DataLoader settings.
These settings are passed directly to StatefulDataLoader.
Note:
persistent_workers and prefetch_factor are only valid if num_workers > 0.
Example (TOML config file):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a dedicated test for dataloader with kwargs passed through?
https://github.com/pytorch/torchtitan/blob/main/tests/integration_tests/features.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a GPU integration test. To be able to use the cli to pass in the kwargs I added a tyro rule. I am not super familiar with tyro so please have a look.

Also, shout out to the integration test setup. Love that we could do a quick mini-GPU run as part of feature testing.

[training.dataloader]
num_workers = 4
pin_memory = true
persistent_workers = true
prefetch_factor = 2
"""

num_workers: int = 0
"""Number of worker processes for data loading."""

persistent_workers: bool = False
"""Keep workers alive between epochs. Only valid when num_workers > 0."""

pin_memory: bool = False
"""Copy tensors to CUDA pinned memory before returning them."""

prefetch_factor: int | None = None
"""
Number of batches loaded in advance by each worker. Only valid when num_workers > 0.
Default is 2 when num_workers > 0, otherwise None.
"""


@dataclass
class Training:
dataset: str = "c4_test"
Expand Down Expand Up @@ -263,6 +297,9 @@ class Training:
many temporary files.
"""

dataloader: DataLoader = field(default_factory=DataLoader)
"""DataLoader configuration"""


@dataclass
class Parallelism:
Expand Down Expand Up @@ -914,6 +951,9 @@ class Validation:
WARNING: When setting to -1 there could be hangs due to mismatch among ranks
"""

dataloader: DataLoader = field(default_factory=DataLoader)
"""DataLoader configuration"""

def __post_init__(self):
assert (
self.steps > 0 or self.steps == -1
Expand Down
22 changes: 14 additions & 8 deletions torchtitan/experiments/vlm/datasets/mm_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
It supports both streaming and non-streaming datasets from HuggingFace.
"""

from dataclasses import asdict
from typing import Any, Callable

import torch
Expand Down Expand Up @@ -381,14 +382,14 @@ def build_mm_dataloader(
"""Build a data loader for multimodal datasets.
Args:
dp_world_size: Data parallel world size
dp_rank: Data parallel rank
tokenizer: Tokenizer for text processing
job_config: Job configuration
infinite: Whether to loop infinitely
dp_world_size: Data parallel world size.
dp_rank: Data parallel rank.
tokenizer: Tokenizer for text processing.
job_config: Job configuration containing dataset and DataLoader settings.
infinite: Whether to loop infinitely.
Returns:
DataLoader with appropriate parallelism handling
DataLoader with appropriate parallelism handling.
"""
dataset_path = job_config.training.dataset_path
batch_size = job_config.training.local_batch_size
Expand Down Expand Up @@ -429,12 +430,17 @@ def build_mm_dataloader(
special_tokens=special_tokens,
)

dataloader_kwargs = {
**asdict(job_config.training.dataloader),
"batch_size": batch_size,
"collate_fn": collate_fn,
}

base_dataloader = ParallelAwareDataloader(
dataset=dataset,
dp_rank=dp_rank,
dp_world_size=dp_world_size,
batch_size=batch_size,
collate_fn=collate_fn,
**dataloader_kwargs,
)

return base_dataloader
Loading
Loading