diff --git a/src/chronos/chronos2/pipeline.py b/src/chronos/chronos2/pipeline.py index 3eddcd7f..d0a89d8d 100644 --- a/src/chronos/chronos2/pipeline.py +++ b/src/chronos/chronos2/pipeline.py @@ -40,10 +40,10 @@ class Chronos2Pipeline(BaseChronosPipeline): forecast_type: ForecastType = ForecastType.QUANTILES default_context_length: int = 2048 - def __init__(self, model: Chronos2Model): + def __init__(self, model: Chronos2Model, revision: str | None = None): super().__init__(inner_model=model) self.model = model - + self.revision = revision @staticmethod def _get_prob_mass_per_quantile_level(quantile_levels: torch.Tensor) -> torch.Tensor: """ @@ -195,6 +195,7 @@ def fit( model.load_state_dict(self.model.state_dict()) if finetune_mode == "lora": + lora_revision = self.revision if lora_config is None: lora_config = LoraConfig( r=8, @@ -206,8 +207,10 @@ def fit( "self_attention.o", "output_patch_embedding.output_layer", ], + revision=lora_revision, ) elif isinstance(lora_config, dict): + lora_config.setdefault("revision", lora_revision) lora_config = LoraConfig(**lora_config) else: assert isinstance(lora_config, LoraConfig), ( @@ -1161,6 +1164,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): Load the model, either from a local path, S3 prefix or from the HuggingFace Hub. Supports the same arguments as ``AutoConfig`` and ``AutoModel`` from ``transformers``. """ + revision = kwargs.get("revision") # Check if the model is on S3 and cache it locally first # NOTE: Only base models (not LoRA adapters) are supported via S3 @@ -1178,7 +1182,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): model = AutoPeftModel.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) model = model.merge_and_unload() - return cls(model=model) + return cls(model=model, revision=revision) # Handle the case for the base model config = AutoConfig.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) @@ -1192,7 +1196,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): class_ = Chronos2Model model = class_.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) - return cls(model=model) + return cls(model=model, revision=revision) def save_pretrained(self, save_directory: str | Path, *args, **kwargs): """ diff --git a/test/test_chronos2.py b/test/test_chronos2.py index 3fac7261..14f81796 100644 --- a/test/test_chronos2.py +++ b/test/test_chronos2.py @@ -1132,3 +1132,38 @@ def test_eager_and_sdpa_produce_identical_outputs(pipeline): for out_eager, out_sdpa in zip(outputs_eager_grouped, outputs_sdpa_grouped): # Should match exactly or very close (numerical precision) assert torch.allclose(out_eager, out_sdpa, atol=1e-5, rtol=1e-4) + + +@pytest.mark.parametrize("source_revision", ["my-test-branch", None]) +def test_lora_config_uses_source_revision_from_instantiation( + pipeline: Chronos2Pipeline, tmpdir, source_revision +): + """ + Test that fit in 'lora' mode correctly uses the 'revision' + stored in the pipeline instance. + """ + output_dir = Path(tmpdir) + dummy_inputs = [torch.rand(100)] + + pipeline.revision = source_revision + + pipeline.fit( + inputs=dummy_inputs, + prediction_length=10, + finetune_mode="lora", + output_dir=output_dir, + num_steps=1, # Keep it fast + batch_size=32, + ) + + adapter_config_path = output_dir / "finetuned-ckpt" / "adapter_config.json" + assert adapter_config_path.exists(), "adapter_config.json was not created" + + with open(adapter_config_path, "r") as f: + adapter_config = json.load(f) + + if source_revision is not None: + assert "revision" in adapter_config + assert adapter_config["revision"] == source_revision + else: + assert "revision" not in adapter_config or adapter_config["revision"] is None