Skip to content

modules_to_save resulting in empty tensor with deepspeed zero3 LoRA training #2450

@agokrani

Description

@agokrani

System Info

peft==0.15.0
transformers==4.49.0

Who can help?

@BenjaminBossan I am training a model with LoRA using DeepSpeed zero3, and I have set modules_to_save=["embed_tokens", "lm_head"]. The model trained successfully without any errors. However, while merging the model, I am encountering the following error:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/agokrani/Documents/git/zett/env-zett-311/lib/python3.11/site-packages/transformers/integrations/peft.py", line 239, in load_adapter
    incompatible_keys = set_peft_model_state_dict(
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/agokrani/Documents/git/zett/env-zett-311/lib/python3.11/site-packages/peft/utils/save_and_load.py", line 444, in set_peft_model_state_dict
    load_result = model.load_state_dict(peft_model_state_dict, strict=False)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/agokrani/Documents/git/zett/env-zett-311/lib/python3.11/site-packages/torch/nn/modules/module.py", line 2581, in load_state_dict
    raise RuntimeError(
RuntimeError: Error(s) in loading state_dict for Phi3ForCausalLM:
	size mismatch for model.embed_tokens.modules_to_save.default.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([151669, 3072]).

Upon inspecting the adapter_model.safetensors file, I noticed that both embed_tokens and lm_head tensors were saved, but they are empty tensors. I have debugged the root cause, and it seems that the following code was removed from the get_peft_model_state_dict function:

if getattr(model, "modules_to_save", None) is not None:
    for key, value in state_dict.items():
        if any(f"{module_name}.modules_to_save.{adapter_name}" in key for module_name in model.modules_to_save):
            to_return[key.replace("modules_to_save.", "")] = value

Instead, this code was added:

# ADDITIONAL TRAINING MODULES / MODULES_TO_SAVE
for name, module in model.named_modules():
    if isinstance(module, AuxiliaryTrainingWrapper):
        to_return.update({f"{name}.{k}": v for k, v in module.adapter_state_dict(adapter_name).items()})

Although the code above adds the embedding layers to the to_return dict, it only adds empty tensors instead of the actual weights. The easiest way to solve this problem is to re-add the removed code. However, I am not 100% sure if this is the correct approach. Do you have any idea why this might be happening? My guess is that DeepSpeed zero3 might be causing this problem. If the solution is indeed just to restore the removed code, I can create the PR. Please let me know what you think and whether I should create the PR.

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

Just use deepspeed zero3 with latest version of peft and try to save embedding layers.

Expected behavior

The embedding layers shouldn't be empty tensors, but rather the actual modified weights of the embedding layers.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions