-
Notifications
You must be signed in to change notification settings - Fork 81
Add TOVA press #12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TOVA press #12
Conversation
maxjeblick
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for adding TOVA, LGTM!
|
Thanks a lot for adding TOVA! Just a small clarification: Thanks again, and sorry for the misunderstanding :) |
|
After #L58, the shape is There is no need to compute the full attention weight matrix that would cost a lot of memory. We only need So in our code, # attn_weights -> shape [batch, heads, window_size=1, seq_len]
scores = attn_weights.mean(1) # -> shape [batch, window_size=1, seq_len]
scores = scores.repeat(1, keys.shape[1], 1) # -> shape [batch, num_key_value_heads, seq_len]Do you confirm everything is as you expect it ? |
|
Thanks for the quick response! I'm sorry, if I understand correctly this is not exactly the same computation for deeper layers (it is the same for the first attention layer). For the prefilling case, if token i attended to previous tokens in previous attention layers, the hidden representation will differ from the one where the same token did not attend to other tokens in previous layers. In the current implementation, the non-window tokens (which are all tokens but one) never attend any other tokens, which will result in non contextualized non-window token representations which probably results in (at least) slightly worse performance. Regarding the memory cost, I agree that this implementation is more costly as we need to store (for a short period) the whole attention matrix of a specific layer, but the KV cache memory stays compressed. |
|
The current code is doing exactly the same thing what you first posted in the issue here.
No, the press is applied after the forward pass (in a pytorch hook) and only modify the KV cache, not the hidden states. So whatever you prune in the first layers, the input hidden states of the current layer are not impacted. Hidden states during pre-filling are 100% independent from the press. I hope this clarifies |
|
Thank you for clarifying it, and I apologize for the misunderstanding. As I'm not very familiar with this repository, I wanted to ensure that the implemented code functions as intended. LGTM, and thanks again! |
TOVA press as requested in #2