Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions examples/olora_finetuning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@ OLoRA also supports quantization. To use 4-bit quantization try:
```bash
python3 examples/olora_finetuning/olora_finetuning.py --base_model facebook/opt-350m --quantize
```
or you can just pass a quantized model without the quantize flag.

If you want to run DDP by [accelerate](https://huggingface.co/docs/accelerate/en/index), please run `accelerate config` to set your ddp config, and run:
```bash
accelerate launch examples/olora_finetuning/olora_finetuning.py --base_model facebook/opt-350m
```
please add `--device_map cpu` if you want to run finetune on CPU.

If you want to train a quantized model like AWQ and GPTQ which do not support olora init method, please pass `--init_lora_weights gaussian`. For example:
```bash
python3 examples/olora_finetuning/olora_finetuning.py --base_model hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4 --init_lora_weights gaussian

```


## Use the model
Expand Down
37 changes: 26 additions & 11 deletions examples/olora_finetuning/olora_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
# limitations under the License.


from typing import List
import os
from typing import List, Optional

import torch
import transformers
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, set_seed

from peft import (
LoraConfig,
Expand All @@ -43,23 +44,33 @@ def train(
lora_alpha: int = 16,
lora_dropout: float = 0.05,
lora_target_modules: List[str] = None,
torch_dtype: str = "float16",
init_lora_weights="olora",
seed: Optional[int] = None,
):
model = AutoModelForCausalLM.from_pretrained(
base_model,
device_map=device_map,
quantization_config=BitsAndBytesConfig(
# Set device_map to the right place when enabling DDP.
world_size = int(os.environ.get("WORLD_SIZE", 0)) or int(os.environ.get("PMI_SIZE", 0))
if world_size > 1 and device_map != "cpu":
from accelerate import Accelerator

device_map = {"": Accelerator().process_index}
# Set seed
if seed is not None:
set_seed(seed)
model_kwargs = {"torch_dtype": getattr(torch, torch_dtype), "device_map": device_map}
if quantize:
model_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
if quantize
else None,
torch_dtype=torch.float16,
)
model = AutoModelForCausalLM.from_pretrained(base_model, **model_kwargs)

tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
# For some tokenizer with no pad token like llama
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

def tokenize(prompt, add_eos_token=True):
result = tokenizer(
Expand Down Expand Up @@ -112,7 +123,6 @@ def generate_and_tokenize_prompt(example):
warmup_steps=100,
num_train_epochs=num_epochs,
learning_rate=learning_rate,
fp16=True,
logging_steps=100,
optim="adamw_torch",
evaluation_strategy="steps",
Expand All @@ -122,6 +132,7 @@ def generate_and_tokenize_prompt(example):
output_dir=output_dir,
save_total_limit=3,
load_best_model_at_end=True,
ddp_find_unused_parameters=False if world_size > 1 else None,
),
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
Expand Down Expand Up @@ -159,7 +170,9 @@ def generate_prompt(example):
parser.add_argument("--lora_alpha", type=int, default=16)
parser.add_argument("--lora_dropout", type=float, default=0.05)
parser.add_argument("--lora_target_modules", type=str, default=None)
parser.add_argument("--torch_dtype", type=str, default="float16")
parser.add_argument("--init_lora_weights", type=str, default="olora")
parser.add_argument("--seed", type=int, default=None)

args = parser.parse_args()

Expand All @@ -180,5 +193,7 @@ def generate_prompt(example):
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
lora_target_modules=args.lora_target_modules,
torch_dtype=args.torch_dtype,
init_lora_weights=args.init_lora_weights,
seed=args.seed,
)
2 changes: 1 addition & 1 deletion src/peft/utils/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None):
# BNB requires CUDA weights
device = weight.device
is_cpu = device.type == torch.device("cpu").type
if is_cpu:
if is_cpu and torch.cuda.is_available():
weight = weight.to(torch.device("cuda"))

cls_name = weight.__class__.__name__
Expand Down
Loading