diff --git a/lumigator/backend/backend/services/exceptions/tracking_exceptions.py b/lumigator/backend/backend/services/exceptions/tracking_exceptions.py new file mode 100644 index 000000000..a77711ce8 --- /dev/null +++ b/lumigator/backend/backend/services/exceptions/tracking_exceptions.py @@ -0,0 +1,16 @@ +from backend.services.exceptions.base_exceptions import ( + NotFoundError, +) + + +class RunNotFoundError(NotFoundError): + """Raised when a run cannot be found.""" + + def __init__(self, resource_id: str, message: str | None = None, exc: Exception | None = None): + """Creates a RunNotFoundError. + + :param resource_id: ID of run + :param message: optional error message + :param exc: optional exception, where possible raise ``from exc`` to preserve the original traceback + """ + super().__init__("Run", str(resource_id), message, exc) diff --git a/lumigator/backend/backend/tests/conftest.py b/lumigator/backend/backend/tests/conftest.py index fbbf8dc64..065ab5ff1 100644 --- a/lumigator/backend/backend/tests/conftest.py +++ b/lumigator/backend/backend/tests/conftest.py @@ -4,6 +4,7 @@ import time import uuid from collections.abc import Generator +from datetime import datetime from pathlib import Path from unittest.mock import MagicMock from uuid import UUID @@ -26,6 +27,7 @@ JobType, ) from lumigator_schemas.models import ModelsResponse +from mlflow.entities import Metric, Param, Run, RunData, RunInfo, RunTag from s3fs import S3FileSystem from sqlalchemy import Engine, create_engine from sqlalchemy.orm import Session @@ -541,3 +543,60 @@ def model_specs_data() -> list[ModelsResponse]: models = [ModelsResponse.model_validate(item) for item in model_specs] return models + + +@pytest.fixture(scope="function") +def fake_mlflow_tracking_client(fake_s3fs): + """Fixture for MLflowTrackingClient using the real MLflowClient.""" + return MLflowTrackingClient(tracking_uri="http://mlflow.mock", s3_file_system=fake_s3fs) + + +@pytest.fixture +def sample_mlflow_run(): + """Fixture for a sample MlflowRun with mock data.""" + return Run( + run_info=RunInfo( + run_uuid="d34dbeef-1000-0000-0000-000000000000", + experiment_id="exp-1", + user_id="user", + status="FINISHED", + start_time=123456789, + end_time=None, + lifecycle_stage="active", + artifact_uri="", + ), + run_data=RunData( + metrics=[ + Metric(key="accuracy", value=0.75, timestamp=123456789, step=0), + ], + params=[ + Param(key="batch_size", value="32"), + ], + tags=[ + RunTag(key="description", value="A sample workflow"), + RunTag(key="mlflow.runName", value="Run2"), + RunTag(key="model", value="SampleModel"), + RunTag(key="system_prompt", value="Prompt text"), + RunTag(key="status", value="COMPLETED"), + ], + ), + ) + + +@pytest.fixture +def fake_mlflow_run_deleted(): + """Fixture for a deleted MLflow run.""" + run_info = RunInfo( + run_uuid="d34dbeef-1000-0000-0000-000000000000", + experiment_id="exp-456", + user_id="user-789", + status="FAILED", + start_time=int(datetime(2024, 1, 1).timestamp() * 1000), + end_time=None, + lifecycle_stage="deleted", + artifact_uri="s3://some-bucket", + ) + + run_data = RunData(metrics={}, params={}, tags={}) + + return Run(run_info=run_info, run_data=run_data) diff --git a/lumigator/backend/backend/tests/unit/tracking/test_mlflow.py b/lumigator/backend/backend/tests/unit/tracking/test_mlflow.py index 1fb5f72e2..8cd92b475 100644 --- a/lumigator/backend/backend/tests/unit/tracking/test_mlflow.py +++ b/lumigator/backend/backend/tests/unit/tracking/test_mlflow.py @@ -1,10 +1,315 @@ +import uuid +from unittest.mock import MagicMock + import pytest +from lumigator_schemas.jobs import JobResults from lumigator_schemas.workflows import WorkflowStatus -from mlflow.entities import RunStatus +from mlflow.entities import Metric, Param, Run, RunData, RunInfo, RunStatus, RunTag +from backend.services.exceptions.tracking_exceptions import RunNotFoundError from backend.tracking.mlflow import MLflowTrackingClient +def test_get_job_success(fake_mlflow_tracking_client, sample_mlflow_run): + """Test fetching job results successfully.""" + job_id = uuid.UUID("d34dbeef-1000-0000-0000-000000000000") + + fake_mlflow_tracking_client._client.get_run = MagicMock(return_value=sample_mlflow_run) + + result = fake_mlflow_tracking_client.get_job(str(job_id)) + + assert isinstance(result, JobResults) + assert result.id == job_id + assert result.metrics == [{"name": "accuracy", "value": 0.75}] + assert result.parameters == [{"name": "batch_size", "value": "32"}] + + fake_mlflow_tracking_client._client.get_run.assert_called_once_with(str(job_id)) + + +def test_get_job_deleted(fake_mlflow_tracking_client, fake_mlflow_run_deleted): + """Test fetching a deleted job returns None.""" + job_id = uuid.UUID("d34dbeef-1000-0000-0000-000000000000") + + fake_mlflow_tracking_client._client.get_run = MagicMock(return_value=fake_mlflow_run_deleted) + + with pytest.raises(RunNotFoundError): + fake_mlflow_tracking_client.get_job(str(job_id)) + + fake_mlflow_tracking_client._client.get_run.assert_called_once_with(str(job_id)) + + +def test_compile_metrics(fake_mlflow_tracking_client): + """Test metric compilation across multiple job runs.""" + job_ids = { + "job1": "d34dbeef-1000-0000-0000-000000000001", + "job2": "d34dbeef-1000-0000-0000-000000000002", + } + runs = { + job_ids["job1"]: Run( + run_info=RunInfo( + run_uuid=job_ids["job1"], + experiment_id="exp-1", + user_id="user", + status="FINISHED", + start_time=123456789, + end_time=None, + lifecycle_stage="active", + ), + run_data=RunData( + metrics=[Metric(key="accuracy", value=0.95, timestamp=123456789, step=0)], + params=[Param(key="other_thing", value="0.01")], + tags=[ + RunTag(key="mlflow.runName", value="Run1"), + ], + ), + ), + job_ids["job2"]: Run( + run_info=RunInfo( + run_uuid=job_ids["job2"], + experiment_id="exp-1", + user_id="user", + status="FINISHED", + start_time=123456789, + end_time=None, + lifecycle_stage="active", + ), + run_data=RunData( + metrics=[Metric(key="loss", value=0.2, timestamp=123456789, step=0)], + params=[Param(key="learning_rate", value="0.02")], + tags=[ + RunTag(key="mlflow.runName", value="Run2"), + ], + ), + ), + } + + fake_mlflow_tracking_client._client.get_run = MagicMock(side_effect=lambda job_id: runs[job_id]) + + result = fake_mlflow_tracking_client._compile_metrics([job_id for job_id in job_ids.values()]) + + assert result == {"loss": 0.2, "accuracy": 0.95} + fake_mlflow_tracking_client._client.get_run.assert_any_call(job_ids["job1"]) + fake_mlflow_tracking_client._client.get_run.assert_any_call(job_ids["job2"]) + + +def test_compile_metrics_conflict(fake_mlflow_tracking_client): + """Test metric conflict across job runs raises assertion error.""" + job_ids = { + "job1": "d34dbeef-1000-0000-0000-000000000001", + "job2": "d34dbeef-1000-0000-0000-000000000002", + } + runs = { + job_ids["job1"]: Run( + run_info=RunInfo( + run_uuid=job_ids["job1"], + experiment_id="exp-1", + user_id="user", + status="FINISHED", + start_time=123456789, + end_time=None, + lifecycle_stage="active", + artifact_uri="", + ), + run_data=RunData( + metrics=[Metric(key="accuracy", value=0.95, timestamp=123456789, step=0)], + params=[Param(key="other_thing", value="0.01")], + tags=[ + RunTag(key="description", value="A sample workflow"), + RunTag(key="mlflow.runName", value="Run1"), + RunTag(key="model", value="SampleModel"), + RunTag(key="system_prompt", value="Prompt text"), + RunTag(key="status", value="COMPLETED"), + ], + ), + ), + job_ids["job2"]: Run( + run_info=RunInfo( + run_uuid=job_ids["job2"], + experiment_id="exp-1", + user_id="user", + status="FINISHED", + start_time=123456789, + end_time=None, + lifecycle_stage="active", + artifact_uri="", + ), + run_data=RunData( + metrics=[Metric(key="accuracy", value=0.75, timestamp=123456789, step=0)], + params=[Param(key="learning_rate", value="0.02")], + tags=[ + RunTag(key="description", value="A sample workflow"), + RunTag(key="mlflow.runName", value="Run2"), + RunTag(key="model", value="SampleModel"), + RunTag(key="system_prompt", value="Prompt text"), + RunTag(key="status", value="COMPLETED"), + ], + ), + ), + } + + fake_mlflow_tracking_client._client.get_run = MagicMock(side_effect=lambda job_id: runs[job_id]) + + with pytest.raises(ValueError) as e: + fake_mlflow_tracking_client._compile_metrics([job_id for job_id in job_ids.values()]) + + assert str(e.value) == ( + "Duplicate metric 'accuracy' found in job 'd34dbeef-1000-0000-0000-000000000002'. " + "Stored value: 0.95, this value: 0.75" + ) + + +def test_compile_parameters(fake_mlflow_tracking_client): + """Test parameter compilation across multiple job runs.""" + job_ids = { + "job1": "d34dbeef-1000-0000-0000-000000000001", + "job2": "d34dbeef-1000-0000-0000-000000000002", + } + runs = { + job_ids["job1"]: Run( + run_info=RunInfo( + run_uuid=job_ids["job1"], + experiment_id="exp-1", + user_id="user", + status="FINISHED", + start_time=123456789, + end_time=None, + lifecycle_stage="active", + ), + run_data=RunData( + params=[Param(key="other_thing", value="0.01")], + tags=[ + RunTag(key="mlflow.runName", value="Run1"), + ], + ), + ), + job_ids["job2"]: Run( + run_info=RunInfo( + run_uuid=job_ids["job2"], + experiment_id="exp-1", + user_id="user", + status="FINISHED", + start_time=123456789, + end_time=None, + lifecycle_stage="active", + ), + run_data=RunData( + params=[Param(key="learning_rate", value="5")], + tags=[ + RunTag(key="mlflow.runName", value="Run2"), + ], + ), + ), + } + + fake_mlflow_tracking_client._client.get_run = MagicMock(side_effect=lambda job_id: runs[job_id]) + + result = fake_mlflow_tracking_client._compile_parameters([job_id for job_id in job_ids.values()]) + + assert result == { + "other_thing": { + "value": "0.01", + "jobs": { + "Run1": "0.01", + }, + }, + "learning_rate": { + "value": "5", + "jobs": { + "Run2": "5", + }, + }, + } + + +def test_compile_parameters_conflict(fake_mlflow_tracking_client): + """Test parameter conflicts result in no 'value' key being set.""" + job_ids = { + "job1": "d34dbeef-1000-0000-0000-000000000001", + "job2": "d34dbeef-1000-0000-0000-000000000002", + "job3": "d34dbeef-1000-0000-0000-000000000003", + } + runs = { + job_ids["job1"]: Run( + run_info=RunInfo( + run_uuid=job_ids["job1"], + experiment_id="exp-1", + user_id="user", + status="FINISHED", + start_time=123456789, + end_time=None, + lifecycle_stage="active", + ), + run_data=RunData( + params=[ + Param(key="other_thing", value="0.01"), + Param(key="learning_rate", value="7"), + ], + tags=[ + RunTag(key="mlflow.runName", value="Run1"), + ], + ), + ), + job_ids["job2"]: Run( + run_info=RunInfo( + run_uuid=job_ids["job2"], + experiment_id="exp-1", + user_id="user", + status="FINISHED", + start_time=123456789, + end_time=None, + lifecycle_stage="active", + ), + run_data=RunData( + params=[ + Param(key="other_thing", value="0.01"), + Param(key="learning_rate", value="5"), + ], + tags=[ + RunTag(key="mlflow.runName", value="Run2"), + ], + ), + ), + job_ids["job3"]: Run( + run_info=RunInfo( + run_uuid=job_ids["job3"], + experiment_id="exp-1", + user_id="user", + status="FINISHED", + start_time=123456789, + end_time=None, + lifecycle_stage="active", + ), + run_data=RunData( + params=[Param(key="learning_rate", value="8")], + tags=[ + RunTag(key="mlflow.runName", value="Run3"), + ], + ), + ), + } + + fake_mlflow_tracking_client._client.get_run = MagicMock(side_effect=lambda job_id: runs[job_id]) + + result = fake_mlflow_tracking_client._compile_parameters([job_id for job_id in job_ids.values()]) + + assert result == { + "other_thing": { + "value": "0.01", + "jobs": { + "Run1": "0.01", + "Run2": "0.01", + }, + }, + "learning_rate": { + "jobs": { + "Run1": "7", + "Run2": "5", + "Run3": "8", + }, + }, + } + + @pytest.mark.asyncio @pytest.mark.parametrize( "tracking_status, workflow_status, expected", diff --git a/lumigator/backend/backend/tracking/mlflow.py b/lumigator/backend/backend/tracking/mlflow.py index f63b940fe..9a1dbf1b7 100644 --- a/lumigator/backend/backend/tracking/mlflow.py +++ b/lumigator/backend/backend/tracking/mlflow.py @@ -1,9 +1,11 @@ +import asyncio import contextlib import http import json from collections.abc import Generator from datetime import datetime from http import HTTPStatus +from typing import Any from urllib.parse import urljoin from uuid import UUID @@ -22,6 +24,7 @@ from s3fs import S3FileSystem from backend.services.exceptions.job_exceptions import JobNotFoundError, JobUpstreamError +from backend.services.exceptions.tracking_exceptions import RunNotFoundError from backend.settings import settings from backend.tracking.schemas import RunOutputs from backend.tracking.tracking_interface import TrackingClient @@ -107,33 +110,62 @@ async def delete_experiment(self, experiment_id: str) -> None: # delete the experiment self._client.delete_experiment(experiment_id) - def _compile_metrics(self, job_ids: list) -> dict: - """Take the individual metrics from each run and compile them into a single dict - for now, assert that each run has no overlapping metrics + def _compile_metrics(self, job_ids: list) -> dict[str, float]: + """Aggregate metrics from job runs, ensuring no duplicate keys. + + :param job_ids: A list of job IDs to aggregate metrics from. + :return: A dictionary mapping metric names to their values. + :raises ValueError: If a duplicate metric is found across jobs. """ metrics = {} + for job_id in job_ids: run = self._client.get_run(job_id) - for metric in run.data.metrics: - assert metric not in metrics - metrics[metric] = run.data.metrics[metric] + for metric, value in run.data.metrics.items(): + if metric in metrics: + raise ValueError( + f"Duplicate metric '{metric}' found in job '{job_id}'. " + f"Stored value: {metrics[metric]}, this value: {value}" + ) + metrics[metric] = value return metrics - def _compile_parameters(self, job_ids: list) -> dict: - """Take the individual parameters from each run and compile them into a single dict - for now, assert that each run has no overlapping parameters + def _compile_parameters(self, job_ids: list) -> dict[str, dict[str, Any]]: + """Aggregate parameters across runs while associating each value with its specific run name/job ID. + + If all values for a parameter across jobs are the same, + the 'value' key will exist for that parameter with the shared value. + + :param job_ids: A list of job IDs to aggregate parameters from. + :return: A dictionary where each key is a parameter name, and the value is a dictionary containing a 'runs' + key mapping a run's parameter value. Also, when all values are the same, a 'value' key is present. + :raises MlflowException: If there is a problem getting the run data for a job. + :raises ValueError: If a run name is not found in the tags for a particular job. """ parameters = {} + value_key = "value" + jobs_key = "jobs" + for job_id in job_ids: run = self._client.get_run(job_id) - for parameter in run.data.params: - # if the parameter shows up in multiple runs, prepend the run_name to the key - # TODO: this is a hacky way to handle this, - # we should come up with a better solution but at least it keeps the info - if parameter in parameters: - parameters[f"{run.data.tags['mlflow.runName']}_{parameter}"] = run.data.params[parameter] - parameters[parameter] = run.data.params[parameter] + run_name = run.data.tags.get("mlflow.runName") + if not run_name: + raise ValueError(f"Cannot compile parameters, run name not found in tags for job: {job_id}") + + for param_key, param_value in run.data.params.items(): + param_entry = parameters.setdefault(param_key, {jobs_key: {}}) + + # Remove the shared value if it exists but doesn't match the current value. + if value_key in param_entry and param_entry[value_key] != param_value: + param_entry.pop(value_key, None) + # Add the shared value if we haven't added data for any runs yet + elif not bool(param_entry.get(jobs_key)): + param_entry[value_key] = param_value + + # Finally add the run-specific value. + param_entry[jobs_key][run_name] = param_value + return parameters def _find_workflows(self, experiment_id: str) -> list: @@ -165,7 +197,11 @@ async def get_experiment(self, experiment_id: str) -> GetExperimentResponse | No async def _format_experiment(self, experiment: MlflowExperiment) -> GetExperimentResponse: # now get all the workflows associated with that experiment workflow_ids = self._find_workflows(experiment.experiment_id) - workflows = [await self.get_workflow(workflow_id) for workflow_id in workflow_ids] + workflows = [ + workflow + for workflow in await asyncio.gather(*(self.get_workflow(wid) for wid in workflow_ids)) + if workflow is not None + ] task_definition_json = experiment.tags.get("task_definition") task_definition = TypeAdapter(TaskDefinition).validate_python(json.loads(task_definition_json)) return GetExperimentResponse( @@ -259,9 +295,10 @@ async def get_workflow(self, workflow_id: str) -> WorkflowDetailsResponse | None experiment_ids=[workflow.info.experiment_id], filter_string=f"tags.{MLFLOW_PARENT_RUN_ID} = '{workflow_id}'", ) - all_job_ids = [run.info.run_id for run in all_jobs] + # sort the jobs by created_at, with the oldest last all_jobs = sorted(all_jobs, key=lambda x: x.info.start_time) + all_job_ids = [run.info.run_id for run in all_jobs] workflow_details = WorkflowDetailsResponse( id=workflow_id, @@ -286,32 +323,21 @@ async def get_workflow(self, workflow_id: str) -> WorkflowDetailsResponse | None # check if the compiled results already exist workflow_s3_uri = self._get_s3_uri(workflow_id) if not self._s3_file_system.exists(workflow_s3_uri): - compiled_results = {"metrics": {}, "parameters": {}, "artifacts": {}} + compiled_results: dict[str, JobResultObject] = {} for job in workflow_details.jobs: # look for all parameter keys that end in "_s3_path" and download the file for param in job.parameters: if param["name"].endswith("_s3_path"): - # download the file # get the file from the S3 bucket with self._s3_file_system.open(f"{param['value']}") as f: job_results = JobResultObject.model_validate_json(f.read()) - # if any keys are the same, log a warning and then overwrite the key - for job_result_item in job_results: - if job_result_item[1] is None: - loguru.logger.info(f"No {job_result_item[0]} found for job {job.id}.") - continue - for key in job_result_item[1].keys(): - if key in compiled_results[job_result_item[0]]: - loguru.logger.warning(f"Key '{key}' already exists in the results. Overwriting.") - # merge the results into the compiled results - compiled_results[job_result_item[0]][key] = job_result_item[1][key] + compiled_results[str(job.id)] = job_results.model_dump() # Upload the compiled results to S3. self._upload_to_s3(workflow_s3_uri, compiled_results) # Update the download URL in the response as compiled results are available. workflow_details.artifacts_download_url = await self._generate_presigned_url(workflow_id) - return workflow_details async def update_workflow_status(self, workflow_id: str, status: WorkflowStatus) -> None: @@ -412,17 +438,25 @@ async def update_job(self, job_id: str, data: RunOutputs): for parameter, value in data.parameters.items(): self._client.log_param(job_id, parameter, value) - async def get_job(self, job_id: str): - """Get the results of a job.""" + async def get_job(self, job_id: str) -> JobResults: + """Get the results of a job (known as a Run in MLFlow). + + This method is used to get the metrics and parameters of a job. + + :param job_id: The ID of the job. + :return: The results of the job. + :raises RunNotFoundError: If the job is not found. + """ run = self._client.get_run(job_id) if run.info.lifecycle_stage == "deleted": - return None + raise RunNotFoundError(job_id, "deleted") + return JobResults( id=job_id, - metrics=[{"name": metric[0], "value": metric[1]} for metric in run.data.metrics.items()], - parameters=[{"name": param[0], "value": param[1]} for param in run.data.params.items()], - metric_url="TODO", - artifact_url="TODO", + metrics=[{"name": key, "value": value} for key, value in run.data.metrics.items()], + parameters=[{"name": key, "value": value} for key, value in run.data.params.items()], + metric_url="", # TODO: Implement + artifact_url="", # TODO: Implement ) async def delete_job(self, job_id: str):