Skip to content

Parallel patch embed for faster execution#17

Merged
brandstetter-johannes merged 4 commits intomainfrom
jkg/fast
Jun 15, 2023
Merged

Parallel patch embed for faster execution#17
brandstetter-johannes merged 4 commits intomainfrom
jkg/fast

Conversation

@rejuvyesh
Copy link
Copy Markdown
Collaborator

We can use the trick of grouped convolutions to perform patch embedding of different variables in parallel in a single GPU kernel call instead of serially in a for loop.

TODO:

  • Correctly converting the non-parallel patch embed checkpoints

@rejuvyesh rejuvyesh assigned tung-nd and unassigned tung-nd Apr 24, 2023
@rejuvyesh rejuvyesh requested a review from tung-nd April 24, 2023 02:42
@rejuvyesh
Copy link
Copy Markdown
Collaborator Author

rejuvyesh commented Apr 24, 2023

Some rough timing:

parallel_patch_embed: True, img_size: [32,64]
forward time: 0.009
---
parallel_patch_embed: False, img_size: [32,64]
forward time: 0.021
---
parallel_patch_embed: True, img_size: [128, 256]
forward time: 0.097
---
parallel_patch_embed: False, img_size: [128, 256]
forward time: 0.109
parallel_patch_embed: True, img_size: [32, 64]
backward time: 0.034
---
parallel_patch_embed: False, img_size: [32, 64]
backward time: 0.052

@brandstetter-johannes brandstetter-johannes merged commit efd6de4 into main Jun 15, 2023
@rejuvyesh rejuvyesh deleted the jkg/fast branch June 28, 2023 15:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants