Skip to content

Conversation

@francesco-bertolotti
Copy link

This PR addresses the issue #2140

Briefly, the bug is related to the validation step.
The validate method did not pass the attention mask to the model.
When using flash attention, this leads to an error.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Dec 11, 2025
Copy link
Contributor

@wwwjn wwwjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix! Please fix the PP enabled branch as well, I think we would do something similar to train.py: https://github.com/pytorch/torchtitan/blob/refs/heads/main/torchtitan/train.py#L507-L516

break

# prepare attention mask
extra_kwargs : dict[str, Any] = {"attention_masks" : model_parts[0].get_attention_masks(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@francesco-bertolotti
Copy link
Author

The new commit propagated the fix to the pipeline parallelism case. Further, I have encapsulated the creation of the attention mask in a method.

Co-authored-by: Francesco Bertolotti <[email protected]>
"unequal sample counts across ranks when dataset is exhausted."
)

def post_dataloading_process(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we reuse the one in train.py?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle, it would be possible, but I would need to make some modifications since their signatures differ. In particular, the trainer accesses model_parts through self.

One possible solution is to turn post_dataloading_process into a utility function that both the trainer and validator can call. However, this would mean you can no longer modify the behavior through inheritance—unless the current post_dataloading_process simply becomes a wrapper around the new utility function.|

If you have any suggestions, I can add a commit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.
If you want to "modify the behavior", we probably should have a PostDataloadingProcessor class.
If not, I think it's fine to call a util directly in both trainer and validator.

WDYT?
cc @fegin

Copy link
Contributor

@fegin fegin Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking this as well, I'm seeing the duplicated code in trainer and validator, model_parts is not a big problem, you can add it. See #2144. But the duplicated logic is pretty concerning.

My initial proposal is to have an util function in distributed.utils so that both methods can call it. The only worry is that I don't know whether distributed.utils is the best place to put it. It is not now because the util function purely unwrap the input but the util function will encapsulate the CP logic after #2144. So distributed.utils seems to be a good place.

@francesco-bertolotti you can do it as your PR may be landed first. I can rebase on top of your change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed we should share code, but not sure about distributed.utils. Let's discuss offline.

@rakkit rakkit mentioned this pull request Dec 12, 2025
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unblocking

@tianyu-l
Copy link
Contributor

@francesco-bertolotti CPU unit test failed, could you take a look?

@tianyu-l tianyu-l linked an issue Dec 14, 2025 that may be closed by this pull request
@francesco-bertolotti
Copy link
Author

Hi @tianyu-l,

I looked into the error, and it turns out to be a circular import issue.

In validate.py, I need access to model_args. Since Validator doesn’t have direct access to it, I currently construct it via get_train_spec. However, importing get_train_spec triggers an import of BaseValidator, which lives in the same module as Validator, leading to the circular dependency.

I see two possible ways to resolve this:

  1. Pass model_args directly to Validator, eliminating the need to import get_train_spec.
  2. Move BaseValidator into its own module.

I’m leaning slightly toward changing the Validator signature, as it feels cleaner overall, though it would be a breaking change.

Curious to hear your thoughts.

@tianyu-l
Copy link
Contributor

It seems this is being fixed in #2144, with the first approach you mentioned.

Are you OK if we consolidate effort there?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Validation breaking with FlashAttention

4 participants