Skip to content

Commit 4100647

Browse files
authored
Add QFilterPress (#54)
1 parent 5c8bb37 commit 4100647

File tree

5 files changed

+72
-0
lines changed

5 files changed

+72
-0
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ Several presses inherit from `ScorerPress` ([source](kvpress/presses/scorer_pres
6666
- `StreamingLLMPress` ([source](kvpress/presses/streaming_llm_press.py), [paper](https://arxiv.org/abs/2309.17453)): keep only the initial and recent tokens
6767
- `TOVAPress` ([source](kvpress/presses/tova_press.py), [paper](https://arxiv.org/abs/2401.06104)): attention weight of the last query averaged across heads
6868
- `ObservedAttentionPress` ([source](kvpress/presses/observed_attention_press.py), [paper](https://arxiv.org/abs/2306.14048)): average attention weight observed during in pre-filling phase
69+
- `QFilterPress` ([source](kvpress/presses/qfilter_press.py), [paper](https://arxiv.org/abs/2503.02812)): project the Key representations on the main SVD component of the Query vectors to approximate the attention scores.
6970

7071
Some presses rely on a different logic:
7172
- `ThinKPress` ([source](kvpress/presses/think_press.py), [paper](https://arxiv.org/pdf/2407.21018)): compress the dimensions of the keys based on the channel attention score on the last queries
@@ -81,6 +82,7 @@ Finally we provide wrapper presses that can be combined with other presses:
8182
- `ChunkPress` ([source](kvpress/presses/chunk_press.py), [paper](https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00716/125280)): compress the KV cache on each sequence chunk separately. This can yield to more uniform compression across long sequences
8283
- `CriticalKVPress` and `CriticalAdaKVPress` ([source](kvpress/presses/criticalkv_press.py), [paper](https://arxiv.org/abs/2502.03805)): refine the scores using the L1 norm of Wo @ values, coupled with a two-stage selection.
8384

85+
8486
For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression)
8587

8688
## Evaluation

kvpress/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from kvpress.presses.streaming_llm_press import StreamingLLMPress
2424
from kvpress.presses.think_press import ThinKPress
2525
from kvpress.presses.tova_press import TOVAPress
26+
from kvpress.presses.qfilter_press import QFilterPress
2627

2728
# Patch the attention functions to support head-wise compression
2829
patch_attention_functions()
@@ -49,4 +50,5 @@
4950
"ChunkPress",
5051
"DuoAttentionPress",
5152
"ChunkKVPress",
53+
"QFilterPress",
5254
]

kvpress/presses/qfilter_press.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from contextlib import contextmanager
5+
from dataclasses import dataclass
6+
7+
import torch
8+
from huggingface_hub import PyTorchModelHubMixin, get_collection
9+
10+
from kvpress.presses.scorer_press import ScorerPress
11+
12+
13+
class QFilters(torch.nn.Module, PyTorchModelHubMixin):
14+
def __init__(self, num_layers: int, num_kv_heads: int, kv_head_dim: int):
15+
super().__init__()
16+
self.q_filters = torch.nn.Parameter(torch.randn(num_layers, num_kv_heads, kv_head_dim))
17+
18+
@classmethod
19+
def from_pretrained(cls, pretrained_model_name_or_path):
20+
return super().from_pretrained(pretrained_model_name_or_path)
21+
22+
23+
@dataclass
24+
class QFilterPress(ScorerPress):
25+
"""
26+
Prune KV pairs with Q-filters
27+
"""
28+
29+
def __post_init_from_model__(self, model):
30+
model_name = model.config.name_or_path.split("/")[-1]
31+
self.q_filters = self.load_q_filters(model_name)
32+
self.q_filters = self.q_filters.to(model.dtype)
33+
34+
@staticmethod
35+
def load_q_filters(model_name):
36+
try:
37+
return QFilters.from_pretrained(f"nthngdy/{model_name}_qfilt").q_filters
38+
except TypeError:
39+
raise ValueError(
40+
f"Could not load Q-filters for {model_name}. Available models: {QFilterPress.available_qfilters()}"
41+
)
42+
43+
@staticmethod
44+
def available_qfilters():
45+
collection = get_collection("nthngdy/q-filters-67a4994dcb302a3d37f3d119", token=False)
46+
return [x.item_id.split("/")[-1][:-6] for x in collection.items]
47+
48+
def score(self, module, hidden_states, keys, values, attentions, kwargs):
49+
q_filter = self.q_filters[module.layer_idx][None, :, None]
50+
q_filter = q_filter.to(keys.device)
51+
scores = -(q_filter * keys).sum(dim=-1)
52+
return scores
53+
54+
@contextmanager
55+
def __call__(self, model):
56+
self.__post_init_from_model__(model)
57+
with super().__call__(model):
58+
yield

tests/default_presses.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
StreamingLLMPress,
1414
ThinKPress,
1515
TOVAPress,
16+
QFilterPress,
1617
)
1718

1819

@@ -31,6 +32,7 @@ def load_attention_pattern(model):
3132
{"cls": ExpectedAttentionPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
3233
{"cls": RandomPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
3334
{"cls": StreamingLLMPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
35+
{"cls": QFilterPress, "kwargs": [{"compression_ratio": 0.2}, {"compression_ratio": 0.8}]},
3436
{
3537
"cls": SnapKVPress,
3638
"kwargs": [{"compression_ratio": 0.2, "window_size": 2}, {"compression_ratio": 0.8, "window_size": 2}],
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
from kvpress.presses.qfilter_press import QFilterPress
4+
5+
6+
def test_load_qfilters():
7+
for model_name in QFilterPress.available_qfilters():
8+
QFilterPress.load_q_filters(model_name)

0 commit comments

Comments
 (0)