Skip to content

Conversation

@SimJeg
Copy link
Collaborator

@SimJeg SimJeg commented Nov 25, 2024

TOVA press as requested in #2

@SimJeg SimJeg linked an issue Nov 25, 2024 that may be closed by this pull request
@SimJeg SimJeg requested a review from maxjeblick November 25, 2024 10:55
Copy link
Collaborator

@maxjeblick maxjeblick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding TOVA, LGTM!

@hassidm
Copy link

hassidm commented Nov 26, 2024

Thanks a lot for adding TOVA!

Just a small clarification:
The attn_weights at this line are with shape [batch, heads, seq_len, seq_len] or [batch, heads, window_size, seq_len]? for TOVA in the prefilling phase the former should be implemented (and if I understand correctly, the former is implemented according to link or link).
After computing the attention weights with shape [batch, heads, seq_len, seq_len], the compression should correspond to the last token attention weights (averaged across heads):
torch.mean(attn_weights[:,:,-1:,:],dim=1) (which implemented correctly in th PR :) )

Thanks again, and sorry for the misunderstanding :)

@SimJeg
Copy link
Collaborator Author

SimJeg commented Nov 26, 2024

After #L58, the shape is [batch, heads, window_size, seq_len] and not [batch, heads, seq_len, seq_len] because we only compute attention for the last window_size tokens.

There is no need to compute the full attention weight matrix that would cost a lot of memory. We only need attn_weights = full_attn_weights[:, :, -window_size:, :].

So in our code,

# attn_weights -> shape [batch, heads, window_size=1, seq_len]
scores = attn_weights.mean(1) # -> shape [batch, window_size=1, seq_len]
scores = scores.repeat(1, keys.shape[1], 1) # -> shape [batch, num_key_value_heads, seq_len]

Do you confirm everything is as you expect it ?

@hassidm
Copy link

hassidm commented Nov 26, 2024

Thanks for the quick response!

I'm sorry, if I understand correctly this is not exactly the same computation for deeper layers (it is the same for the first attention layer).

For the prefilling case, if token i attended to previous tokens in previous attention layers, the hidden representation will differ from the one where the same token did not attend to other tokens in previous layers. In the current implementation, the non-window tokens (which are all tokens but one) never attend any other tokens, which will result in non contextualized non-window token representations which probably results in (at least) slightly worse performance.

Regarding the memory cost, I agree that this implementation is more costly as we need to store (for a short period) the whole attention matrix of a specific layer, but the KV cache memory stays compressed.

@SimJeg
Copy link
Collaborator Author

SimJeg commented Nov 26, 2024

The current code is doing exactly the same thing what you first posted in the issue here.

In the current implementation, the non-window tokens (which are all tokens but one) never attend any other tokens

No, the press is applied after the forward pass (in a pytorch hook) and only modify the KV cache, not the hidden states. So whatever you prune in the first layers, the input hidden states of the current layer are not impacted. Hidden states during pre-filling are 100% independent from the press.

I hope this clarifies

@hassidm
Copy link

hassidm commented Nov 26, 2024

Thank you for clarifying it, and I apologize for the misunderstanding. As I'm not very familiar with this repository, I wanted to ensure that the implemented code functions as intended.

LGTM, and thanks again!

@SimJeg SimJeg merged commit 3ca0ce4 into main Nov 26, 2024
2 checks passed
@SimJeg SimJeg deleted the simon/tova-press branch November 26, 2024 12:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[NEW PRESS] Add TOVA

4 participants