Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
enable FSDP example for model `hugging-quants/Meta-Llama-3.1-8B-Instr…
…uct-GPTQ-INT4`

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
  • Loading branch information
kaixuanliu committed Jul 2, 2025
commit 2a54712a42b44a311287d15adb12a88a43112138
2 changes: 1 addition & 1 deletion examples/sft/run_peft.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ python train.py \
--lora_alpha 16 \
--lora_dropout 0.1 \
--lora_target_modules "all-linear" \
--use_4bit_quantization True \
--use_bnb_4bit_quantization True \
--use_nested_quant True \
--bnb_4bit_compute_dtype "bfloat16" \
--use_flash_attn True
2 changes: 1 addition & 1 deletion examples/sft/run_peft_deepspeed.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ accelerate launch --config_file "configs/deepspeed_config.yaml" train.py \
--lora_alpha 16 \
--lora_dropout 0.1 \
--lora_target_modules "all-linear" \
--use_4bit_quantization False
--use_bnb_4bit_quantization False
2 changes: 1 addition & 1 deletion examples/sft/run_peft_fsdp.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ accelerate launch --config_file "configs/fsdp_config.yaml" train.py \
--lora_alpha 16 \
--lora_dropout 0.1 \
--lora_target_modules "all-linear" \
--use_4bit_quantization False
--use_bnb_4bit_quantization False
2 changes: 1 addition & 1 deletion examples/sft/run_peft_multigpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ torchrun --nproc_per_node 8 --nnodes 1 train.py \
--lora_alpha 16 \
--lora_dropout 0.1 \
--lora_target_modules "all-linear" \
--use_4bit_quantization True \
--use_bnb_4bit_quantization True \
--use_nested_quant True \
--bnb_4bit_compute_dtype "bfloat16" \
--use_flash_attn True
2 changes: 1 addition & 1 deletion examples/sft/run_peft_qlora_deepspeed_stage3.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ accelerate launch --config_file "configs/deepspeed_config_z3_qlora.yaml" train.
--lora_alpha 16 \
--lora_dropout 0.1 \
--lora_target_modules "all-linear" \
--use_4bit_quantization True \
--use_bnb_4bit_quantization True \
--use_nested_quant True \
--bnb_4bit_compute_dtype "bfloat16" \
--bnb_4bit_quant_storage_dtype "bfloat16"
2 changes: 1 addition & 1 deletion examples/sft/run_peft_qlora_fsdp.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ accelerate launch --config_file "configs/fsdp_config_qlora.yaml" train.py \
--lora_alpha 16 \
--lora_dropout 0.1 \
--lora_target_modules "all-linear" \
--use_4bit_quantization True \
--use_bnb_4bit_quantization True \
--use_nested_quant True \
--bnb_4bit_compute_dtype "bfloat16" \
--bnb_4bit_quant_storage_dtype "bfloat16"
2 changes: 1 addition & 1 deletion examples/sft/run_unsloth_peft.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ python train.py \
--lora_alpha 16 \
--lora_dropout 0.1 \
--lora_target_modules "q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj" \
--use_4bit_quantization True \
--use_bnb_4bit_quantization True \
--use_nested_quant True \
--bnb_4bit_compute_dtype "bfloat16" \
--use_flash_attn True
34 changes: 22 additions & 12 deletions examples/sft/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,19 +97,19 @@ def create_and_prepare_model(args, data_args, training_args):
):
raise NotImplementedError("Unsloth is not supported in distributed training")

if args.use_4bit_quantization:
if args.use_bnb_4bit_quantization:
compute_dtype = getattr(torch, args.bnb_4bit_compute_dtype)
quant_storage_dtype = getattr(torch, args.bnb_4bit_quant_storage_dtype)

bnb_config = BitsAndBytesConfig(
load_in_4bit=args.use_4bit_quantization,
load_in_4bit=args.use_bnb_4bit_quantization,
bnb_4bit_quant_type=args.bnb_4bit_quant_type,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=args.use_nested_quant,
bnb_4bit_quant_storage=quant_storage_dtype,
)

if compute_dtype == torch.float16 and args.use_4bit_quantization:
if compute_dtype == torch.float16 and args.use_bnb_4bit_quantization:
major, _ = torch.cuda.get_device_capability()
if major >= 8:
print("=" * 80)
Expand All @@ -124,19 +124,25 @@ def create_and_prepare_model(args, data_args, training_args):
model_name=args.model_name_or_path,
max_seq_length=training_args.max_seq_length,
dtype=None,
load_in_4bit=args.use_4bit_quantization,
load_in_4bit=args.use_bnb_4bit_quantization,
)
else:
torch_dtype = (
quant_storage_dtype if quant_storage_dtype and quant_storage_dtype.is_floating_point else torch.float32
)
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
quantization_config=bnb_config,
trust_remote_code=True,
attn_implementation="flash_attention_2" if args.use_flash_attn else "eager",
torch_dtype=torch_dtype,
)

# Prepare model loading arguments
model_kwargs = {
"trust_remote_code": True,
"attn_implementation": "flash_attention_2" if args.use_flash_attn else "eager",
"torch_dtype": torch_dtype,
}

# Only add quantization_config if bnb_config is not None
if bnb_config is not None:
model_kwargs["quantization_config"] = bnb_config

model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, **model_kwargs)

peft_config = None
chat_template = None
Expand Down Expand Up @@ -178,7 +184,11 @@ def create_and_prepare_model(args, data_args, training_args):
# ante). See https://github.com/huggingface/accelerate/issues/1620.
uses_transformers_4_46 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.46.0")
uses_fsdp = os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true"
if (bnb_config is not None) and uses_fsdp and uses_transformers_4_46:
if (
(bnb_config is not None or (hasattr(model, "hf_quantizer") and model.hf_quantizer is not None))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For better readability, let's assign this line to a variable like is_quantized, WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good advice! Have adjusted the code

and uses_fsdp
and uses_transformers_4_46
):
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8, mean_resizing=False)
else:
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
Expand Down