Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,13 +247,6 @@ def _replace_module(self, parent, child_name, new_module, child):
if hasattr(child, "base_layer"):
child = child.base_layer

if getattr(child, "state", None) is not None:
if hasattr(new_module, "base_layer"):
new_module.base_layer.state = child.state
else:
new_module.state = child.state
new_module.to(child.weight.device)

meta = torch.device("meta")
# dispatch to correct device
for name, module in new_module.named_modules():
Expand Down
34 changes: 32 additions & 2 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -3870,15 +3870,17 @@ def test_p_tuning_exactly_reproducible_after_loading(self, tmp_path):
class TestLowCpuMemUsageDifferentDevices:
"""Test for the low CPU memory usage option for loading PEFT models.

There are already tests for this in test_initialization.py but here we want to specifically test diverging devices
for the model and state_dict.
There are already tests for low_cpu_mem_usage=True in test_initialization.py but here we want to run tests that
require a GPU.

"""

model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
device = infer_device()

@pytest.mark.parametrize("device_model, device_sd", [("cpu", "cuda"), ("cuda", "cpu")])
def test_low_cpu_mem_usage_model_model_on_gpu_state_dict_on_cpu_works(self, device_model, device_sd):
# specifically test diverging devices for the model and state_dict
inputs = {"input_ids": torch.randint(0, 100, (1, 10)), "attention_mask": torch.ones(1, 10)}
inputs = {k: v.to(device_model) for k, v in inputs.items()}

Expand Down Expand Up @@ -3912,6 +3914,34 @@ def test_low_cpu_mem_usage_model_model_on_gpu_state_dict_on_cpu_works(self, devi
assert torch.allclose(logits_low_cpu_mem, logits_not_low_cpu_mem)
assert {p.device.type for p in model.parameters()} == {device_model}

@pytest.mark.parametrize("quantization_method", ["bnb-4bit", "bnb-8bit"])
def test_low_cpu_mem_usage_with_quantization(self, quantization_method):
# Ensure that low_cpu_mem_usage works with quantization
# See also https://github.com/huggingface/diffusers/issues/10550
if quantization_method == "bnb-4bit":
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float32,
bnb_4bit_quant_storage=torch.float32,
bnb_4bit_use_double_quant=True,
)
elif quantization_method == "bnb-8bit":
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
else:
raise ValueError(f"Unknown quantization method {quantization_method}")

model = AutoModelForCausalLM.from_pretrained(self.model_id, quantization_config=quantization_config)
if model.device.type != self.device:
# calling model.to("cuda") with 8 bit bnb raises an error, thus guard against it
model = model.to(self.device)

lora_config = LoraConfig(init_lora_weights=False, target_modules="all-linear")

# We use get_peft_model with low_cpu_mem_usage=True here. This is not typically done in practice (the option is
# mostly interesting for loading trained adapters), but it does the job for testing purposes.
model = get_peft_model(model, lora_config, low_cpu_mem_usage=True) # this should not raise
assert {p.device.type for p in model.parameters()} == {self.device, "meta"}


class TestEvaInitializationGPU:
"""GPU tests for the Eva initialization method."""
Expand Down
Loading