From 5b651a5822ed8fc50246185b22a68f9a2ed70096 Mon Sep 17 00:00:00 2001 From: Peter Wilson Date: Fri, 21 Mar 2025 18:46:53 +0000 Subject: [PATCH 1/3] MLFlow: updated logic for creating compiled.json Updated get_job, _compile_metrics and _compile_parameters Add exception for workflow run not found Added unit tests --- .../exceptions/tracking_exceptions.py | 16 + lumigator/backend/backend/tests/conftest.py | 59 ++++ .../tests/unit/tracking/test_mlflow.py | 301 ++++++++++++++++++ lumigator/backend/backend/tracking/mlflow.py | 106 +++--- .../backend/tracking/tracking_interface.py | 2 +- 5 files changed, 447 insertions(+), 37 deletions(-) create mode 100644 lumigator/backend/backend/services/exceptions/tracking_exceptions.py create mode 100644 lumigator/backend/backend/tests/unit/tracking/test_mlflow.py diff --git a/lumigator/backend/backend/services/exceptions/tracking_exceptions.py b/lumigator/backend/backend/services/exceptions/tracking_exceptions.py new file mode 100644 index 000000000..a77711ce8 --- /dev/null +++ b/lumigator/backend/backend/services/exceptions/tracking_exceptions.py @@ -0,0 +1,16 @@ +from backend.services.exceptions.base_exceptions import ( + NotFoundError, +) + + +class RunNotFoundError(NotFoundError): + """Raised when a run cannot be found.""" + + def __init__(self, resource_id: str, message: str | None = None, exc: Exception | None = None): + """Creates a RunNotFoundError. + + :param resource_id: ID of run + :param message: optional error message + :param exc: optional exception, where possible raise ``from exc`` to preserve the original traceback + """ + super().__init__("Run", str(resource_id), message, exc) diff --git a/lumigator/backend/backend/tests/conftest.py b/lumigator/backend/backend/tests/conftest.py index b6b167cc4..1783391c2 100644 --- a/lumigator/backend/backend/tests/conftest.py +++ b/lumigator/backend/backend/tests/conftest.py @@ -4,6 +4,7 @@ import time import uuid from collections.abc import Generator +from datetime import datetime from pathlib import Path from unittest.mock import MagicMock from uuid import UUID @@ -26,6 +27,7 @@ JobType, ) from lumigator_schemas.models import ModelsResponse +from mlflow.entities import Metric, Param, Run, RunData, RunInfo, RunTag from s3fs import S3FileSystem from sqlalchemy import Engine, create_engine from sqlalchemy.orm import Session @@ -532,3 +534,60 @@ def model_specs_data() -> list[ModelsResponse]: models = [ModelsResponse.model_validate(item) for item in model_specs] return models + + +@pytest.fixture(scope="function") +def fake_mlflow_tracking_client(fake_s3fs): + """Fixture for MLflowTrackingClient using the real MLflowClient.""" + return MLflowTrackingClient(tracking_uri="http://mlflow.mock", s3_file_system=fake_s3fs) + + +@pytest.fixture +def sample_mlflow_run(): + """Fixture for a sample MlflowRun with mock data.""" + return Run( + run_info=RunInfo( + run_uuid="d34dbeef-1000-0000-0000-000000000000", + experiment_id="exp-1", + user_id="user", + status="FINISHED", + start_time=123456789, + end_time=None, + lifecycle_stage="active", + artifact_uri="", + ), + run_data=RunData( + metrics=[ + Metric(key="accuracy", value=0.75, timestamp=123456789, step=0), + ], + params=[ + Param(key="batch_size", value="32"), + ], + tags=[ + RunTag(key="description", value="A sample workflow"), + RunTag(key="mlflow.runName", value="Run2"), + RunTag(key="model", value="SampleModel"), + RunTag(key="system_prompt", value="Prompt text"), + RunTag(key="status", value="COMPLETED"), + ], + ), + ) + + +@pytest.fixture +def fake_mlflow_run_deleted(): + """Fixture for a deleted MLflow run.""" + run_info = RunInfo( + run_uuid="d34dbeef-1000-0000-0000-000000000000", + experiment_id="exp-456", + user_id="user-789", + status="FAILED", + start_time=int(datetime(2024, 1, 1).timestamp() * 1000), + end_time=None, + lifecycle_stage="deleted", + artifact_uri="s3://some-bucket", + ) + + run_data = RunData(metrics={}, params={}, tags={}) + + return Run(run_info=run_info, run_data=run_data) diff --git a/lumigator/backend/backend/tests/unit/tracking/test_mlflow.py b/lumigator/backend/backend/tests/unit/tracking/test_mlflow.py new file mode 100644 index 000000000..0b11d9b0c --- /dev/null +++ b/lumigator/backend/backend/tests/unit/tracking/test_mlflow.py @@ -0,0 +1,301 @@ +import uuid +from unittest.mock import MagicMock + +import pytest +from lumigator_schemas.jobs import JobResults +from mlflow.entities import Metric, Param, Run, RunData, RunInfo, RunTag + +from backend.services.exceptions.tracking_exceptions import RunNotFoundError + + +def test_get_job_success(fake_mlflow_tracking_client, sample_mlflow_run): + """Test fetching job results successfully.""" + job_id = uuid.UUID("d34dbeef-1000-0000-0000-000000000000") + + fake_mlflow_tracking_client._client.get_run = MagicMock(return_value=sample_mlflow_run) + + result = fake_mlflow_tracking_client.get_job(str(job_id)) + + assert isinstance(result, JobResults) + assert result.id == job_id + assert result.metrics == [{"name": "accuracy", "value": 0.75}] + assert result.parameters == [{"name": "batch_size", "value": "32"}] + + fake_mlflow_tracking_client._client.get_run.assert_called_once_with(str(job_id)) + + +def test_get_job_deleted(fake_mlflow_tracking_client, fake_mlflow_run_deleted): + """Test fetching a deleted job returns None.""" + job_id = uuid.UUID("d34dbeef-1000-0000-0000-000000000000") + + fake_mlflow_tracking_client._client.get_run = MagicMock(return_value=fake_mlflow_run_deleted) + + with pytest.raises(RunNotFoundError): + fake_mlflow_tracking_client.get_job(str(job_id)) + + fake_mlflow_tracking_client._client.get_run.assert_called_once_with(str(job_id)) + + +def test_compile_metrics(fake_mlflow_tracking_client): + """Test metric compilation across multiple job runs.""" + job_ids = { + "job1": "d34dbeef-1000-0000-0000-000000000001", + "job2": "d34dbeef-1000-0000-0000-000000000002", + } + runs = { + job_ids["job1"]: Run( + run_info=RunInfo( + run_uuid=job_ids["job1"], + experiment_id="exp-1", + user_id="user", + status="FINISHED", + start_time=123456789, + end_time=None, + lifecycle_stage="active", + ), + run_data=RunData( + metrics=[Metric(key="accuracy", value=0.95, timestamp=123456789, step=0)], + params=[Param(key="other_thing", value="0.01")], + tags=[ + RunTag(key="mlflow.runName", value="Run1"), + ], + ), + ), + job_ids["job2"]: Run( + run_info=RunInfo( + run_uuid=job_ids["job2"], + experiment_id="exp-1", + user_id="user", + status="FINISHED", + start_time=123456789, + end_time=None, + lifecycle_stage="active", + ), + run_data=RunData( + metrics=[Metric(key="loss", value=0.2, timestamp=123456789, step=0)], + params=[Param(key="learning_rate", value="0.02")], + tags=[ + RunTag(key="mlflow.runName", value="Run2"), + ], + ), + ), + } + + fake_mlflow_tracking_client._client.get_run = MagicMock(side_effect=lambda job_id: runs[job_id]) + + result = fake_mlflow_tracking_client._compile_metrics([job_id for job_id in job_ids.values()]) + + assert result == {"loss": 0.2, "accuracy": 0.95} + fake_mlflow_tracking_client._client.get_run.assert_any_call(job_ids["job1"]) + fake_mlflow_tracking_client._client.get_run.assert_any_call(job_ids["job2"]) + + +def test_compile_metrics_conflict(fake_mlflow_tracking_client): + """Test metric conflict across job runs raises assertion error.""" + job_ids = { + "job1": "d34dbeef-1000-0000-0000-000000000001", + "job2": "d34dbeef-1000-0000-0000-000000000002", + } + runs = { + job_ids["job1"]: Run( + run_info=RunInfo( + run_uuid=job_ids["job1"], + experiment_id="exp-1", + user_id="user", + status="FINISHED", + start_time=123456789, + end_time=None, + lifecycle_stage="active", + artifact_uri="", + ), + run_data=RunData( + metrics=[Metric(key="accuracy", value=0.95, timestamp=123456789, step=0)], + params=[Param(key="other_thing", value="0.01")], + tags=[ + RunTag(key="description", value="A sample workflow"), + RunTag(key="mlflow.runName", value="Run1"), + RunTag(key="model", value="SampleModel"), + RunTag(key="system_prompt", value="Prompt text"), + RunTag(key="status", value="COMPLETED"), + ], + ), + ), + job_ids["job2"]: Run( + run_info=RunInfo( + run_uuid=job_ids["job2"], + experiment_id="exp-1", + user_id="user", + status="FINISHED", + start_time=123456789, + end_time=None, + lifecycle_stage="active", + artifact_uri="", + ), + run_data=RunData( + metrics=[Metric(key="accuracy", value=0.75, timestamp=123456789, step=0)], + params=[Param(key="learning_rate", value="0.02")], + tags=[ + RunTag(key="description", value="A sample workflow"), + RunTag(key="mlflow.runName", value="Run2"), + RunTag(key="model", value="SampleModel"), + RunTag(key="system_prompt", value="Prompt text"), + RunTag(key="status", value="COMPLETED"), + ], + ), + ), + } + + fake_mlflow_tracking_client._client.get_run = MagicMock(side_effect=lambda job_id: runs[job_id]) + + with pytest.raises(ValueError) as e: + fake_mlflow_tracking_client._compile_metrics([job_id for job_id in job_ids.values()]) + + assert str(e.value) == ( + "Duplicate metric 'accuracy' found in job 'd34dbeef-1000-0000-0000-000000000002'. " + "Stored value: 0.95, this value: 0.75" + ) + + +def test_compile_parameters(fake_mlflow_tracking_client): + """Test parameter compilation across multiple job runs.""" + job_ids = { + "job1": "d34dbeef-1000-0000-0000-000000000001", + "job2": "d34dbeef-1000-0000-0000-000000000002", + } + runs = { + job_ids["job1"]: Run( + run_info=RunInfo( + run_uuid=job_ids["job1"], + experiment_id="exp-1", + user_id="user", + status="FINISHED", + start_time=123456789, + end_time=None, + lifecycle_stage="active", + ), + run_data=RunData( + params=[Param(key="other_thing", value="0.01")], + tags=[ + RunTag(key="mlflow.runName", value="Run1"), + ], + ), + ), + job_ids["job2"]: Run( + run_info=RunInfo( + run_uuid=job_ids["job2"], + experiment_id="exp-1", + user_id="user", + status="FINISHED", + start_time=123456789, + end_time=None, + lifecycle_stage="active", + ), + run_data=RunData( + params=[Param(key="learning_rate", value="5")], + tags=[ + RunTag(key="mlflow.runName", value="Run2"), + ], + ), + ), + } + + fake_mlflow_tracking_client._client.get_run = MagicMock(side_effect=lambda job_id: runs[job_id]) + + result = fake_mlflow_tracking_client._compile_parameters([job_id for job_id in job_ids.values()]) + + assert result == { + "other_thing": { + "value": "0.01", + "jobs": { + "Run1": "0.01", + }, + }, + "learning_rate": { + "value": "5", + "jobs": { + "Run2": "5", + }, + }, + } + + +def test_compile_parameters_conflict(fake_mlflow_tracking_client): + """Test parameter conflicts result in no 'value' key being set.""" + job_ids = { + "job1": "d34dbeef-1000-0000-0000-000000000001", + "job2": "d34dbeef-1000-0000-0000-000000000002", + "job3": "d34dbeef-1000-0000-0000-000000000003", + } + runs = { + job_ids["job1"]: Run( + run_info=RunInfo( + run_uuid=job_ids["job1"], + experiment_id="exp-1", + user_id="user", + status="FINISHED", + start_time=123456789, + end_time=None, + lifecycle_stage="active", + ), + run_data=RunData( + params=[Param(key="other_thing", value="0.01"), Param(key="learning_rate", value="7")], + tags=[ + RunTag(key="mlflow.runName", value="Run1"), + ], + ), + ), + job_ids["job2"]: Run( + run_info=RunInfo( + run_uuid=job_ids["job2"], + experiment_id="exp-1", + user_id="user", + status="FINISHED", + start_time=123456789, + end_time=None, + lifecycle_stage="active", + ), + run_data=RunData( + params=[Param(key="learning_rate", value="5")], + tags=[ + RunTag(key="mlflow.runName", value="Run2"), + ], + ), + ), + job_ids["job3"]: Run( + run_info=RunInfo( + run_uuid=job_ids["job3"], + experiment_id="exp-1", + user_id="user", + status="FINISHED", + start_time=123456789, + end_time=None, + lifecycle_stage="active", + ), + run_data=RunData( + params=[Param(key="learning_rate", value="8")], + tags=[ + RunTag(key="mlflow.runName", value="Run3"), + ], + ), + ), + } + + fake_mlflow_tracking_client._client.get_run = MagicMock(side_effect=lambda job_id: runs[job_id]) + + result = fake_mlflow_tracking_client._compile_parameters([job_id for job_id in job_ids.values()]) + + assert result == { + "other_thing": { + "value": "0.01", + "jobs": { + "Run1": "0.01", + }, + }, + "learning_rate": { + "jobs": { + "Run1": "7", + "Run2": "5", + "Run3": "8", + }, + }, + } diff --git a/lumigator/backend/backend/tracking/mlflow.py b/lumigator/backend/backend/tracking/mlflow.py index 2281404c4..67ce1ed1e 100644 --- a/lumigator/backend/backend/tracking/mlflow.py +++ b/lumigator/backend/backend/tracking/mlflow.py @@ -4,6 +4,7 @@ from collections.abc import Generator from datetime import datetime from http import HTTPStatus +from typing import Any from urllib.parse import urljoin from uuid import UUID @@ -21,6 +22,7 @@ from s3fs import S3FileSystem from backend.services.exceptions.job_exceptions import JobNotFoundError, JobUpstreamError +from backend.services.exceptions.tracking_exceptions import RunNotFoundError from backend.settings import settings from backend.tracking.schemas import RunOutputs from backend.tracking.tracking_interface import TrackingClient @@ -91,33 +93,62 @@ def delete_experiment(self, experiment_id: str) -> None: # delete the experiment self._client.delete_experiment(experiment_id) - def _compile_metrics(self, job_ids: list) -> dict: - """Take the individual metrics from each run and compile them into a single dict - for now, assert that each run has no overlapping metrics + def _compile_metrics(self, job_ids: list) -> dict[str, float]: + """Aggregate metrics from job runs, ensuring no duplicate keys. + + :param job_ids: A list of job IDs to aggregate metrics from. + :return: A dictionary mapping metric names to their values. + :raises ValueError: If a duplicate metric is found across jobs. """ metrics = {} + for job_id in job_ids: run = self._client.get_run(job_id) - for metric in run.data.metrics: - assert metric not in metrics - metrics[metric] = run.data.metrics[metric] + for metric, value in run.data.metrics.items(): + if metric in metrics: + raise ValueError( + f"Duplicate metric '{metric}' found in job '{job_id}'. " + f"Stored value: {metrics[metric]}, this value: {value}" + ) + metrics[metric] = value return metrics - def _compile_parameters(self, job_ids: list) -> dict: - """Take the individual parameters from each run and compile them into a single dict - for now, assert that each run has no overlapping parameters + def _compile_parameters(self, job_ids: list) -> dict[str, dict[str, Any]]: + """Aggregate parameters across runs while associating each value with its specific run name/job ID. + + If all values for a parameter across jobs are the same, + the 'value' key will exist for that parameter with the shared value. + + :param job_ids: A list of job IDs to aggregate parameters from. + :return: A dictionary where each key is a parameter name, and the value is a dictionary containing a 'runs' + key mapping a run's parameter value. Also, when all values are the same, a 'value' key is present. + :raises MlflowException: If there is a problem getting the run data for a job. + :raises ValueError: If a run name is not found in the tags for a particular job. """ parameters = {} + value_key = "value" + jobs_key = "jobs" + for job_id in job_ids: run = self._client.get_run(job_id) - for parameter in run.data.params: - # if the parameter shows up in multiple runs, prepend the run_name to the key - # TODO: this is a hacky way to handle this, - # we should come up with a better solution but at least it keeps the info - if parameter in parameters: - parameters[f"{run.data.tags['mlflow.runName']}_{parameter}"] = run.data.params[parameter] - parameters[parameter] = run.data.params[parameter] + run_name = run.data.tags.get("mlflow.runName") + if not run_name: + raise ValueError(f"Cannot compile parameters, run name not found in tags for job: {job_id}") + + for param_key, param_value in run.data.params.items(): + param_entry = parameters.setdefault(param_key, {jobs_key: {}}) + + # Remove the shared value if it exists but doesn't match the current value. + if value_key in param_entry and param_entry[value_key] != param_value: + param_entry.pop(value_key, None) + # Add the shared value if we haven't added data for any runs yet + elif not bool(param_entry.get(jobs_key)): + param_entry[value_key] = param_value + + # Finally add the run-specific value. + param_entry[jobs_key][run_name] = param_value + return parameters def _find_workflows(self, experiment_id: str) -> list: @@ -243,9 +274,10 @@ async def get_workflow(self, workflow_id: str) -> WorkflowDetailsResponse | None experiment_ids=[workflow.info.experiment_id], filter_string=f"tags.{MLFLOW_PARENT_RUN_ID} = '{workflow_id}'", ) - all_job_ids = [run.info.run_id for run in all_jobs] + # sort the jobs by created_at, with the oldest last all_jobs = sorted(all_jobs, key=lambda x: x.info.start_time) + all_job_ids = [run.info.run_id for run in all_jobs] workflow_details = WorkflowDetailsResponse( id=workflow_id, @@ -269,7 +301,8 @@ async def get_workflow(self, workflow_id: str) -> WorkflowDetailsResponse | None # a presigned URL for that file # check if the compiled results already exist if not self._s3_file_system.exists(f"{settings.S3_BUCKET}/{workflow_id}/compiled.json"): - compiled_results = {"metrics": {}, "parameters": {}, "artifacts": {}} + compiled_results: dict[str, JobResultObject] = {} + for job in workflow_details.jobs: # look for all parameter keys that end in "_s3_path" and download the file for param in job.parameters: @@ -278,19 +311,12 @@ async def get_workflow(self, workflow_id: str) -> WorkflowDetailsResponse | None # get the file from the S3 bucket with self._s3_file_system.open(f"{param['value']}") as f: job_results = JobResultObject.model_validate(json.loads(f.read())) - # if any keys are the same, log a warning and then overwrite the key - for job_result_item in job_results: - if job_result_item[1] is None: - loguru.logger.info(f"No {job_result_item[0]} found for job {job.id}.") - continue - for key in job_result_item[1].keys(): - if key in compiled_results[job_result_item[0]]: - loguru.logger.warning(f"Key '{key}' already exists in the results. Overwriting.") - # merge the results into the compiled results - compiled_results[job_result_item[0]][key] = job_result_item[1][key] + compiled_results[job.id] = job_results + with self._s3_file_system.open(f"{settings.S3_BUCKET}/{workflow_id}/compiled.json", "w") as f: f.write(json.dumps(compiled_results)) - # Generate presigned download URL for the object + + # Generate presigned download URL for the object download_url = await self._s3_file_system.s3.generate_presigned_url( "get_object", Params={ @@ -391,17 +417,25 @@ def update_job(self, job_id: str, data: RunOutputs): for parameter, value in data.parameters.items(): self._client.log_param(job_id, parameter, value) - def get_job(self, job_id: str): - """Get the results of a job.""" + def get_job(self, job_id: str) -> JobResults: + """Get the results of a job. + + This method is used to get the metrics and parameters of a job. + + :param job_id: The ID of the job. + :return: The results of the job. + :raises RunNotFoundError: If the job is not found. + """ run = self._client.get_run(job_id) if run.info.lifecycle_stage == "deleted": - return None + raise RunNotFoundError(job_id, "deleted") + return JobResults( id=job_id, - metrics=[{"name": metric[0], "value": metric[1]} for metric in run.data.metrics.items()], - parameters=[{"name": param[0], "value": param[1]} for param in run.data.params.items()], - metric_url="TODO", - artifact_url="TODO", + metrics=[{"name": key, "value": value} for key, value in run.data.metrics.items()], + parameters=[{"name": key, "value": value} for key, value in run.data.params.items()], + metric_url="", # TODO: Implement + artifact_url="", # TODO: Implement ) def delete_job(self, job_id: str): diff --git a/lumigator/backend/backend/tracking/tracking_interface.py b/lumigator/backend/backend/tracking/tracking_interface.py index 4b142b1e8..2c6614e79 100644 --- a/lumigator/backend/backend/tracking/tracking_interface.py +++ b/lumigator/backend/backend/tracking/tracking_interface.py @@ -89,7 +89,7 @@ def update_workflow(self, workflow_id: str, data: RunOutputs): """Update the outputs of a workflow""" ... - def get_job(self, job_id: str) -> JobResults | None: + def get_job(self, job_id: str) -> JobResults: """Get a job.""" ... From 2c07d0dd1b112efb34d523fdc3ae5b39dbc32fd1 Mon Sep 17 00:00:00 2001 From: Peter Wilson Date: Tue, 25 Mar 2025 08:33:43 +0000 Subject: [PATCH 2/3] Small tweaks --- .../backend/tests/unit/tracking/test_mlflow.py | 11 +++++++++-- lumigator/backend/backend/tracking/mlflow.py | 2 +- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/lumigator/backend/backend/tests/unit/tracking/test_mlflow.py b/lumigator/backend/backend/tests/unit/tracking/test_mlflow.py index 0b11d9b0c..473d4e7c3 100644 --- a/lumigator/backend/backend/tests/unit/tracking/test_mlflow.py +++ b/lumigator/backend/backend/tests/unit/tracking/test_mlflow.py @@ -238,7 +238,10 @@ def test_compile_parameters_conflict(fake_mlflow_tracking_client): lifecycle_stage="active", ), run_data=RunData( - params=[Param(key="other_thing", value="0.01"), Param(key="learning_rate", value="7")], + params=[ + Param(key="other_thing", value="0.01"), + Param(key="learning_rate", value="7"), + ], tags=[ RunTag(key="mlflow.runName", value="Run1"), ], @@ -255,7 +258,10 @@ def test_compile_parameters_conflict(fake_mlflow_tracking_client): lifecycle_stage="active", ), run_data=RunData( - params=[Param(key="learning_rate", value="5")], + params=[ + Param(key="other_thing", value="0.01"), + Param(key="learning_rate", value="5"), + ], tags=[ RunTag(key="mlflow.runName", value="Run2"), ], @@ -289,6 +295,7 @@ def test_compile_parameters_conflict(fake_mlflow_tracking_client): "value": "0.01", "jobs": { "Run1": "0.01", + "Run2": "0.01", }, }, "learning_rate": { diff --git a/lumigator/backend/backend/tracking/mlflow.py b/lumigator/backend/backend/tracking/mlflow.py index 67ce1ed1e..f19a9d57e 100644 --- a/lumigator/backend/backend/tracking/mlflow.py +++ b/lumigator/backend/backend/tracking/mlflow.py @@ -418,7 +418,7 @@ def update_job(self, job_id: str, data: RunOutputs): self._client.log_param(job_id, parameter, value) def get_job(self, job_id: str) -> JobResults: - """Get the results of a job. + """Get the results of a job (known as a Run in MLFlow). This method is used to get the metrics and parameters of a job. From 873466158fe4ced584aa37043e066d26667d1277 Mon Sep 17 00:00:00 2001 From: Peter Wilson Date: Tue, 25 Mar 2025 20:08:38 +0000 Subject: [PATCH 3/3] Ignore jobs that are not found, or deleted. Fix compiled.json serialization --- lumigator/backend/backend/tracking/mlflow.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/lumigator/backend/backend/tracking/mlflow.py b/lumigator/backend/backend/tracking/mlflow.py index f19a9d57e..ae9c1ad96 100644 --- a/lumigator/backend/backend/tracking/mlflow.py +++ b/lumigator/backend/backend/tracking/mlflow.py @@ -1,3 +1,4 @@ +import asyncio import contextlib import http import json @@ -180,7 +181,11 @@ async def get_experiment(self, experiment_id: str) -> GetExperimentResponse | No async def _format_experiment(self, experiment: MlflowExperiment) -> GetExperimentResponse: # now get all the workflows associated with that experiment workflow_ids = self._find_workflows(experiment.experiment_id) - workflows = [await self.get_workflow(workflow_id) for workflow_id in workflow_ids] + workflows = [ + workflow + for workflow in await asyncio.gather(*(self.get_workflow(wid) for wid in workflow_ids)) + if workflow is not None + ] task_definition_json = experiment.tags.get("task_definition") task_definition = TypeAdapter(TaskDefinition).validate_python(json.loads(task_definition_json)) return GetExperimentResponse( @@ -311,7 +316,7 @@ async def get_workflow(self, workflow_id: str) -> WorkflowDetailsResponse | None # get the file from the S3 bucket with self._s3_file_system.open(f"{param['value']}") as f: job_results = JobResultObject.model_validate(json.loads(f.read())) - compiled_results[job.id] = job_results + compiled_results[str(job.id)] = job_results.model_dump() with self._s3_file_system.open(f"{settings.S3_BUCKET}/{workflow_id}/compiled.json", "w") as f: f.write(json.dumps(compiled_results))