Skip to content

Commit 5e90682

Browse files
authored
⚰️ Remove deprecated args, script arguments, and PPOv2 (huggingface#2306)
* Remove deprecated args * Remove deprecated args in SFTTrainer * Remove deprecated script argument classes * Remove deprecated PPOv2Config and PPOv2Trainer classes * Commented out sync_ref_model line in test_trainers_args.py
1 parent 3b43996 commit 5e90682

File tree

7 files changed

+7
-353
lines changed

7 files changed

+7
-353
lines changed

tests/test_trainers_args.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def test_dpo(self):
159159
force_use_ref_model=True,
160160
f_divergence_type="js_divergence",
161161
f_alpha_divergence_coef=0.5,
162-
sync_ref_model=True,
162+
# sync_ref_model=True, # cannot be True when precompute_ref_log_probs=True. Don't test this.
163163
ref_model_mixup_alpha=0.5,
164164
ref_model_sync_steps=32,
165165
rpo_alpha=0.5,
@@ -189,7 +189,7 @@ def test_dpo(self):
189189
self.assertEqual(trainer.args.force_use_ref_model, True)
190190
self.assertEqual(trainer.args.f_divergence_type, "js_divergence")
191191
self.assertEqual(trainer.args.f_alpha_divergence_coef, 0.5)
192-
self.assertEqual(trainer.args.sync_ref_model, True)
192+
# self.assertEqual(trainer.args.sync_ref_model, True)
193193
self.assertEqual(trainer.args.ref_model_mixup_alpha, 0.5)
194194
self.assertEqual(trainer.args.ref_model_sync_steps, 32)
195195
self.assertEqual(trainer.args.rpo_alpha, 0.5)

trl/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,6 @@
7979
"PairRMJudge",
8080
"PPOConfig",
8181
"PPOTrainer",
82-
"PPOv2Config",
83-
"PPOv2Trainer",
8482
"RandomPairwiseJudge",
8583
"RandomRankJudge",
8684
"RewardConfig",
@@ -170,8 +168,6 @@
170168
PairRMJudge,
171169
PPOConfig,
172170
PPOTrainer,
173-
PPOv2Config,
174-
PPOv2Trainer,
175171
RandomPairwiseJudge,
176172
RandomRankJudge,
177173
RewardConfig,

trl/commands/cli_utils.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@
2525
import yaml
2626
from transformers import HfArgumentParser
2727

28-
from ..utils import ScriptArguments
29-
3028

3129
logger = logging.getLogger(__name__)
3230

@@ -81,33 +79,6 @@ def warning_handler(message, category, filename, lineno, file=None, line=None):
8179
warnings.showwarning = warning_handler
8280

8381

84-
@dataclass
85-
class SFTScriptArguments(ScriptArguments):
86-
def __post_init__(self):
87-
logger.warning(
88-
"`SFTScriptArguments` is deprecated, and will be removed in v0.13. Please use "
89-
"`ScriptArguments` instead."
90-
)
91-
92-
93-
@dataclass
94-
class RewardScriptArguments(ScriptArguments):
95-
def __post_init__(self):
96-
logger.warning(
97-
"`RewardScriptArguments` is deprecated, and will be removed in v0.13. Please use "
98-
"`ScriptArguments` instead."
99-
)
100-
101-
102-
@dataclass
103-
class DPOScriptArguments(ScriptArguments):
104-
def __post_init__(self):
105-
logger.warning(
106-
"`DPOScriptArguments` is deprecated, and will be removed in v0.13. Please use "
107-
"`ScriptArguments` instead."
108-
)
109-
110-
11182
@dataclass
11283
class ChatArguments:
11384
# general settings

trl/trainer/dpo_trainer.py

Lines changed: 5 additions & 165 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from accelerate import PartialState
3232
from accelerate.utils import is_deepspeed_available, tqdm
3333
from datasets import Dataset
34-
from huggingface_hub.utils._deprecation import _deprecate_arguments
3534
from torch.utils.data import DataLoader
3635
from transformers import (
3736
AutoModelForCausalLM,
@@ -172,6 +171,9 @@ class DPOTrainer(Trainer):
172171
This supercedes the `tokenizer` argument, which is now deprecated.
173172
model_init (`Callable[[], transformers.PreTrainedModel]`):
174173
The model initializer to use for training. If None is specified, the default model initializer will be used.
174+
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
175+
The function to use to compute the metrics. Must take a `EvalPrediction` and return
176+
a dictionary string to metric values.
175177
callbacks (`List[transformers.TrainerCallback]`):
176178
The callbacks to use for training.
177179
optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
@@ -180,90 +182,35 @@ class DPOTrainer(Trainer):
180182
The function to use to preprocess the logits before computing the metrics.
181183
peft_config (`Dict`, defaults to `None`):
182184
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
183-
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
184-
The function to use to compute the metrics. Must take a `EvalPrediction` and return
185-
a dictionary string to metric values.
186185
"""
187186

188187
_tag_names = ["trl", "dpo"]
189188

190-
@_deprecate_arguments(
191-
version="0.13.0",
192-
deprecated_args=[
193-
"beta",
194-
"label_smoothing",
195-
"loss_type",
196-
"label_pad_token_id",
197-
"padding_value",
198-
"truncation_mode",
199-
"max_length",
200-
"max_prompt_length",
201-
"max_target_length",
202-
"is_encoder_decoder",
203-
"disable_dropout",
204-
"generate_during_eval",
205-
"precompute_ref_log_probs",
206-
"dataset_num_proc",
207-
"model_init_kwargs",
208-
"ref_model_init_kwargs",
209-
"model_adapter_name",
210-
"ref_adapter_name",
211-
"reference_free",
212-
"force_use_ref_model",
213-
],
214-
custom_message="Deprecated positional argument(s) used in DPOTrainer, please use the DPOConfig to set these arguments instead.",
215-
)
216189
@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.14.0", raise_if_both_names=True)
217190
def __init__(
218191
self,
219192
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
220193
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
221-
beta: float = 0.1,
222-
label_smoothing: float = 0,
223-
loss_type: Optional[str] = None,
224194
args: Optional[DPOConfig] = None,
225195
data_collator: Optional[DataCollator] = None,
226-
label_pad_token_id: int = -100,
227-
padding_value: Optional[int] = None,
228-
truncation_mode: str = "keep_end",
229196
train_dataset: Optional[Dataset] = None,
230197
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
231198
processing_class: Optional[
232199
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
233200
] = None,
234201
model_init: Optional[Callable[[], PreTrainedModel]] = None,
202+
compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
235203
callbacks: Optional[List[TrainerCallback]] = None,
236204
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
237205
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
238-
max_length: Optional[int] = None,
239-
max_prompt_length: Optional[int] = None,
240-
max_target_length: Optional[int] = None,
241206
peft_config: Optional[Dict] = None,
242-
is_encoder_decoder: Optional[bool] = None,
243-
disable_dropout: bool = True,
244-
generate_during_eval: bool = False,
245-
compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
246-
precompute_ref_log_probs: bool = False,
247-
dataset_num_proc: Optional[int] = None,
248-
model_init_kwargs: Optional[Dict] = None,
249-
ref_model_init_kwargs: Optional[Dict] = None,
250-
model_adapter_name: Optional[str] = None,
251-
ref_adapter_name: Optional[str] = None,
252-
reference_free: bool = False,
253-
force_use_ref_model: bool = False,
254207
):
255208
if not isinstance(model, str) and ref_model is model:
256209
raise ValueError(
257210
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
258211
"same as `model`, you must mass a copy of it, or `None` if you use peft."
259212
)
260213

261-
if model_init_kwargs is not None:
262-
warnings.warn(
263-
"You passed `model_init_kwargs` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
264-
)
265-
args.model_init_kwargs = model_init_kwargs
266-
267214
if args.model_init_kwargs is None:
268215
model_init_kwargs = {}
269216
elif not isinstance(model, str):
@@ -283,12 +230,6 @@ def __init__(
283230
)
284231
model_init_kwargs["torch_dtype"] = torch_dtype
285232

286-
if ref_model_init_kwargs is not None:
287-
warnings.warn(
288-
"You passed `ref_model_init_kwargs` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
289-
)
290-
args.ref_model_init_kwargs = ref_model_init_kwargs
291-
292233
if args.ref_model_init_kwargs is None:
293234
ref_model_init_kwargs = {}
294235
elif not isinstance(ref_model, str):
@@ -326,12 +267,6 @@ def __init__(
326267
# has been called in order to properly call autocast if needed.
327268
self._peft_has_been_casted_to_bf16 = False
328269

329-
if force_use_ref_model:
330-
warnings.warn(
331-
"You passed `force_use_ref_model` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
332-
)
333-
args.force_use_ref_model = force_use_ref_model
334-
335270
if not is_peft_available() and peft_config is not None:
336271
raise ValueError(
337272
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
@@ -393,22 +328,12 @@ def make_inputs_require_grad(module, input, output):
393328

394329
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
395330

396-
if generate_during_eval:
397-
warnings.warn(
398-
"You passed `generate_during_eval` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
399-
)
400-
args.generate_during_eval = generate_during_eval
401331
if args.generate_during_eval and not is_wandb_available():
402332
raise ValueError(
403333
"`generate_during_eval=True` requires Weights and Biases to be installed."
404334
" Please install `wandb` to resolve."
405335
)
406336

407-
if is_encoder_decoder is not None:
408-
warnings.warn(
409-
"You passed `is_encoder_decoder` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
410-
)
411-
args.is_encoder_decoder = is_encoder_decoder
412337
if model is not None:
413338
self.is_encoder_decoder = model.config.is_encoder_decoder
414339
elif args.is_encoder_decoder is None:
@@ -427,33 +352,10 @@ def make_inputs_require_grad(module, input, output):
427352
self.is_vision_model = False
428353

429354
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
430-
if model_adapter_name is not None:
431-
warnings.warn(
432-
"You passed `model_adapter_name` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
433-
)
434-
args.model_adapter_name = model_adapter_name
435355
self.model_adapter_name = args.model_adapter_name
436-
437-
if ref_adapter_name is not None:
438-
warnings.warn(
439-
"You passed `ref_adapter_name` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
440-
)
441-
args.ref_adapter_name = ref_adapter_name
442356
self.ref_adapter_name = args.ref_adapter_name
443-
444-
if reference_free:
445-
warnings.warn(
446-
"You passed `reference_free` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
447-
)
448-
args.reference_free = reference_free
449357
self.reference_free = args.reference_free
450358

451-
if precompute_ref_log_probs:
452-
warnings.warn(
453-
"You passed `precompute_ref_log_probs` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
454-
)
455-
args.precompute_ref_log_probs = precompute_ref_log_probs
456-
457359
if ref_model:
458360
self.ref_model = ref_model
459361
elif self.is_peft_model or args.precompute_ref_log_probs:
@@ -465,36 +367,6 @@ def make_inputs_require_grad(module, input, output):
465367
if processing_class is None:
466368
raise ValueError("processing_class must be specified to tokenize a DPO dataset.")
467369

468-
if max_length is not None:
469-
warnings.warn(
470-
"You passed `max_length` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
471-
)
472-
args.max_length = max_length
473-
474-
if max_prompt_length is not None:
475-
warnings.warn(
476-
"You passed `max_prompt_length` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
477-
)
478-
args.max_prompt_length = max_prompt_length
479-
480-
if max_target_length is not None:
481-
warnings.warn(
482-
"You passed `max_target_length` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
483-
)
484-
args.max_completion_length = max_target_length
485-
486-
if label_pad_token_id != -100:
487-
warnings.warn(
488-
"You passed `label_pad_token_id` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
489-
)
490-
args.label_pad_token_id = label_pad_token_id
491-
492-
if padding_value is not None:
493-
warnings.warn(
494-
"You passed `padding_value` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
495-
)
496-
args.padding_value = padding_value
497-
498370
if args.padding_value is not None:
499371
self.padding_value = args.padding_value
500372
else:
@@ -512,11 +384,6 @@ def make_inputs_require_grad(module, input, output):
512384
if data_collator is None:
513385
data_collator = PreferenceCollator(pad_token_id=self.padding_value)
514386

515-
if not disable_dropout:
516-
warnings.warn(
517-
"You passed `disable_dropout` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
518-
)
519-
args.disable_dropout = disable_dropout
520387
if args.disable_dropout:
521388
disable_dropout_in_model(model)
522389
if self.ref_model is not None:
@@ -526,11 +393,6 @@ def make_inputs_require_grad(module, input, output):
526393
self.generate_during_eval = args.generate_during_eval
527394
self.label_pad_token_id = args.label_pad_token_id
528395
self.max_prompt_length = args.max_prompt_length
529-
if truncation_mode != "keep_end":
530-
warnings.warn(
531-
"You passed `truncation_mode` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
532-
)
533-
args.truncation_mode = truncation_mode
534396
self.truncation_mode = args.truncation_mode
535397
self.max_completion_length = args.max_completion_length
536398
self.precompute_ref_log_probs = args.precompute_ref_log_probs
@@ -540,16 +402,6 @@ def make_inputs_require_grad(module, input, output):
540402
self._precomputed_train_ref_log_probs = False
541403
self._precomputed_eval_ref_log_probs = False
542404

543-
if loss_type is not None:
544-
warnings.warn(
545-
"You passed `loss_type` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
546-
)
547-
args.loss_type = loss_type
548-
if label_smoothing != 0:
549-
warnings.warn(
550-
"You passed `label_smoothing` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
551-
)
552-
args.label_smoothing = label_smoothing
553405
if (
554406
args.loss_type in ["hinge", "ipo", "bco_pair", "sppo_hard", "nca_pair", "apo_zero", "apo_down"]
555407
and args.label_smoothing > 0
@@ -560,11 +412,6 @@ def make_inputs_require_grad(module, input, output):
560412
if args.loss_type == "kto_pair":
561413
raise ValueError("Support for kto_pair has been removed in DPOTrainer. Please use KTOTrainer.")
562414

563-
if beta != 0.1:
564-
warnings.warn(
565-
"You passed `beta` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
566-
)
567-
args.beta = beta
568415
self.beta = args.beta
569416
self.label_smoothing = args.label_smoothing
570417
self.loss_type = args.loss_type
@@ -578,15 +425,8 @@ def make_inputs_require_grad(module, input, output):
578425
)
579426

580427
self._stored_metrics = defaultdict(lambda: defaultdict(list))
581-
582428
self.f_divergence_type = args.f_divergence_type
583429
self.f_divergence_params = {FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: args.f_alpha_divergence_coef}
584-
585-
if dataset_num_proc is not None:
586-
warnings.warn(
587-
"You passed `dataset_num_proc` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
588-
)
589-
args.dataset_num_proc = dataset_num_proc
590430
self.dataset_num_proc = args.dataset_num_proc
591431

592432
# Compute that only on the main process for faster data processing.
@@ -683,7 +523,7 @@ def make_inputs_require_grad(module, input, output):
683523
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
684524

685525
if args.sync_ref_model:
686-
if precompute_ref_log_probs:
526+
if self.precompute_ref_log_probs:
687527
raise ValueError(
688528
"You cannot use `precompute_ref_log_probs=True` with TR-DPO method. Please set `precompute_ref_log_probs=False`."
689529
)

0 commit comments

Comments
 (0)