Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
423c616
ENH: Implement ensure_weight_tying for trainable_token_indices (#2864)
sambhavnoobcoder Oct 26, 2025
232c6e7
maintinaers comments addressed
sambhavnoobcoder Oct 29, 2025
213b47f
make style ran
sambhavnoobcoder Oct 29, 2025
3874be0
Merge remote-tracking branch 'origin/main' into trainable-tokens-weig…
sambhavnoobcoder Nov 5, 2025
30d2d01
comments fixed
sambhavnoobcoder Nov 5, 2025
3f19ed3
Merge main into trainable-tokens-weight-tying and resolve conflict in…
sambhavnoobcoder Nov 18, 2025
88f5f22
Use _get_module_names_tied_with_embedding helper for cleaner code
sambhavnoobcoder Nov 19, 2025
c62aa56
test fixes
sambhavnoobcoder Nov 27, 2025
769b8b0
Fix embedding name matching for nested paths and ensure weight tying …
sambhavnoobcoder Dec 9, 2025
8252ccf
Apply ruff formatting
sambhavnoobcoder Dec 9, 2025
616fb80
Fix embedding name matching to use full paths and endswith for disamb…
sambhavnoobcoder Dec 12, 2025
456cd36
Add _tied_weights_keys to MegaModel test and enable ensure_weight_tyi…
sambhavnoobcoder Dec 15, 2025
6ed72f7
Add tests for targeting both embedding and tied layers explicitly
sambhavnoobcoder Dec 15, 2025
7e1821d
Use mapping format for _tied_weights_keys in MegaModel test
sambhavnoobcoder Dec 16, 2025
5f86f3a
Clarify docstrings for tied layer tests
sambhavnoobcoder Dec 16, 2025
69dbbe7
Remove unnecessary reference to maintainer in docstring
sambhavnoobcoder Dec 17, 2025
1969250
Move BartConfig and BartModel imports to top of file
sambhavnoobcoder Dec 17, 2025
eb20b83
Rename MegaModel to CompositeModel throughout tests
sambhavnoobcoder Dec 17, 2025
d86bac8
style changes
sambhavnoobcoder Dec 18, 2025
518a3a7
Apply doc-builder style formatting to docstrings
sambhavnoobcoder Dec 18, 2025
b56d14c
Fix CI failure by setting sub-models _tied_weights_keys to None
sambhavnoobcoder Dec 20, 2025
3d265f2
Merge remote-tracking branch 'origin/main' into trainable-tokens-weig…
sambhavnoobcoder Jan 6, 2026
d9177fb
Add format check for _tied_weights_keys to avoid mixing list and dict…
sambhavnoobcoder Jan 7, 2026
a652214
Skip composite model test for transformers <v5
sambhavnoobcoder Jan 8, 2026
e9dbf1b
Apply formatting to skipif decorator
sambhavnoobcoder Jan 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix embedding name matching for nested paths and ensure weight tying …
…in inject_adapter
  • Loading branch information
sambhavnoobcoder committed Dec 9, 2025
commit 769b8b0eb822b55ed091a174c833e808187af474
34 changes: 23 additions & 11 deletions src/peft/tuners/trainable_tokens/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,29 @@ def inject_adapter(
if matched_keys:
parent, target, target_name = _get_submodules(model, name)

peft_config = self.peft_config[adapter_name].to_dict()
peft_config["tied_adapter"] = self.model.get_input_embeddings()

self._create_and_replace_dict(
peft_config,
adapter_name,
target,
target_name,
parent,
matched_keys[0],
)
# If the module is already a TrainableTokensLayer, we need to replace it with a tied version
# instead of just updating it. This handles the case where the user explicitly targeted
# both the embedding and tied layers in target_modules.
if isinstance(target, TrainableTokensLayer):
# Replace the existing layer with a new one that's tied to the embedding
peft_config = self.peft_config[adapter_name].to_dict()
peft_config["tied_adapter"] = self.model.get_input_embeddings()

new_module = self._create_new_module(peft_config, adapter_name, target.base_layer, **peft_config)
self._replace_module(parent, target_name, new_module, target.base_layer)
else:
# Module hasn't been wrapped yet, create and replace normally
peft_config = self.peft_config[adapter_name].to_dict()
peft_config["tied_adapter"] = self.model.get_input_embeddings()

self._create_and_replace_dict(
peft_config,
adapter_name,
target,
target_name,
parent,
matched_keys[0],
)

def _get_tied_target_modules(self, *args, **kwargs):
# Normally this method would return the layers that target tied layers.
Expand Down
4 changes: 2 additions & 2 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -1488,8 +1488,8 @@ def set_additional_trainable_modules(model, peft_config, model_config, adapter_n

# Find which target layers are in the tied weights (including the embedding source)
for target_layer_name in target_layers:
# Check if this is the embedding layer
if target_layer_name == embedding_name:
# Check if this is the embedding layer (check both exact match and endswith for nested structures)
if target_layer_name == embedding_name or target_layer_name.endswith(embedding_name):
tied_layer_keys.append(target_layer_name)
continue
# Check if this target layer matches any tied module (considering nested structures)
Expand Down