Skip to content

Conversation

@divyanshk
Copy link

This diff introduces common dataloader args which are supported by statefuldataloader (and torch.utils.data dataloader). Users should be able to use them in their config files.

I was thinking about introducing a catch all kwargs to make it easier to specify args but that can easily complicate things (validation checks, duplication, existing defined named args in function definitions etc).

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Dec 2, 2025
@divyanshk divyanshk force-pushed the divyanshk/dataloader_args branch from 6763cc0 to 990d654 Compare December 2, 2025 01:04
@divyanshk divyanshk marked this pull request as ready for review December 3, 2025 17:09
Copy link
Contributor

@wwwjn wwwjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! I feel like I'm slightly lean towards using kwargs instead of adding these parameters one by one. This is because the StatefulDataLoader() has a lot of supported field and it's hard to say some of them are "common" in different use cases.

Can you explain more on "but that can easily complicate things"? We can just pass all the kwargs to StatefulDataLoader and let it to check correctness. wdyt @tianyu-l

Comment on lines 89 to 92
num_workers=num_workers,
persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
pin_memory=pin_memory,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about introducing a catch all kwargs to make it easier to specify args but that can easily complicate things (validation checks, duplication, existing defined named args in function definitions etc).

These are valid concerns. For now I'm leaning towards keeping things simple by passing **kwargs around.

Does it make sense if we only make these args explicit when sending to the actual init of StatefulDataLoader and not passing in all **kwargs from the input of ParallelAwareDataloader? The point is to not accidentally hit error inside StatefulDataLoader.

self,
dataset: IterableDataset,
dp_rank: int,
dp_world_size: int,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you help change this: let's keep at most one positional arg (dataset) and others to be kwargs.

@divyanshk
Copy link
Author

Thanks @tianyu-l @wwwjn Updated the PR with kwargs based approach. I initially didn't do this to avoid any confusion on the user's part. That is because we provide batch_size, collate_fn (in mm_datasets) internally. I resolved that by making explicit args defined internally take precedence. Added a warning for users in config.py - so that should help. The error from wrong kwargs (if any) will be thrown in torchtitan itself - won't go down to StatefulDataloader.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good in general.

The CPU unit test in CI didn't run. Could you double check?

Also, please add an GPU integration test, see inline comments.

- batch_size: Determined by training.local_batch_size
- collate_fn: Set by the dataset-specific collator
Example (TOML config file):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a dedicated test for dataloader with kwargs passed through?
https://github.com/pytorch/torchtitan/blob/main/tests/integration_tests/features.py

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a GPU integration test. To be able to use the cli to pass in the kwargs I added a tyro rule. I am not super familiar with tyro so please have a look.

Also, shout out to the integration test setup. Love that we could do a quick mini-GPU run as part of feature testing.

OverrideDefinitions(
[
[
'--training.dataloader.kwargs \'{"num_workers": 2, "pin_memory": true, "prefetch_factor": 2}\'',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of letting cli accept a dict, can we just do

--training.dataloader.kwargs.num_workers 2 --training.dataloader.kwargs.pin_memory true, ...

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tianyu-l That won't be possible because of how tyro is operating. If we want to have an arbitrary dict to act like a catch all kwargs then the dotted notation won't work because those fields are not pre-defined.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm got it. I feel this makes the CLI a bit too hard to use.

If there's a common set of kwargs that people often use, maybe we should restrict and start with that?

Sorry if it sounds going back to where we started. Do you think a middle ground makes sense where we wrap and pass explicit args around in a kwargs dict, after getting them from job_config.training.data_loader?

Happy to discuss more if you think there are better alternatives.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be done. I updated the code to pass around kwargs internally but expose explicit args through the user config.

@tianyu-l tianyu-l linked an issue Dec 14, 2025 that may be closed by this pull request
@divyanshk divyanshk force-pushed the divyanshk/dataloader_args branch from 604ac82 to 435543d Compare December 17, 2025 19:28
@tianyu-l
Copy link
Contributor

Not sure why the CPU unit test is taking forever. It doesn't happen for other PRs, could you take a look?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Slow Dataloader should use num_worker > 1

3 participants