Skip to content

Commit 8b07c31

Browse files
committed
Fix async issues after moving to S3FileSystem from boto3 client
* Update the Tracking Client interface to be async for all methods by default * Update MLFlow Tracking Client to adhere to the updated interface * Bubble up async calls through services and routes (workflows, experiments) * Configure S3FileSystem used for tests (via MagicMock) to emulate the correct params for 'storage_options' * Remove un-required (s3 related) dependencies, and update dependencies from backend's pyproject.toml * Fix: Add 'integration' pytest marker to test config to prevent warnings being emitted during test runs * Chore: uv.lock updated
1 parent 0961539 commit 8b07c31

File tree

12 files changed

+94
-352
lines changed

12 files changed

+94
-352
lines changed

lumigator/backend/backend/api/routes/experiments.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ def experiment_exception_mappings() -> dict[type[ServiceError], HTTPStatus]:
2121

2222

2323
@router.post("/", status_code=status.HTTP_201_CREATED)
24-
def create_experiment_id(service: ExperimentServiceDep, request: ExperimentCreate) -> GetExperimentResponse:
24+
async def create_experiment_id(service: ExperimentServiceDep, request: ExperimentCreate) -> GetExperimentResponse:
2525
"""Create an experiment ID."""
26-
return GetExperimentResponse.model_validate(service.create_experiment(request).model_dump())
26+
experiment = await service.create_experiment(request)
27+
return GetExperimentResponse.model_validate(experiment.model_dump())
2728

2829

2930
@router.get("/{experiment_id}")
@@ -45,6 +46,6 @@ async def list_experiments(
4546

4647

4748
@router.delete("/{experiment_id}")
48-
def delete_experiment(service: ExperimentServiceDep, experiment_id: str) -> None:
49+
async def delete_experiment(service: ExperimentServiceDep, experiment_id: str) -> None:
4950
"""Delete an experiment by ID."""
50-
service.delete_experiment(experiment_id)
51+
await service.delete_experiment(experiment_id)

lumigator/backend/backend/api/routes/workflows.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,14 @@ async def get_workflow(service: WorkflowServiceDep, workflow_id: str) -> Workflo
4444

4545
# get the logs
4646
@router.get("/{workflow_id}/logs")
47-
def get_workflow_logs(service: WorkflowServiceDep, workflow_id: str) -> JobLogsResponse:
47+
async def get_workflow_logs(service: WorkflowServiceDep, workflow_id: str) -> JobLogsResponse:
4848
"""Get the logs for a workflow."""
49-
return JobLogsResponse.model_validate(service.get_workflow_logs(workflow_id).model_dump())
49+
logs = await service.get_workflow_logs(workflow_id)
50+
return JobLogsResponse.model_validate(logs.model_dump())
5051

5152

5253
@router.get("/{workflow_id}/result/download")
53-
def get_workflow_result_download(
54+
async def get_workflow_result_download(
5455
service: WorkflowServiceDep,
5556
workflow_id: str,
5657
) -> str:
@@ -60,17 +61,20 @@ def get_workflow_result_download(
6061
service: Workflow service dependency
6162
workflow_id: ID of the workflow whose results will be returned
6263
"""
63-
return service.get_workflow_result_download(workflow_id)
64+
return await service.get_workflow_result_download(workflow_id)
6465

6566

6667
# delete a workflow
6768
@router.delete("/{workflow_id}")
68-
def delete_workflow(service: WorkflowServiceDep, workflow_id: str, force: bool = False) -> WorkflowDetailsResponse:
69+
async def delete_workflow(
70+
service: WorkflowServiceDep, workflow_id: str, force: bool = False
71+
) -> WorkflowDetailsResponse:
6972
"""Delete a workflow by ID.
7073
7174
Args:
7275
service: Workflow service dependency
7376
workflow_id: ID of the workflow to delete
7477
force: If True, force deletion even if the workflow is active or has dependencies
7578
"""
76-
return WorkflowDetailsResponse.model_validate(service.delete_workflow(workflow_id, force=force).model_dump())
79+
result = await service.delete_workflow(workflow_id, force=force)
80+
return WorkflowDetailsResponse.model_validate(result.model_dump())

lumigator/backend/backend/services/datasets.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ def _save_dataset_to_s3(self, temp_fname, record):
140140
# Upload to S3
141141
dataset_key = self._get_s3_key(record.id, record.filename)
142142
dataset_path = self._get_s3_path(dataset_key)
143-
# Deprecated!!!
144143
dataset_hf.save_to_disk(dataset_path, storage_options=self.s3_filesystem.storage_options)
145144

146145
# Use the converted HF format files to rebuild the CSV and store it as 'dataset.csv'.

lumigator/backend/backend/services/experiments.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ def __init__(
2525
self._dataset_service = dataset_service
2626
self._tracking_session = tracking_session
2727

28-
def create_experiment(self, request: ExperimentCreate) -> GetExperimentResponse:
29-
experiment = self._tracking_session.create_experiment(
28+
async def create_experiment(self, request: ExperimentCreate) -> GetExperimentResponse:
29+
experiment = await self._tracking_session.create_experiment(
3030
request.name,
3131
request.description,
3232
request.task_definition,
@@ -50,5 +50,5 @@ async def list_experiments(self, skip: int, limit: int) -> ListingResponse[GetEx
5050
items=[GetExperimentResponse.model_validate(x) for x in records],
5151
)
5252

53-
def delete_experiment(self, experiment_id: str):
54-
self._tracking_session.delete_experiment(experiment_id)
53+
async def delete_experiment(self, experiment_id: str):
54+
await self._tracking_session.delete_experiment(experiment_id)

lumigator/backend/backend/services/workflows.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
WorkflowStatus,
2323
)
2424
from pydantic_core._pydantic_core import ValidationError
25+
from typing_extensions import deprecated
2526

2627
from backend.repositories.jobs import JobRepository
2728
from backend.services.datasets import DatasetService
@@ -92,12 +93,12 @@ async def _handle_workflow_failure(self, workflow_id: str):
9293
loguru.logger.error("Workflow failed: {} ... updating status and stopping jobs", workflow_id)
9394

9495
# Mark the workflow as failed.
95-
self._tracking_client.update_workflow_status(workflow_id, WorkflowStatus.FAILED)
96+
await self._tracking_client.update_workflow_status(workflow_id, WorkflowStatus.FAILED)
9697

9798
# Get the list of jobs in the workflow to stop any that are still running.
9899
stop_tasks = [
99100
self._job_service.stop_job(UUID(ray_job_id))
100-
for job in self._tracking_client.list_jobs(workflow_id)
101+
for job in await self._tracking_client.list_jobs(workflow_id)
101102
if (ray_job_id := job.data.params.get("ray_job_id"))
102103
]
103104
# Wait for all stop tasks to complete concurrently
@@ -147,8 +148,8 @@ async def _run_inference_eval_pipeline(
147148
return
148149

149150
# Track the workflow status as running and add the inference job.
150-
self._tracking_client.update_workflow_status(workflow.id, WorkflowStatus.RUNNING)
151-
inference_run_id = self._tracking_client.create_job(
151+
await self._tracking_client.update_workflow_status(workflow.id, WorkflowStatus.RUNNING)
152+
inference_run_id = await self._tracking_client.create_job(
152153
request.experiment_id, workflow.id, "inference", inference_job.id
153154
)
154155

@@ -228,7 +229,7 @@ async def _run_inference_eval_pipeline(
228229
metrics=inf_output.metrics,
229230
ray_job_id=str(inference_job.id),
230231
)
231-
self._tracking_client.update_job(inference_run_id, inference_job_output)
232+
await self._tracking_client.update_job(inference_run_id, inference_job_output)
232233
except Exception as e:
233234
loguru.logger.error(
234235
"Workflow pipeline error: Workflow {}. Inference job: {}. Cannot update DB with with result data: {}",
@@ -272,7 +273,7 @@ async def _run_inference_eval_pipeline(
272273
return
273274

274275
# Track the evaluation job.
275-
eval_run_id = self._tracking_client.create_job(
276+
eval_run_id = await self._tracking_client.create_job(
276277
request.experiment_id, workflow.id, "evaluation", evaluation_job.id
277278
)
278279

@@ -323,9 +324,9 @@ async def _run_inference_eval_pipeline(
323324
parameters={"eval_output_s3_path": f"{settings.S3_BUCKET}/{result_key}"},
324325
ray_job_id=str(evaluation_job.id),
325326
)
326-
self._tracking_client.update_job(eval_run_id, outputs)
327-
self._tracking_client.update_workflow_status(workflow.id, WorkflowStatus.SUCCEEDED)
328-
self._tracking_client.get_workflow(workflow.id)
327+
await self._tracking_client.update_job(eval_run_id, outputs)
328+
await self._tracking_client.update_workflow_status(workflow.id, WorkflowStatus.SUCCEEDED)
329+
await self._tracking_client.get_workflow(workflow.id)
329330
except Exception as e:
330331
loguru.logger.error(
331332
"Workflow pipeline error: Workflow {}. Evaluation job: {} Error validating results: {}",
@@ -336,13 +337,13 @@ async def _run_inference_eval_pipeline(
336337
await self._handle_workflow_failure(workflow.id)
337338
return
338339

339-
def get_workflow_result_download(self, workflow_id: str) -> str:
340+
async def get_workflow_result_download(self, workflow_id: str) -> str:
340341
"""Return workflow results file URL for downloading.
341342
342343
Args:
343344
workflow_id: ID of the workflow whose results will be returned
344345
"""
345-
workflow_details = self.get_workflow(workflow_id)
346+
workflow_details = await self.get_workflow(workflow_id)
346347
if workflow_details.artifacts_download_url:
347348
return workflow_details.artifacts_download_url
348349
else:
@@ -391,7 +392,7 @@ async def create_workflow(self, request: WorkflowCreateRequest) -> WorkflowRespo
391392
)
392393
request.system_prompt = default_system_prompt
393394

394-
workflow = self._tracking_client.create_workflow(
395+
workflow = await self._tracking_client.create_workflow(
395396
experiment_id=request.experiment_id,
396397
description=request.description,
397398
name=request.name,
@@ -406,17 +407,18 @@ async def create_workflow(self, request: WorkflowCreateRequest) -> WorkflowRespo
406407

407408
return workflow
408409

409-
def delete_workflow(self, workflow_id: str, force: bool) -> WorkflowResponse:
410+
async def delete_workflow(self, workflow_id: str, force: bool) -> WorkflowResponse:
410411
"""Delete a workflow by ID."""
411412
# if the workflow is running, we should throw an error
412413
workflow = self.get_workflow(workflow_id)
413414
if workflow.status == WorkflowStatus.RUNNING and not force:
414415
raise WorkflowValidationError("Cannot delete a running workflow")
415-
return self._tracking_client.delete_workflow(workflow_id)
416+
return await self._tracking_client.delete_workflow(workflow_id)
416417

417-
def get_workflow_logs(self, workflow_id: str) -> JobLogsResponse:
418+
@deprecated("get_workflow_logs is deprecated, it will be removed in future versions.")
419+
async def get_workflow_logs(self, workflow_id: str) -> JobLogsResponse:
418420
"""Get the logs for a workflow."""
419-
job_list = self._tracking_client.list_jobs(workflow_id)
421+
job_list = await self._tracking_client.list_jobs(workflow_id)
420422
# sort the jobs by created_at, with the oldest last
421423
job_list = sorted(job_list, key=lambda x: x.info.start_time)
422424
all_ray_job_ids = [run.data.params.get("ray_job_id") for run in job_list]

lumigator/backend/backend/tests/conftest.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,14 +277,23 @@ def boto_s3fs() -> Generator[S3FileSystem, None, None]:
277277
aws_endpoint_url = os.environ.get("AWS_ENDPOINT_URL", "http://localhost:9000")
278278
aws_default_region = os.environ.get("AWS_DEFAULT_REGION", "us-east-2")
279279

280+
# Mock the S3 'storage_options' property to match the real client.
280281
s3fs = S3FileSystem(
281282
key=aws_access_key_id,
282283
secret=aws_secret_access_key,
283284
endpoint_url=aws_endpoint_url,
284285
client_kwargs={"region_name": aws_default_region},
285286
)
286287

287-
mock_s3fs = MagicMock(wraps=s3fs, storage_options={"endpoint_url": aws_endpoint_url})
288+
mock_s3fs = MagicMock(
289+
wraps=s3fs,
290+
storage_options={
291+
"client_kwargs": {"region_name": aws_default_region},
292+
"key": aws_access_key_id,
293+
"secret": aws_secret_access_key,
294+
"endpoint_url": aws_endpoint_url,
295+
},
296+
)
288297

289298
yield mock_s3fs
290299
logger.info(f"intercepted s3fs calls: {str(mock_s3fs.mock_calls)}")

lumigator/backend/backend/tracking/mlflow.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(self, tracking_uri: str, s3_file_system: S3FileSystem):
3333
self._client = MlflowClient(tracking_uri=tracking_uri)
3434
self._s3_file_system = s3_file_system
3535

36-
def create_experiment(
36+
async def create_experiment(
3737
self,
3838
name: str,
3939
description: str,
@@ -79,15 +79,15 @@ def create_experiment(
7979
created_at=datetime.fromtimestamp(experiment.creation_time / 1000),
8080
)
8181

82-
def delete_experiment(self, experiment_id: str) -> None:
82+
async def delete_experiment(self, experiment_id: str) -> None:
8383
"""Delete an experiment. Although Mflow has a delete_experiment method,
8484
We will use the functions of this class instead, so that we make sure we correctly
8585
clean up all the artifacts/runs/etc. associated with the experiment.
8686
"""
8787
workflow_ids = self._find_workflows(experiment_id)
8888
# delete all the workflows
8989
for workflow_id in workflow_ids:
90-
self.delete_workflow(workflow_id)
90+
await self.delete_workflow(workflow_id)
9191
# delete the experiment
9292
self._client.delete_experiment(experiment_id)
9393

@@ -164,7 +164,7 @@ async def _format_experiment(self, experiment: MlflowExperiment) -> GetExperimen
164164
workflows=workflows,
165165
)
166166

167-
def update_experiment(self, experiment_id: str, new_name: str) -> None:
167+
async def update_experiment(self, experiment_id: str, new_name: str) -> None:
168168
"""Update the name of an experiment."""
169169
raise NotImplementedError
170170

@@ -199,7 +199,7 @@ async def experiments_count(self):
199199
# this corresponds to creating a run in MLflow.
200200
# The run will have n number of nested runs,
201201
# which correspond to what we call "jobs" in our system
202-
def create_workflow(
202+
async def create_workflow(
203203
self, experiment_id: str, description: str, name: str, model: str, system_prompt: str
204204
) -> WorkflowResponse:
205205
"""Create a new workflow."""
@@ -256,7 +256,7 @@ async def get_workflow(self, workflow_id: str) -> WorkflowDetailsResponse | None
256256
system_prompt=workflow.data.tags.get("system_prompt"),
257257
status=WorkflowStatus(workflow.data.tags.get("status")),
258258
created_at=datetime.fromtimestamp(workflow.info.start_time / 1000),
259-
jobs=[self.get_job(job_id) for job_id in all_job_ids],
259+
jobs=[await self.get_job(job_id) for job_id in all_job_ids],
260260
metrics=self._compile_metrics(all_job_ids),
261261
parameters=self._compile_parameters(all_job_ids),
262262
)
@@ -302,7 +302,7 @@ async def get_workflow(self, workflow_id: str) -> WorkflowDetailsResponse | None
302302
workflow_details.artifacts_download_url = download_url
303303
return workflow_details
304304

305-
def update_workflow_status(self, workflow_id: str, status: WorkflowStatus) -> None:
305+
async def update_workflow_status(self, workflow_id: str, status: WorkflowStatus) -> None:
306306
"""Update the status of a workflow."""
307307
self._client.set_tag(workflow_id, "status", status.value)
308308

@@ -328,7 +328,7 @@ def _get_ray_job_logs(self, ray_job_id: str):
328328
loguru.logger.error(f"Response text: {resp.text}")
329329
raise JobUpstreamError(ray_job_id, "JSON decode error in Ray response") from e
330330

331-
def get_workflow_logs(self, workflow_id: str) -> JobLogsResponse:
331+
async def get_workflow_logs(self, workflow_id: str) -> JobLogsResponse:
332332
workflow_run = self._client.get_run(workflow_id)
333333
# get the jobs associated with the workflow
334334
all_jobs = self._client.search_runs(
@@ -343,7 +343,7 @@ def get_workflow_logs(self, workflow_id: str) -> JobLogsResponse:
343343
# TODO: This is not a great solution but it matches the current API
344344
return JobLogsResponse(logs="\n================\n".join([log.logs for log in logs]))
345345

346-
def delete_workflow(self, workflow_id: str) -> WorkflowResponse:
346+
async def delete_workflow(self, workflow_id: str) -> WorkflowResponse:
347347
"""Delete a workflow."""
348348
# first, get the workflow
349349
workflow = self._client.get_run(workflow_id)
@@ -370,11 +370,11 @@ def delete_workflow(self, workflow_id: str) -> WorkflowResponse:
370370
created_at=datetime.fromtimestamp(workflow.info.start_time / 1000),
371371
)
372372

373-
def list_workflows(self, experiment_id: str) -> list:
373+
async def list_workflows(self, experiment_id: str) -> list:
374374
"""List all workflows in an experiment."""
375375
raise NotImplementedError
376376

377-
def create_job(self, experiment_id: str, workflow_id: str, name: str, job_id: str) -> str:
377+
async def create_job(self, experiment_id: str, workflow_id: str, name: str, job_id: str) -> str:
378378
"""Link a started job to an experiment and a workflow."""
379379
run = self._client.create_run(
380380
experiment_id=experiment_id,
@@ -384,14 +384,14 @@ def create_job(self, experiment_id: str, workflow_id: str, name: str, job_id: st
384384
self._client.log_param(run.info.run_id, "ray_job_id", job_id)
385385
return run.info.run_id
386386

387-
def update_job(self, job_id: str, data: RunOutputs):
387+
async def update_job(self, job_id: str, data: RunOutputs):
388388
"""Update the metrics and parameters of a job."""
389389
for metric, value in data.metrics.items():
390390
self._client.log_metric(job_id, metric, value)
391391
for parameter, value in data.parameters.items():
392392
self._client.log_param(job_id, parameter, value)
393393

394-
def get_job(self, job_id: str):
394+
async def get_job(self, job_id: str):
395395
"""Get the results of a job."""
396396
run = self._client.get_run(job_id)
397397
if run.info.lifecycle_stage == "deleted":
@@ -404,11 +404,11 @@ def get_job(self, job_id: str):
404404
artifact_url="TODO",
405405
)
406406

407-
def delete_job(self, job_id: str):
407+
async def delete_job(self, job_id: str):
408408
"""Delete a job."""
409409
self._client.delete_run(job_id)
410410

411-
def list_jobs(self, workflow_id: str):
411+
async def list_jobs(self, workflow_id: str):
412412
"""List all jobs in a workflow."""
413413
workflow_run = self._client.get_run(workflow_id)
414414
# get the jobs associated with the workflow

0 commit comments

Comments
 (0)