|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +from contextlib import contextmanager |
| 5 | +from dataclasses import dataclass |
| 6 | + |
| 7 | +import torch |
| 8 | +from huggingface_hub import PyTorchModelHubMixin, get_collection |
| 9 | + |
| 10 | +from kvpress.presses.scorer_press import ScorerPress |
| 11 | + |
| 12 | + |
| 13 | +class QFilters(torch.nn.Module, PyTorchModelHubMixin): |
| 14 | + def __init__(self, num_layers: int, num_kv_heads: int, kv_head_dim: int): |
| 15 | + super().__init__() |
| 16 | + self.q_filters = torch.nn.Parameter(torch.randn(num_layers, num_kv_heads, kv_head_dim)) |
| 17 | + |
| 18 | + @classmethod |
| 19 | + def from_pretrained(cls, pretrained_model_name_or_path): |
| 20 | + return super().from_pretrained(pretrained_model_name_or_path) |
| 21 | + |
| 22 | + |
| 23 | +@dataclass |
| 24 | +class QFilterPress(ScorerPress): |
| 25 | + """ |
| 26 | + Prune KV pairs with Q-filters |
| 27 | + """ |
| 28 | + |
| 29 | + def __post_init_from_model__(self, model): |
| 30 | + model_name = model.config.name_or_path.split("/")[-1] |
| 31 | + self.q_filters = self.load_q_filters(model_name) |
| 32 | + self.q_filters = self.q_filters.to(model.dtype) |
| 33 | + |
| 34 | + @staticmethod |
| 35 | + def load_q_filters(model_name): |
| 36 | + try: |
| 37 | + return QFilters.from_pretrained(f"nthngdy/{model_name}_qfilt").q_filters |
| 38 | + except TypeError: |
| 39 | + raise ValueError( |
| 40 | + f"Could not load Q-filters for {model_name}. Available models: {QFilterPress.available_qfilters()}" |
| 41 | + ) |
| 42 | + |
| 43 | + @staticmethod |
| 44 | + def available_qfilters(): |
| 45 | + collection = get_collection("nthngdy/q-filters-67a4994dcb302a3d37f3d119", token=False) |
| 46 | + return [x.item_id.split("/")[-1][:-6] for x in collection.items] |
| 47 | + |
| 48 | + def score(self, module, hidden_states, keys, values, attentions, kwargs): |
| 49 | + q_filter = self.q_filters[module.layer_idx][None, :, None] |
| 50 | + q_filter = q_filter.to(keys.device) |
| 51 | + scores = -(q_filter * keys).sum(dim=-1) |
| 52 | + return scores |
| 53 | + |
| 54 | + @contextmanager |
| 55 | + def __call__(self, model): |
| 56 | + self.__post_init_from_model__(model) |
| 57 | + with super().__call__(model): |
| 58 | + yield |
0 commit comments