diff --git a/VERSION b/VERSION index a7f3fc27a7a..934e07c6e0a 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.22.0 \ No newline at end of file +0.22.2 \ No newline at end of file diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index e83e7f73b77..dd34a68e03b 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -172,12 +172,6 @@ def setUp(self): self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) self.tokenizer.pad_token = self.tokenizer.eos_token - # get t5 as seq2seq example: - model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration" - self.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id) - self.t5_ref_model = AutoModelForSeq2SeqLM.from_pretrained(model_id) - self.t5_tokenizer = AutoTokenizer.from_pretrained(model_id) - def test_train(self): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") @@ -255,6 +249,39 @@ def test_train_loss_types(self, loss_type): if param.sum() != 0: # ignore 0 biases self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + @require_liger_kernel + def test_train_encoder_decoder_liger(self): + model_id = "trl-internal-testing/tiny-BartModel" + model = AutoModelForSeq2SeqLM.from_pretrained(model_id) + dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") + tokenizer = AutoTokenizer.from_pretrained(model_id) + + training_args = DPOConfig( + output_dir="selftmp_dir", + per_device_train_batch_size=2, + learning_rate=9e-1, + report_to="none", + use_liger_loss=True, + ) + trainer = DPOTrainer( + model=model, + args=training_args, + processing_class=tokenizer, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + def test_dpo_trainer_with_weighting(self): dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index fbe2071b537..b76f8ef5233 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -162,6 +162,18 @@ def test_pad_to_multiple_of(self): torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0], [0, 1, 0, 0]])) torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, -100], [4, 5, -100, -100]])) + def test_pad_to_multiple_of_and_padding_free(self): + """Test padding to multiple of specified value.""" + collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True, pad_to_multiple_of=4) + examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}] + + result = collator(examples) + + torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0]])) + torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1, 1, 0, 0, 0]])) + torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1, 0, 0, 0]])) + torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4, 5, -100, -100, -100]])) + def test_custom_position_ids(self): """Test handling of custom position IDs in examples.""" self.collator = DataCollatorForLanguageModeling(pad_token_id=0) diff --git a/trl/__init__.py b/trl/__init__.py index 29a3a8fc910..52f3dec9d25 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -12,18 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from importlib.metadata import PackageNotFoundError, version from pathlib import Path from typing import TYPE_CHECKING from .import_utils import OptionalDependencyNotAvailable, _LazyModule, is_diffusers_available -# Read version from VERSION file -_version_file = Path(__file__).parent.parent / "VERSION" try: - with open(_version_file, encoding="utf-8") as f: - __version__ = f.read().strip() -except FileNotFoundError: + __version__ = version("trl") +except PackageNotFoundError: __version__ = "unknown" _import_structure = { diff --git a/trl/data_utils.py b/trl/data_utils.py index 28acceaf906..429832ac026 100644 --- a/trl/data_utils.py +++ b/trl/data_utils.py @@ -663,8 +663,9 @@ def pack_dataset( >>> dataset = Dataset.from_dict(examples) >>> packed_dataset = pack_dataset(dataset, seq_length=4, strategy="bfd") >>> packed_dataset[:] - {'input_ids': [[1, 2, 3, 9], [6, 7, 8, 4, 5]], - 'attention_mask': [[1, 1, 0, 1], [1, 0, 0, 1, 0]]} + {'input_ids': [[1, 2, 3, 9], [6, 7, 8], [4, 5]], + 'attention_mask': [[1, 1, 0, 1], [1, 0, 0], [1, 0]], + 'seq_lengths': [[3, 1], [3], [2]]} ``` """ if map_kwargs is None: diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 126b58dbafc..19afc4eed3e 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -102,6 +102,7 @@ def shift_tokens_right(input_ids: torch.Tensor, decoder_start_token_id: int) -> shifted_input_ids = input_ids.new_zeros(input_ids.shape) shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() shifted_input_ids[:, 0] = decoder_start_token_id + return shifted_input_ids @dataclass diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 43c7e966aae..426998ff4de 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -605,6 +605,8 @@ def __post_init__(self): super().__post_init__() + self.scale_rewards = {True: "group", False: "none"}.get(self.scale_rewards, self.scale_rewards) + num_processes = self.world_size # The current default effective batch size if self.generation_batch_size is None and self.steps_per_generation is None: diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 62626cc54e7..5aeac0f0551 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -352,7 +352,7 @@ def __init__( self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode self.use_liger_loss = args.use_liger_loss self.loss_type = args.loss_type - self.scale_rewards = {True: "group", False: "none"}.get(args.scale_rewards, args.scale_rewards) + self.scale_rewards = args.scale_rewards self.importance_sampling_level = args.importance_sampling_level self.mask_truncated_completions = args.mask_truncated_completions self.top_entropy_quantile = args.top_entropy_quantile @@ -1398,11 +1398,11 @@ def _generate_and_score_completions( mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0) advantages = rewards - mean_grouped_rewards - if self.scale_rewards in ["batch", "none"]: + if self.scale_rewards in ["group", "none"]: # If self.scale_rewards = "none", we'll still log group level std std_rewards = rewards.view(-1, self.num_generations).std(dim=1) std_rewards = std_rewards.repeat_interleave(self.num_generations, dim=0) - elif self.scale_rewards == "group": + elif self.scale_rewards == "batch": # Compute global std std_rewards = rewards.std().expand_as(rewards) else: @@ -1411,7 +1411,7 @@ def _generate_and_score_completions( ) is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards)) - if self.scale_rewards in ["batch", "none"]: + if self.scale_rewards != "none": advantages = advantages / (std_rewards + 1e-4) # Slice to keep only the local part of the data diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index ff5dc97fe3e..c06911332fd 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -124,7 +124,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin): that are no in the completion. padding_free (`bool`, *optional*, defaults to `False`): If set to `True`, the sequences will be flattened into a single sequence, and the position IDs will be - generated accordingly. The attention mask will be set to 1 for all tokens. + generated accordingly. pad_to_multiple_of (`int` or `None`, *optional*, defaults to `None`): If set, the sequences will be padded to a multiple of this value. return_tensors (`str`, *optional*, defaults to `"pt"`): @@ -206,48 +206,48 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d if "assistant_masks" in examples[0]: assistant_masks = [torch.tensor(example["assistant_masks"]) for example in examples] - # Pad + # If padding_free, flatten everything into a single sequence output = {} if self.padding_free: - output["input_ids"] = torch.cat(input_ids, dim=0).unsqueeze(0) + input_ids = [torch.cat(input_ids, dim=0)] if not has_packed_position_ids: - output["attention_mask"] = torch.cat(attention_mask, dim=0).unsqueeze(0) + attention_mask = [torch.cat(attention_mask, dim=0)] if self.return_position_ids: - output["position_ids"] = torch.cat(position_ids, dim=0).unsqueeze(0) - output["labels"] = torch.cat(labels, dim=0).unsqueeze(0) + position_ids = [torch.cat(position_ids, dim=0)] + labels = [torch.cat(labels, dim=0)] if self.completion_only_loss and "completion_mask" in examples[0]: - completion_mask = torch.cat(completion_mask, dim=0).unsqueeze(0) - output["labels"][completion_mask == 0] = -100 + completion_mask = [torch.cat(completion_mask, dim=0)] if "assistant_masks" in examples[0]: - assistant_masks = torch.cat(assistant_masks, dim=0).unsqueeze(0) - output["labels"][assistant_masks == 0] = -100 - else: - output["input_ids"] = pad( - input_ids, - padding_value=self.pad_token_id, - padding_side="right", - pad_to_multiple_of=self.pad_to_multiple_of, - ) + assistant_masks = [torch.cat(assistant_masks, dim=0)] + + # Pad + output["input_ids"] = pad( + input_ids, + padding_value=self.pad_token_id, + padding_side="right", + pad_to_multiple_of=self.pad_to_multiple_of, + ) + if not has_packed_position_ids: output["attention_mask"] = pad( attention_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of ) - if self.return_position_ids: - output["position_ids"] = pad( - position_ids, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of - ) - output["labels"] = pad( - labels, padding_value=-100, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of + if self.return_position_ids: + output["position_ids"] = pad( + position_ids, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of ) - if self.completion_only_loss and "completion_mask" in examples[0]: - completion_mask = pad( - completion_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of - ) - output["labels"][completion_mask == 0] = -100 # mask everything that is not in the completion - if "assistant_masks" in examples[0]: - assistant_masks = pad( - assistant_masks, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of - ) - output["labels"][assistant_masks == 0] = -100 + output["labels"] = pad( + labels, padding_value=-100, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of + ) + if self.completion_only_loss and "completion_mask" in examples[0]: + completion_mask = pad( + completion_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of + ) + output["labels"][completion_mask == 0] = -100 # mask everything that is not in the completion + if "assistant_masks" in examples[0]: + assistant_masks = pad( + assistant_masks, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of + ) + output["labels"][assistant_masks == 0] = -100 return output @staticmethod