Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions lumigator/backend/backend/api/routes/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ def experiment_exception_mappings() -> dict[type[ServiceError], HTTPStatus]:


@router.post("/", status_code=status.HTTP_201_CREATED)
def create_experiment_id(service: ExperimentServiceDep, request: ExperimentCreate) -> GetExperimentResponse:
async def create_experiment_id(service: ExperimentServiceDep, request: ExperimentCreate) -> GetExperimentResponse:
"""Create an experiment ID."""
return GetExperimentResponse.model_validate(service.create_experiment(request).model_dump())
experiment = await service.create_experiment(request)
return GetExperimentResponse.model_validate(experiment.model_dump())


@router.get("/{experiment_id}")
Expand All @@ -45,6 +46,6 @@ async def list_experiments(


@router.delete("/{experiment_id}")
def delete_experiment(service: ExperimentServiceDep, experiment_id: str) -> None:
async def delete_experiment(service: ExperimentServiceDep, experiment_id: str) -> None:
"""Delete an experiment by ID."""
service.delete_experiment(experiment_id)
await service.delete_experiment(experiment_id)
16 changes: 10 additions & 6 deletions lumigator/backend/backend/api/routes/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,14 @@ async def get_workflow(service: WorkflowServiceDep, workflow_id: str) -> Workflo

# get the logs
@router.get("/{workflow_id}/logs")
def get_workflow_logs(service: WorkflowServiceDep, workflow_id: str) -> JobLogsResponse:
async def get_workflow_logs(service: WorkflowServiceDep, workflow_id: str) -> JobLogsResponse:
"""Get the logs for a workflow."""
return JobLogsResponse.model_validate(service.get_workflow_logs(workflow_id).model_dump())
logs = await service.get_workflow_logs(workflow_id)
return JobLogsResponse.model_validate(logs.model_dump())


@router.get("/{workflow_id}/result/download")
def get_workflow_result_download(
async def get_workflow_result_download(
service: WorkflowServiceDep,
workflow_id: str,
) -> str:
Expand All @@ -60,17 +61,20 @@ def get_workflow_result_download(
service: Workflow service dependency
workflow_id: ID of the workflow whose results will be returned
"""
return service.get_workflow_result_download(workflow_id)
return await service.get_workflow_result_download(workflow_id)


# delete a workflow
@router.delete("/{workflow_id}")
def delete_workflow(service: WorkflowServiceDep, workflow_id: str, force: bool = False) -> WorkflowDetailsResponse:
async def delete_workflow(
service: WorkflowServiceDep, workflow_id: str, force: bool = False
) -> WorkflowDetailsResponse:
"""Delete a workflow by ID.

Args:
service: Workflow service dependency
workflow_id: ID of the workflow to delete
force: If True, force deletion even if the workflow is active or has dependencies
"""
return WorkflowDetailsResponse.model_validate(service.delete_workflow(workflow_id, force=force).model_dump())
result = await service.delete_workflow(workflow_id, force=force)
return WorkflowDetailsResponse.model_validate(result.model_dump())
1 change: 0 additions & 1 deletion lumigator/backend/backend/services/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ def _save_dataset_to_s3(self, temp_fname, record):
# Upload to S3
dataset_key = self._get_s3_key(record.id, record.filename)
dataset_path = self._get_s3_path(dataset_key)
# Deprecated!!!
dataset_hf.save_to_disk(dataset_path, storage_options=self.s3_filesystem.storage_options)

# Use the converted HF format files to rebuild the CSV and store it as 'dataset.csv'.
Expand Down
8 changes: 4 additions & 4 deletions lumigator/backend/backend/services/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def __init__(
self._dataset_service = dataset_service
self._tracking_session = tracking_session

def create_experiment(self, request: ExperimentCreate) -> GetExperimentResponse:
experiment = self._tracking_session.create_experiment(
async def create_experiment(self, request: ExperimentCreate) -> GetExperimentResponse:
experiment = await self._tracking_session.create_experiment(
request.name,
request.description,
request.task_definition,
Expand All @@ -50,5 +50,5 @@ async def list_experiments(self, skip: int, limit: int) -> ListingResponse[GetEx
items=[GetExperimentResponse.model_validate(x) for x in records],
)

def delete_experiment(self, experiment_id: str):
self._tracking_session.delete_experiment(experiment_id)
async def delete_experiment(self, experiment_id: str):
await self._tracking_session.delete_experiment(experiment_id)
34 changes: 18 additions & 16 deletions lumigator/backend/backend/services/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
WorkflowStatus,
)
from pydantic_core._pydantic_core import ValidationError
from typing_extensions import deprecated

from backend.repositories.jobs import JobRepository
from backend.services.datasets import DatasetService
Expand Down Expand Up @@ -92,12 +93,12 @@ async def _handle_workflow_failure(self, workflow_id: str):
loguru.logger.error("Workflow failed: {} ... updating status and stopping jobs", workflow_id)

# Mark the workflow as failed.
self._tracking_client.update_workflow_status(workflow_id, WorkflowStatus.FAILED)
await self._tracking_client.update_workflow_status(workflow_id, WorkflowStatus.FAILED)

# Get the list of jobs in the workflow to stop any that are still running.
stop_tasks = [
self._job_service.stop_job(UUID(ray_job_id))
for job in self._tracking_client.list_jobs(workflow_id)
for job in await self._tracking_client.list_jobs(workflow_id)
if (ray_job_id := job.data.params.get("ray_job_id"))
]
# Wait for all stop tasks to complete concurrently
Expand Down Expand Up @@ -147,8 +148,8 @@ async def _run_inference_eval_pipeline(
return

# Track the workflow status as running and add the inference job.
self._tracking_client.update_workflow_status(workflow.id, WorkflowStatus.RUNNING)
inference_run_id = self._tracking_client.create_job(
await self._tracking_client.update_workflow_status(workflow.id, WorkflowStatus.RUNNING)
inference_run_id = await self._tracking_client.create_job(
request.experiment_id, workflow.id, "inference", inference_job.id
)

Expand Down Expand Up @@ -228,7 +229,7 @@ async def _run_inference_eval_pipeline(
metrics=inf_output.metrics,
ray_job_id=str(inference_job.id),
)
self._tracking_client.update_job(inference_run_id, inference_job_output)
await self._tracking_client.update_job(inference_run_id, inference_job_output)
except Exception as e:
loguru.logger.error(
"Workflow pipeline error: Workflow {}. Inference job: {}. Cannot update DB with with result data: {}",
Expand Down Expand Up @@ -272,7 +273,7 @@ async def _run_inference_eval_pipeline(
return

# Track the evaluation job.
eval_run_id = self._tracking_client.create_job(
eval_run_id = await self._tracking_client.create_job(
request.experiment_id, workflow.id, "evaluation", evaluation_job.id
)

Expand Down Expand Up @@ -323,9 +324,9 @@ async def _run_inference_eval_pipeline(
parameters={"eval_output_s3_path": f"{settings.S3_BUCKET}/{result_key}"},
ray_job_id=str(evaluation_job.id),
)
self._tracking_client.update_job(eval_run_id, outputs)
self._tracking_client.update_workflow_status(workflow.id, WorkflowStatus.SUCCEEDED)
self._tracking_client.get_workflow(workflow.id)
await self._tracking_client.update_job(eval_run_id, outputs)
await self._tracking_client.update_workflow_status(workflow.id, WorkflowStatus.SUCCEEDED)
await self._tracking_client.get_workflow(workflow.id)
except Exception as e:
loguru.logger.error(
"Workflow pipeline error: Workflow {}. Evaluation job: {} Error validating results: {}",
Expand All @@ -336,13 +337,13 @@ async def _run_inference_eval_pipeline(
await self._handle_workflow_failure(workflow.id)
return

def get_workflow_result_download(self, workflow_id: str) -> str:
async def get_workflow_result_download(self, workflow_id: str) -> str:
"""Return workflow results file URL for downloading.

Args:
workflow_id: ID of the workflow whose results will be returned
"""
workflow_details = self.get_workflow(workflow_id)
workflow_details = await self.get_workflow(workflow_id)
if workflow_details.artifacts_download_url:
return workflow_details.artifacts_download_url
else:
Expand Down Expand Up @@ -391,7 +392,7 @@ async def create_workflow(self, request: WorkflowCreateRequest) -> WorkflowRespo
)
request.system_prompt = default_system_prompt

workflow = self._tracking_client.create_workflow(
workflow = await self._tracking_client.create_workflow(
experiment_id=request.experiment_id,
description=request.description,
name=request.name,
Expand All @@ -406,17 +407,18 @@ async def create_workflow(self, request: WorkflowCreateRequest) -> WorkflowRespo

return workflow

def delete_workflow(self, workflow_id: str, force: bool) -> WorkflowResponse:
async def delete_workflow(self, workflow_id: str, force: bool) -> WorkflowResponse:
"""Delete a workflow by ID."""
# if the workflow is running, we should throw an error
workflow = self.get_workflow(workflow_id)
if workflow.status == WorkflowStatus.RUNNING and not force:
raise WorkflowValidationError("Cannot delete a running workflow")
return self._tracking_client.delete_workflow(workflow_id)
return await self._tracking_client.delete_workflow(workflow_id)

def get_workflow_logs(self, workflow_id: str) -> JobLogsResponse:
@deprecated("get_workflow_logs is deprecated, it will be removed in future versions.")
async def get_workflow_logs(self, workflow_id: str) -> JobLogsResponse:
"""Get the logs for a workflow."""
job_list = self._tracking_client.list_jobs(workflow_id)
job_list = await self._tracking_client.list_jobs(workflow_id)
# sort the jobs by created_at, with the oldest last
job_list = sorted(job_list, key=lambda x: x.info.start_time)
all_ray_job_ids = [run.data.params.get("ray_job_id") for run in job_list]
Expand Down
11 changes: 10 additions & 1 deletion lumigator/backend/backend/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,14 +277,23 @@ def boto_s3fs() -> Generator[S3FileSystem, None, None]:
aws_endpoint_url = os.environ.get("AWS_ENDPOINT_URL", "http://localhost:9000")
aws_default_region = os.environ.get("AWS_DEFAULT_REGION", "us-east-2")

# Mock the S3 'storage_options' property to match the real client.
s3fs = S3FileSystem(
key=aws_access_key_id,
secret=aws_secret_access_key,
endpoint_url=aws_endpoint_url,
client_kwargs={"region_name": aws_default_region},
)

mock_s3fs = MagicMock(wraps=s3fs, storage_options={"endpoint_url": aws_endpoint_url})
mock_s3fs = MagicMock(
wraps=s3fs,
storage_options={
"client_kwargs": {"region_name": aws_default_region},
"key": aws_access_key_id,
"secret": aws_secret_access_key,
"endpoint_url": aws_endpoint_url,
},
)

yield mock_s3fs
logger.info(f"intercepted s3fs calls: {str(mock_s3fs.mock_calls)}")
Expand Down
30 changes: 15 additions & 15 deletions lumigator/backend/backend/tracking/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, tracking_uri: str, s3_file_system: S3FileSystem):
self._client = MlflowClient(tracking_uri=tracking_uri)
self._s3_file_system = s3_file_system

def create_experiment(
async def create_experiment(
self,
name: str,
description: str,
Expand Down Expand Up @@ -79,15 +79,15 @@ def create_experiment(
created_at=datetime.fromtimestamp(experiment.creation_time / 1000),
)

def delete_experiment(self, experiment_id: str) -> None:
async def delete_experiment(self, experiment_id: str) -> None:
"""Delete an experiment. Although Mflow has a delete_experiment method,
We will use the functions of this class instead, so that we make sure we correctly
clean up all the artifacts/runs/etc. associated with the experiment.
"""
workflow_ids = self._find_workflows(experiment_id)
# delete all the workflows
for workflow_id in workflow_ids:
self.delete_workflow(workflow_id)
await self.delete_workflow(workflow_id)
# delete the experiment
self._client.delete_experiment(experiment_id)

Expand Down Expand Up @@ -164,7 +164,7 @@ async def _format_experiment(self, experiment: MlflowExperiment) -> GetExperimen
workflows=workflows,
)

def update_experiment(self, experiment_id: str, new_name: str) -> None:
async def update_experiment(self, experiment_id: str, new_name: str) -> None:
"""Update the name of an experiment."""
raise NotImplementedError

Expand Down Expand Up @@ -199,7 +199,7 @@ async def experiments_count(self):
# this corresponds to creating a run in MLflow.
# The run will have n number of nested runs,
# which correspond to what we call "jobs" in our system
def create_workflow(
async def create_workflow(
self, experiment_id: str, description: str, name: str, model: str, system_prompt: str
) -> WorkflowResponse:
"""Create a new workflow."""
Expand Down Expand Up @@ -256,7 +256,7 @@ async def get_workflow(self, workflow_id: str) -> WorkflowDetailsResponse | None
system_prompt=workflow.data.tags.get("system_prompt"),
status=WorkflowStatus(workflow.data.tags.get("status")),
created_at=datetime.fromtimestamp(workflow.info.start_time / 1000),
jobs=[self.get_job(job_id) for job_id in all_job_ids],
jobs=[await self.get_job(job_id) for job_id in all_job_ids],
metrics=self._compile_metrics(all_job_ids),
parameters=self._compile_parameters(all_job_ids),
)
Expand Down Expand Up @@ -302,7 +302,7 @@ async def get_workflow(self, workflow_id: str) -> WorkflowDetailsResponse | None
workflow_details.artifacts_download_url = download_url
return workflow_details

def update_workflow_status(self, workflow_id: str, status: WorkflowStatus) -> None:
async def update_workflow_status(self, workflow_id: str, status: WorkflowStatus) -> None:
"""Update the status of a workflow."""
self._client.set_tag(workflow_id, "status", status.value)

Expand All @@ -328,7 +328,7 @@ def _get_ray_job_logs(self, ray_job_id: str):
loguru.logger.error(f"Response text: {resp.text}")
raise JobUpstreamError(ray_job_id, "JSON decode error in Ray response") from e

def get_workflow_logs(self, workflow_id: str) -> JobLogsResponse:
async def get_workflow_logs(self, workflow_id: str) -> JobLogsResponse:
workflow_run = self._client.get_run(workflow_id)
# get the jobs associated with the workflow
all_jobs = self._client.search_runs(
Expand All @@ -343,7 +343,7 @@ def get_workflow_logs(self, workflow_id: str) -> JobLogsResponse:
# TODO: This is not a great solution but it matches the current API
return JobLogsResponse(logs="\n================\n".join([log.logs for log in logs]))

def delete_workflow(self, workflow_id: str) -> WorkflowResponse:
async def delete_workflow(self, workflow_id: str) -> WorkflowResponse:
"""Delete a workflow."""
# first, get the workflow
workflow = self._client.get_run(workflow_id)
Expand All @@ -370,11 +370,11 @@ def delete_workflow(self, workflow_id: str) -> WorkflowResponse:
created_at=datetime.fromtimestamp(workflow.info.start_time / 1000),
)

def list_workflows(self, experiment_id: str) -> list:
async def list_workflows(self, experiment_id: str) -> list:
"""List all workflows in an experiment."""
raise NotImplementedError

def create_job(self, experiment_id: str, workflow_id: str, name: str, job_id: str) -> str:
async def create_job(self, experiment_id: str, workflow_id: str, name: str, job_id: str) -> str:
"""Link a started job to an experiment and a workflow."""
run = self._client.create_run(
experiment_id=experiment_id,
Expand All @@ -384,14 +384,14 @@ def create_job(self, experiment_id: str, workflow_id: str, name: str, job_id: st
self._client.log_param(run.info.run_id, "ray_job_id", job_id)
return run.info.run_id

def update_job(self, job_id: str, data: RunOutputs):
async def update_job(self, job_id: str, data: RunOutputs):
"""Update the metrics and parameters of a job."""
for metric, value in data.metrics.items():
self._client.log_metric(job_id, metric, value)
for parameter, value in data.parameters.items():
self._client.log_param(job_id, parameter, value)

def get_job(self, job_id: str):
async def get_job(self, job_id: str):
"""Get the results of a job."""
run = self._client.get_run(job_id)
if run.info.lifecycle_stage == "deleted":
Expand All @@ -404,11 +404,11 @@ def get_job(self, job_id: str):
artifact_url="TODO",
)

def delete_job(self, job_id: str):
async def delete_job(self, job_id: str):
"""Delete a job."""
self._client.delete_run(job_id)

def list_jobs(self, workflow_id: str):
async def list_jobs(self, workflow_id: str):
"""List all jobs in a workflow."""
workflow_run = self._client.get_run(workflow_id)
# get the jobs associated with the workflow
Expand Down
Loading