3131from accelerate import PartialState
3232from accelerate .utils import is_deepspeed_available , tqdm
3333from datasets import Dataset
34- from huggingface_hub .utils ._deprecation import _deprecate_arguments
3534from torch .utils .data import DataLoader
3635from transformers import (
3736 AutoModelForCausalLM ,
@@ -172,6 +171,9 @@ class DPOTrainer(Trainer):
172171 This supercedes the `tokenizer` argument, which is now deprecated.
173172 model_init (`Callable[[], transformers.PreTrainedModel]`):
174173 The model initializer to use for training. If None is specified, the default model initializer will be used.
174+ compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
175+ The function to use to compute the metrics. Must take a `EvalPrediction` and return
176+ a dictionary string to metric values.
175177 callbacks (`List[transformers.TrainerCallback]`):
176178 The callbacks to use for training.
177179 optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
@@ -180,90 +182,35 @@ class DPOTrainer(Trainer):
180182 The function to use to preprocess the logits before computing the metrics.
181183 peft_config (`Dict`, defaults to `None`):
182184 The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
183- compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
184- The function to use to compute the metrics. Must take a `EvalPrediction` and return
185- a dictionary string to metric values.
186185 """
187186
188187 _tag_names = ["trl" , "dpo" ]
189188
190- @_deprecate_arguments (
191- version = "0.13.0" ,
192- deprecated_args = [
193- "beta" ,
194- "label_smoothing" ,
195- "loss_type" ,
196- "label_pad_token_id" ,
197- "padding_value" ,
198- "truncation_mode" ,
199- "max_length" ,
200- "max_prompt_length" ,
201- "max_target_length" ,
202- "is_encoder_decoder" ,
203- "disable_dropout" ,
204- "generate_during_eval" ,
205- "precompute_ref_log_probs" ,
206- "dataset_num_proc" ,
207- "model_init_kwargs" ,
208- "ref_model_init_kwargs" ,
209- "model_adapter_name" ,
210- "ref_adapter_name" ,
211- "reference_free" ,
212- "force_use_ref_model" ,
213- ],
214- custom_message = "Deprecated positional argument(s) used in DPOTrainer, please use the DPOConfig to set these arguments instead." ,
215- )
216189 @deprecate_kwarg ("tokenizer" , new_name = "processing_class" , version = "0.14.0" , raise_if_both_names = True )
217190 def __init__ (
218191 self ,
219192 model : Optional [Union [PreTrainedModel , nn .Module , str ]] = None ,
220193 ref_model : Optional [Union [PreTrainedModel , nn .Module , str ]] = None ,
221- beta : float = 0.1 ,
222- label_smoothing : float = 0 ,
223- loss_type : Optional [str ] = None ,
224194 args : Optional [DPOConfig ] = None ,
225195 data_collator : Optional [DataCollator ] = None ,
226- label_pad_token_id : int = - 100 ,
227- padding_value : Optional [int ] = None ,
228- truncation_mode : str = "keep_end" ,
229196 train_dataset : Optional [Dataset ] = None ,
230197 eval_dataset : Optional [Union [Dataset , Dict [str , Dataset ]]] = None ,
231198 processing_class : Optional [
232199 Union [PreTrainedTokenizerBase , BaseImageProcessor , FeatureExtractionMixin , ProcessorMixin ]
233200 ] = None ,
234201 model_init : Optional [Callable [[], PreTrainedModel ]] = None ,
202+ compute_metrics : Optional [Callable [[EvalLoopOutput ], Dict ]] = None ,
235203 callbacks : Optional [List [TrainerCallback ]] = None ,
236204 optimizers : Tuple [torch .optim .Optimizer , torch .optim .lr_scheduler .LambdaLR ] = (None , None ),
237205 preprocess_logits_for_metrics : Optional [Callable [[torch .Tensor , torch .Tensor ], torch .Tensor ]] = None ,
238- max_length : Optional [int ] = None ,
239- max_prompt_length : Optional [int ] = None ,
240- max_target_length : Optional [int ] = None ,
241206 peft_config : Optional [Dict ] = None ,
242- is_encoder_decoder : Optional [bool ] = None ,
243- disable_dropout : bool = True ,
244- generate_during_eval : bool = False ,
245- compute_metrics : Optional [Callable [[EvalLoopOutput ], Dict ]] = None ,
246- precompute_ref_log_probs : bool = False ,
247- dataset_num_proc : Optional [int ] = None ,
248- model_init_kwargs : Optional [Dict ] = None ,
249- ref_model_init_kwargs : Optional [Dict ] = None ,
250- model_adapter_name : Optional [str ] = None ,
251- ref_adapter_name : Optional [str ] = None ,
252- reference_free : bool = False ,
253- force_use_ref_model : bool = False ,
254207 ):
255208 if not isinstance (model , str ) and ref_model is model :
256209 raise ValueError (
257210 "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
258211 "same as `model`, you must mass a copy of it, or `None` if you use peft."
259212 )
260213
261- if model_init_kwargs is not None :
262- warnings .warn (
263- "You passed `model_init_kwargs` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
264- )
265- args .model_init_kwargs = model_init_kwargs
266-
267214 if args .model_init_kwargs is None :
268215 model_init_kwargs = {}
269216 elif not isinstance (model , str ):
@@ -283,12 +230,6 @@ def __init__(
283230 )
284231 model_init_kwargs ["torch_dtype" ] = torch_dtype
285232
286- if ref_model_init_kwargs is not None :
287- warnings .warn (
288- "You passed `ref_model_init_kwargs` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
289- )
290- args .ref_model_init_kwargs = ref_model_init_kwargs
291-
292233 if args .ref_model_init_kwargs is None :
293234 ref_model_init_kwargs = {}
294235 elif not isinstance (ref_model , str ):
@@ -326,12 +267,6 @@ def __init__(
326267 # has been called in order to properly call autocast if needed.
327268 self ._peft_has_been_casted_to_bf16 = False
328269
329- if force_use_ref_model :
330- warnings .warn (
331- "You passed `force_use_ref_model` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
332- )
333- args .force_use_ref_model = force_use_ref_model
334-
335270 if not is_peft_available () and peft_config is not None :
336271 raise ValueError (
337272 "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
@@ -393,22 +328,12 @@ def make_inputs_require_grad(module, input, output):
393328
394329 model .get_input_embeddings ().register_forward_hook (make_inputs_require_grad )
395330
396- if generate_during_eval :
397- warnings .warn (
398- "You passed `generate_during_eval` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
399- )
400- args .generate_during_eval = generate_during_eval
401331 if args .generate_during_eval and not is_wandb_available ():
402332 raise ValueError (
403333 "`generate_during_eval=True` requires Weights and Biases to be installed."
404334 " Please install `wandb` to resolve."
405335 )
406336
407- if is_encoder_decoder is not None :
408- warnings .warn (
409- "You passed `is_encoder_decoder` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
410- )
411- args .is_encoder_decoder = is_encoder_decoder
412337 if model is not None :
413338 self .is_encoder_decoder = model .config .is_encoder_decoder
414339 elif args .is_encoder_decoder is None :
@@ -427,33 +352,10 @@ def make_inputs_require_grad(module, input, output):
427352 self .is_vision_model = False
428353
429354 self .is_peft_model = is_peft_available () and isinstance (model , PeftModel )
430- if model_adapter_name is not None :
431- warnings .warn (
432- "You passed `model_adapter_name` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
433- )
434- args .model_adapter_name = model_adapter_name
435355 self .model_adapter_name = args .model_adapter_name
436-
437- if ref_adapter_name is not None :
438- warnings .warn (
439- "You passed `ref_adapter_name` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
440- )
441- args .ref_adapter_name = ref_adapter_name
442356 self .ref_adapter_name = args .ref_adapter_name
443-
444- if reference_free :
445- warnings .warn (
446- "You passed `reference_free` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
447- )
448- args .reference_free = reference_free
449357 self .reference_free = args .reference_free
450358
451- if precompute_ref_log_probs :
452- warnings .warn (
453- "You passed `precompute_ref_log_probs` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
454- )
455- args .precompute_ref_log_probs = precompute_ref_log_probs
456-
457359 if ref_model :
458360 self .ref_model = ref_model
459361 elif self .is_peft_model or args .precompute_ref_log_probs :
@@ -465,36 +367,6 @@ def make_inputs_require_grad(module, input, output):
465367 if processing_class is None :
466368 raise ValueError ("processing_class must be specified to tokenize a DPO dataset." )
467369
468- if max_length is not None :
469- warnings .warn (
470- "You passed `max_length` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
471- )
472- args .max_length = max_length
473-
474- if max_prompt_length is not None :
475- warnings .warn (
476- "You passed `max_prompt_length` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
477- )
478- args .max_prompt_length = max_prompt_length
479-
480- if max_target_length is not None :
481- warnings .warn (
482- "You passed `max_target_length` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
483- )
484- args .max_completion_length = max_target_length
485-
486- if label_pad_token_id != - 100 :
487- warnings .warn (
488- "You passed `label_pad_token_id` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
489- )
490- args .label_pad_token_id = label_pad_token_id
491-
492- if padding_value is not None :
493- warnings .warn (
494- "You passed `padding_value` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
495- )
496- args .padding_value = padding_value
497-
498370 if args .padding_value is not None :
499371 self .padding_value = args .padding_value
500372 else :
@@ -512,11 +384,6 @@ def make_inputs_require_grad(module, input, output):
512384 if data_collator is None :
513385 data_collator = PreferenceCollator (pad_token_id = self .padding_value )
514386
515- if not disable_dropout :
516- warnings .warn (
517- "You passed `disable_dropout` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
518- )
519- args .disable_dropout = disable_dropout
520387 if args .disable_dropout :
521388 disable_dropout_in_model (model )
522389 if self .ref_model is not None :
@@ -526,11 +393,6 @@ def make_inputs_require_grad(module, input, output):
526393 self .generate_during_eval = args .generate_during_eval
527394 self .label_pad_token_id = args .label_pad_token_id
528395 self .max_prompt_length = args .max_prompt_length
529- if truncation_mode != "keep_end" :
530- warnings .warn (
531- "You passed `truncation_mode` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
532- )
533- args .truncation_mode = truncation_mode
534396 self .truncation_mode = args .truncation_mode
535397 self .max_completion_length = args .max_completion_length
536398 self .precompute_ref_log_probs = args .precompute_ref_log_probs
@@ -540,16 +402,6 @@ def make_inputs_require_grad(module, input, output):
540402 self ._precomputed_train_ref_log_probs = False
541403 self ._precomputed_eval_ref_log_probs = False
542404
543- if loss_type is not None :
544- warnings .warn (
545- "You passed `loss_type` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
546- )
547- args .loss_type = loss_type
548- if label_smoothing != 0 :
549- warnings .warn (
550- "You passed `label_smoothing` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
551- )
552- args .label_smoothing = label_smoothing
553405 if (
554406 args .loss_type in ["hinge" , "ipo" , "bco_pair" , "sppo_hard" , "nca_pair" , "apo_zero" , "apo_down" ]
555407 and args .label_smoothing > 0
@@ -560,11 +412,6 @@ def make_inputs_require_grad(module, input, output):
560412 if args .loss_type == "kto_pair" :
561413 raise ValueError ("Support for kto_pair has been removed in DPOTrainer. Please use KTOTrainer." )
562414
563- if beta != 0.1 :
564- warnings .warn (
565- "You passed `beta` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
566- )
567- args .beta = beta
568415 self .beta = args .beta
569416 self .label_smoothing = args .label_smoothing
570417 self .loss_type = args .loss_type
@@ -578,15 +425,8 @@ def make_inputs_require_grad(module, input, output):
578425 )
579426
580427 self ._stored_metrics = defaultdict (lambda : defaultdict (list ))
581-
582428 self .f_divergence_type = args .f_divergence_type
583429 self .f_divergence_params = {FDivergenceConstants .ALPHA_DIVERGENCE_COEF_KEY : args .f_alpha_divergence_coef }
584-
585- if dataset_num_proc is not None :
586- warnings .warn (
587- "You passed `dataset_num_proc` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
588- )
589- args .dataset_num_proc = dataset_num_proc
590430 self .dataset_num_proc = args .dataset_num_proc
591431
592432 # Compute that only on the main process for faster data processing.
@@ -683,7 +523,7 @@ def make_inputs_require_grad(module, input, output):
683523 self .ref_model = self .accelerator .prepare_model (self .ref_model , evaluation_mode = True )
684524
685525 if args .sync_ref_model :
686- if precompute_ref_log_probs :
526+ if self . precompute_ref_log_probs :
687527 raise ValueError (
688528 "You cannot use `precompute_ref_log_probs=True` with TR-DPO method. Please set `precompute_ref_log_probs=False`."
689529 )
0 commit comments