-
Notifications
You must be signed in to change notification settings - Fork 644
fixed validation error when using flash attention #2142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
fixed validation error when using flash attention #2142
Conversation
Co-authored-by: Francesco Bertolotti <[email protected]>
wwwjn
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix! Please fix the PP enabled branch as well, I think we would do something similar to train.py: https://github.com/pytorch/torchtitan/blob/refs/heads/main/torchtitan/train.py#L507-L516
torchtitan/components/validate.py
Outdated
| break | ||
|
|
||
| # prepare attention mask | ||
| extra_kwargs : dict[str, Any] = {"attention_masks" : model_parts[0].get_attention_masks( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For readability, can you make this into a function similar as: https://github.com/pytorch/torchtitan/blob/refs/heads/main/torchtitan/train.py#L416C9-L416C33
|
The new commit propagated the fix to the pipeline parallelism case. Further, I have encapsulated the creation of the attention mask in a method. |
Co-authored-by: Francesco Bertolotti <[email protected]>
f8693e5 to
628dbb6
Compare
| "unequal sample counts across ranks when dataset is exhausted." | ||
| ) | ||
|
|
||
| def post_dataloading_process( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we reuse the one in train.py?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In principle, it would be possible, but I would need to make some modifications since their signatures differ. In particular, the trainer accesses model_parts through self.
One possible solution is to turn post_dataloading_process into a utility function that both the trainer and validator can call. However, this would mean you can no longer modify the behavior through inheritance—unless the current post_dataloading_process simply becomes a wrapper around the new utility function.|
If you have any suggestions, I can add a commit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see.
If you want to "modify the behavior", we probably should have a PostDataloadingProcessor class.
If not, I think it's fine to call a util directly in both trainer and validator.
WDYT?
cc @fegin
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm thinking this as well, I'm seeing the duplicated code in trainer and validator, model_parts is not a big problem, you can add it. See #2144. But the duplicated logic is pretty concerning.
My initial proposal is to have an util function in distributed.utils so that both methods can call it. The only worry is that I don't know whether distributed.utils is the best place to put it. It is not now because the util function purely unwrap the input but the util function will encapsulate the CP logic after #2144. So distributed.utils seems to be a good place.
@francesco-bertolotti you can do it as your PR may be landed first. I can rebase on top of your change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed we should share code, but not sure about distributed.utils. Let's discuss offline.
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unblocking
|
@francesco-bertolotti CPU unit test failed, could you take a look? |
|
Hi @tianyu-l, I looked into the error, and it turns out to be a circular import issue. In I see two possible ways to resolve this:
I’m leaning slightly toward changing the Curious to hear your thoughts. |
|
It seems this is being fixed in #2144, with the first approach you mentioned. Are you OK if we consolidate effort there? |
This PR addresses the issue #2140
Briefly, the bug is related to the validation step.
The validate method did not pass the attention mask to the model.
When using flash attention, this leads to an error.