Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Use dl kwargs internally in titan, but expose only critical args to e…
…nd-users through config
  • Loading branch information
divyanshk committed Dec 19, 2025
commit b4a97efd91972fef4266d6fe3b224b58ecc222bf
9 changes: 7 additions & 2 deletions tests/integration_tests/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,10 +560,15 @@ def build_features_test_list() -> list[OverrideDefinitions]:
OverrideDefinitions(
[
[
'--training.dataloader.kwargs \'{"num_workers": 2, "pin_memory": true, "prefetch_factor": 2}\'',
"--training.dataloader.num_workers",
"2",
"--training.dataloader.pin_memory",
"--training.dataloader.persistent_workers",
"--training.dataloader.prefetch_factor",
"4",
],
],
"Dataloader kwargs",
"Dataloader kwargs (via CLI args)",
"dataloader_kwargs",
ngpu=2,
),
Expand Down
44 changes: 36 additions & 8 deletions tests/unit_tests/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,37 @@ def get_vocab_size(self) -> int:


class TestParallelAwareDataloader(unittest.TestCase):
def test_dataloader_yields_correct_batches(self):
"""Test that the dataloader correctly yields batched data from the dataset."""
dataset = DummyDataset()
batch_size = 4

dataloader = ParallelAwareDataloader(
dataset,
dp_rank=0,
dp_world_size=1,
batch_size=batch_size,
)

batches = list(dataloader)

# DummyDataset yields 100 items, so we expect 25 batches of size 4
self.assertEqual(len(batches), 25)

# Check first batch structure and values
first_batch_input, first_batch_label = batches[0]
self.assertEqual(len(first_batch_input["input"]), batch_size)
self.assertEqual(len(first_batch_label), batch_size)

# Verify first batch contains expected values (0, 1, 2, 3)
self.assertEqual(first_batch_input["input"].tolist(), [0, 1, 2, 3])
self.assertEqual(first_batch_label.tolist(), [0, 1, 2, 3])

# Check last batch
last_batch_input, last_batch_label = batches[-1]
self.assertEqual(last_batch_input["input"].tolist(), [96, 97, 98, 99])
self.assertEqual(last_batch_label.tolist(), [96, 97, 98, 99])

def test_validate_kwargs_rejects_invalid_kwargs(self):
"""Test that passing invalid kwargs raises ValueError."""
dataset = DummyDataset()
Expand Down Expand Up @@ -87,10 +118,8 @@ def test_config_batch_size_overwritten_by_explicit_batch_size(self):
# Verify that batch_size is the explicit one, not the config one
self.assertEqual(dataloader.batch_size, explicit_batch_size)

def test_build_dataloader_with_job_config_override(self):
"""Verify batch_size from job_config.training.local_batch_size
overrides batch_size in job_config.training.dataloader.kwargs.
"""
def test_build_dataloader_with_job_config(self):
"""Verify batch_size from job_config.training.local_batch_size is correctly used."""
from torchtitan.hf_datasets.text_datasets import build_text_dataloader

tokenizer = DummyTokenizer()
Expand All @@ -104,21 +133,20 @@ def test_build_dataloader_with_job_config_override(self):
"8",
"--training.seq_len",
"512",
"--training.dataloader.num_workers",
"2",
]
)

# Manually set batch_size in dataloader.kwargs to simulate conflict
config.training.dataloader.kwargs["batch_size"] = 2

dataloader = build_text_dataloader(
tokenizer=tokenizer,
dp_world_size=1,
dp_rank=0,
job_config=config,
)

# The dataloader should use training.local_batch_size (8), not kwargs batch_size (2)
self.assertEqual(dataloader.batch_size, 8)
self.assertEqual(dataloader.num_workers, 2)


if __name__ == "__main__":
Expand Down
8 changes: 1 addition & 7 deletions tests/unit_tests/test_dataset_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,9 @@ def test_load_dataset(self):
for world_size in [2]:
for rank in range(world_size):
dataset_name = "cc12m-test-iterable"

batch_size = 1

num_steps = 15
num_workers = 4

# TODO: if num_steps * batch_size * world_size is larger than the number of samples
# in the dataset, then the test will fail, due to huggingface's
Expand All @@ -75,9 +74,6 @@ def test_load_dataset(self):
]
)

# Set num_workers via the kwargs dict
config.training.dataloader.kwargs["num_workers"] = num_workers

dl = build_flux_dataloader(
dp_world_size=world_size,
dp_rank=rank,
Expand All @@ -86,8 +82,6 @@ def test_load_dataset(self):
infinite=True,
)

assert dl.num_workers == num_workers

it = iter(dl)

for i in range(0, num_steps):
Expand Down
12 changes: 11 additions & 1 deletion torchtitan/components/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,11 @@ def __init__(

@staticmethod
def _validate_kwargs(kwargs: dict[str, Any]) -> None:
"""Validate kwargs passed to the dataloader.
"""Validate and sanitize kwargs passed to the dataloader.

Args:
kwargs: Dictionary of keyword arguments to validate. This dict is
modified in-place to remove invalid combinations.

Raises:
ValueError: If 'dataset' is in kwargs or if any invalid kwargs are passed.
Expand All @@ -101,6 +105,12 @@ def _validate_kwargs(kwargs: dict[str, Any]) -> None:
f"Valid kwargs are: {sorted(valid_kwargs)}"
)

# persistent_workers and prefetch_factor are only valid when num_workers > 0.
# Removing them here if num_workers is 0 to avoid StatefulDataLoader errors
if kwargs.get("num_workers", 0) == 0:
kwargs.pop("persistent_workers", None)
kwargs.pop("prefetch_factor", None)

def state_dict(self) -> dict[str, Any]:
# Store state only for dp rank to avoid replicating the same state across other dimensions.
return {
Expand Down
34 changes: 19 additions & 15 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,29 +203,33 @@ class DataLoader:
"""
Configuration for PyTorch DataLoader settings.

This is a flexible kwargs container that passes all fields directly to
StatefulDataLoader. Common options include:
- num_workers: Number of worker processes for data loading (default: 0)
- persistent_workers: Keep workers alive between epochs (default: False)
- prefetch_factor: Batches loaded in advance per worker (default: None)
- pin_memory: Copy tensors to CUDA pinned memory (default: False)

Warning:
Some arguments are set internally and will override any values provided
in kwargs. These include:
- batch_size: Determined by training.local_batch_size
- collate_fn: Set by the dataset-specific collator
These settings are passed directly to StatefulDataLoader.

Note:
persistent_workers and prefetch_factor are only valid if num_workers > 0.

Example (TOML config file):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a dedicated test for dataloader with kwargs passed through?
https://github.com/pytorch/torchtitan/blob/main/tests/integration_tests/features.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a GPU integration test. To be able to use the cli to pass in the kwargs I added a tyro rule. I am not super familiar with tyro so please have a look.

Also, shout out to the integration test setup. Love that we could do a quick mini-GPU run as part of feature testing.

[training.dataloader.kwargs]
[training.dataloader]
num_workers = 4
pin_memory = true
persistent_workers = true
prefetch_factor = 2
"""

kwargs: dict[str, Any] = field(default_factory=dict)
"""Keyword arguments passed to StatefulDataLoader."""
num_workers: int = 0
"""Number of worker processes for data loading."""

persistent_workers: bool = False
"""Keep workers alive between epochs. Only valid when num_workers > 0."""

pin_memory: bool = False
"""Copy tensors to CUDA pinned memory before returning them."""

prefetch_factor: int | None = None
"""
Number of batches loaded in advance by each worker. Only valid when num_workers > 0.
Default is 2 when num_workers > 0, otherwise None.
"""


@dataclass
Expand Down
36 changes: 0 additions & 36 deletions torchtitan/config/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.

import importlib
import json
import os
import sys

Expand Down Expand Up @@ -237,41 +236,6 @@ def list_str_rule(type_info: tyro.constructors.PrimitiveTypeInfo):
str_from_instance=lambda instance: [",".join(instance)],
)

@registry.primitive_rule
def dict_str_any_rule(type_info: tyro.constructors.PrimitiveTypeInfo):
"""Support for dict[str, Any] parsing from CLI.

Accepts JSON format: {"key": "value", "num": 123, "flag": true}

Note: When using from command line, wrap in single quotes for shell escaping:
--training.dataloader.kwargs '{"num_workers": 2, "pin_memory": true}'

The single quotes prevent bash from interpreting {}, spaces, and double quotes.
"""
if type_info.type != dict[str, Any]:
return None

def parse_dict(args: list[str]) -> dict[str, Any]:
if not args or not args[0]:
return {}
try:
return json.loads(args[0])
except json.JSONDecodeError as e:
raise ValueError(
f"Invalid JSON for dict argument: {args[0]}. Error: {e}"
) from e

def dict_to_str(instance: dict[str, Any]) -> list[str]:
return [json.dumps(instance)]

return tyro.constructors.PrimitiveConstructorSpec(
nargs=1,
metavar='{"key": value, ...}',
instance_from_str=parse_dict,
is_instance=lambda instance: isinstance(instance, dict),
str_from_instance=dict_to_str,
)


# Initialize the custom registry for tyro
custom_registry = tyro.constructors.ConstructorRegistry()
Expand Down
4 changes: 2 additions & 2 deletions torchtitan/experiments/vlm/datasets/mm_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
It supports both streaming and non-streaming datasets from HuggingFace.
"""

from dataclasses import asdict
from typing import Any, Callable

import torch
Expand Down Expand Up @@ -429,9 +430,8 @@ def build_mm_dataloader(
special_tokens=special_tokens,
)

# Merge config kwargs with explicit args (explicit args take precedence)
dataloader_kwargs = {
**job_config.training.dataloader.kwargs,
**asdict(job_config.training.dataloader),
"batch_size": batch_size,
"collate_fn": collate_fn,
}
Expand Down
7 changes: 3 additions & 4 deletions torchtitan/hf_datasets/text_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import asdict
from functools import partial
from typing import Any, Callable

Expand Down Expand Up @@ -196,9 +197,8 @@ def build_text_dataloader(
infinite=infinite,
)

# Merge config kwargs with explicit args (explicit args take precedence)
dataloader_kwargs = {
**job_config.training.dataloader.kwargs,
**asdict(job_config.training.dataloader),
"batch_size": batch_size,
}

Expand Down Expand Up @@ -241,9 +241,8 @@ def build_text_validation_dataloader(
infinite=infinite,
)

# Merge config kwargs with explicit args (explicit args take precedence)
dataloader_kwargs = {
**job_config.validation.dataloader.kwargs,
**asdict(job_config.validation.dataloader),
"batch_size": batch_size,
}

Expand Down
7 changes: 3 additions & 4 deletions torchtitan/models/flux/flux_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import itertools
import math
from dataclasses import asdict
from typing import Any, Callable, Optional

import numpy as np
Expand Down Expand Up @@ -342,9 +343,8 @@ def build_flux_dataloader(
infinite=infinite,
)

# Merge config kwargs with explicit args (explicit args take precedence)
dataloader_kwargs = {
**job_config.training.dataloader.kwargs,
**asdict(job_config.training.dataloader),
"batch_size": batch_size,
}

Expand Down Expand Up @@ -444,9 +444,8 @@ def build_flux_validation_dataloader(
infinite=infinite,
)

# Merge config kwargs with explicit args (explicit args take precedence)
dataloader_kwargs = {
**job_config.validation.dataloader.kwargs,
**asdict(job_config.validation.dataloader),
"batch_size": batch_size,
}

Expand Down
Loading