Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 3 additions & 25 deletions lumigator/backend/backend/api/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
from collections.abc import Generator
from typing import Annotated

import boto3
from fastapi import BackgroundTasks, Depends
from lumigator_schemas.redactor import Redactor
from mypy_boto3_s3.client import S3Client
from ray.job_submission import JobSubmissionClient
from s3fs import S3FileSystem
from sqlalchemy.orm import Session
Expand All @@ -32,23 +30,6 @@ def get_db_session() -> Generator[Session, None, None]:
DBSessionDep = Annotated[Session, Depends(get_db_session)]


def get_s3_client() -> S3Client:
aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID")
aws_secret_access_key = os.environ.get("AWS_SECRET_ACCESS_KEY")
aws_default_region = os.environ.get("AWS_DEFAULT_REGION")

return boto3.client(
"s3",
endpoint_url=settings.S3_ENDPOINT_URL,
aws_access_key_id=aws_access_key,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_default_region,
)


S3ClientDep = Annotated[S3Client, Depends(get_s3_client)]


def get_s3_filesystem() -> S3FileSystem:
aws_access_key_id = os.environ.get("AWS_ACCESS_KEY_ID")
aws_secret_access_key = os.environ.get("AWS_SECRET_ACCESS_KEY")
Expand All @@ -66,13 +47,12 @@ def get_s3_filesystem() -> S3FileSystem:
S3FileSystemDep = Annotated[S3FileSystem, Depends(get_s3_filesystem)]


def get_tracking_client_manager(s3_client: S3ClientDep, s3_file_system: S3FileSystemDep) -> TrackingClientManager:
def get_tracking_client_manager(s3_file_system: S3FileSystemDep) -> TrackingClientManager:
"""Dependency to provide a tracking client manager instance."""
if settings.TRACKING_BACKEND == settings.TrackingBackendType.MLFLOW:
return MLflowClientManager(
tracking_uri=settings.TRACKING_BACKEND_URI,
s3_file_system=s3_file_system,
s3_client=s3_client,
)
else:
raise ValueError(f"Unsupported tracking backend: {settings.TRACKING_BACKEND}")
Expand All @@ -92,11 +72,9 @@ def get_tracking_client(
TrackingClientDep = Annotated[TrackingClient, Depends(get_tracking_client)]


def get_dataset_service(
session: DBSessionDep, s3_client: S3ClientDep, s3_filesystem: S3FileSystemDep
) -> DatasetService:
def get_dataset_service(session: DBSessionDep, s3_filesystem: S3FileSystemDep) -> DatasetService:
dataset_repo = DatasetRepository(session)
return DatasetService(dataset_repo, s3_client, s3_filesystem)
return DatasetService(dataset_repo, s3_filesystem)


DatasetServiceDep = Annotated[DatasetService, Depends(get_dataset_service)]
Expand Down
4 changes: 2 additions & 2 deletions lumigator/backend/backend/api/routes/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def list_datasets(


@router.get("/{dataset_id}/download")
def get_dataset_download(
async def get_dataset_download(
service: DatasetServiceDep,
dataset_id: UUID,
extension: str | None = Query(
Expand All @@ -100,4 +100,4 @@ def get_dataset_download(
),
) -> DatasetDownloadResponse:
"""Returns a collection of pre-signed URLs which can be used to download the dataset."""
return service.get_dataset_download(dataset_id, extension)
return await service.get_dataset_download(dataset_id, extension)
10 changes: 6 additions & 4 deletions lumigator/backend/backend/api/routes/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,21 @@ def create_experiment_id(service: ExperimentServiceDep, request: ExperimentCreat


@router.get("/{experiment_id}")
def get_experiment(service: ExperimentServiceDep, experiment_id: str) -> GetExperimentResponse:
async def get_experiment(service: ExperimentServiceDep, experiment_id: str) -> GetExperimentResponse:
"""Get an experiment by ID."""
return GetExperimentResponse.model_validate(service.get_experiment(experiment_id).model_dump())
experiment = await service.get_experiment(experiment_id)
return GetExperimentResponse.model_validate(experiment.model_dump())


@router.get("/")
def list_experiments(
async def list_experiments(
service: ExperimentServiceDep,
skip: int = 0,
limit: int = 100,
) -> ListingResponse[GetExperimentResponse]:
"""List all experiments."""
return ListingResponse[GetExperimentResponse].model_validate(service.list_experiments(skip, limit).model_dump())
experiments = await service.list_experiments(skip, limit)
return ListingResponse[GetExperimentResponse].model_validate(experiments.model_dump())


@router.delete("/{experiment_id}")
Expand Down
4 changes: 2 additions & 2 deletions lumigator/backend/backend/api/routes/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,12 @@ def get_job_dataset(


@router.get("/{job_id}/result/download")
def get_job_result_download(
async def get_job_result_download(
service: JobServiceDep,
job_id: UUID,
) -> JobResultDownloadResponse:
"""Return job results file URL for downloading."""
return service.get_job_result_download(job_id)
return await service.get_job_result_download(job_id)


def _get_all_ray_jobs() -> list[RayJobDetails]:
Expand Down
7 changes: 4 additions & 3 deletions lumigator/backend/backend/api/routes/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@ async def create_workflow(service: WorkflowServiceDep, request: WorkflowCreateRe
It must be associated with an experiment id,
which means you must already have created an experiment and have that ID in the request.
"""
return WorkflowResponse.model_validate(service.create_workflow(request))
return WorkflowResponse.model_validate(await service.create_workflow(request))


@router.get("/{workflow_id}")
def get_workflow(service: WorkflowServiceDep, workflow_id: str) -> WorkflowDetailsResponse:
async def get_workflow(service: WorkflowServiceDep, workflow_id: str) -> WorkflowDetailsResponse:
"""Get a workflow by ID."""
return WorkflowDetailsResponse.model_validate(service.get_workflow(workflow_id).model_dump())
workflow_details = await service.get_workflow(workflow_id)
return WorkflowDetailsResponse.model_validate(workflow_details.model_dump())


# get the logs
Expand Down
29 changes: 14 additions & 15 deletions lumigator/backend/backend/services/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from loguru import logger
from lumigator_schemas.datasets import DatasetDownloadResponse, DatasetFormat, DatasetResponse
from lumigator_schemas.extras import ListingResponse
from mypy_boto3_s3.client import S3Client
from pydantic import ByteSize
from s3fs import S3FileSystem

Expand Down Expand Up @@ -102,9 +101,8 @@ def dataset_has_gt(filename: str) -> bool:


class DatasetService:
def __init__(self, dataset_repo: DatasetRepository, s3_client: S3Client, s3_filesystem: S3FileSystem):
def __init__(self, dataset_repo: DatasetRepository, s3_filesystem: S3FileSystem):
self.dataset_repo = dataset_repo
self.s3_client = s3_client
self.s3_filesystem = s3_filesystem

def _get_dataset_record(self, dataset_id: UUID) -> DatasetRecord | None:
Expand Down Expand Up @@ -153,7 +151,7 @@ def _save_dataset_to_s3(self, temp_fname, record):
if record:
self.dataset_repo.delete(record.id)

raise DatasetUpstreamError("s3", "error attempting to save dataset to S3", e) from e
raise DatasetUpstreamError("s3", "error attempting to save dataset to S3") from e
finally:
# Clean up temp file
Path(temp.name).unlink()
Expand Down Expand Up @@ -281,12 +279,12 @@ def delete_dataset(self, dataset_id: UUID) -> None:
f"Cleaning up DB by removing ID. {e}"
)
except Exception as e:
raise DatasetUpstreamError("s3", f"error attempting to delete dataset {dataset_id} from S3", e) from e
raise DatasetUpstreamError("s3", f"error attempting to delete dataset {dataset_id} from S3") from e

# Getting this far means we are OK to remove the record from the DB.
self.dataset_repo.delete(record.id)

def get_dataset_download(self, dataset_id: UUID, extension: str | None = None) -> DatasetDownloadResponse:
async def get_dataset_download(self, dataset_id: UUID, extension: str | None = None) -> DatasetDownloadResponse:
"""Generate pre-signed download URLs for dataset files.

When supplied, only URLs for files that match the specified extension are returned.
Expand All @@ -306,31 +304,32 @@ def get_dataset_download(self, dataset_id: UUID, extension: str | None = None) -
dataset_key = self._get_s3_key(dataset_id, record.filename)

try:
# Call list_objects_v2 to get all objects whose key names start with `dataset_key`
s3_response = self.s3_client.list_objects_v2(Bucket=settings.S3_BUCKET, Prefix=dataset_key)
# Call find to get all objects whose key names start with `dataset_key`
s3_response = self.s3_filesystem.find(path=settings.S3_BUCKET, prefix=dataset_key)

if s3_response.get("KeyCount") == 0:
if not len(s3_response):
raise DatasetNotFoundError(dataset_id, f"No S3 files found with prefix '{dataset_key}'") from None

download_urls = []
for s3_object in s3_response["Contents"]:
for s3_object in s3_response:
# Ignore files that don't end with the extension if it was specified
if extension and not s3_object["Key"].lower().endswith(extension):
if extension and not s3_object.lower().endswith(extension):
continue

download_url = self.s3_client.generate_presigned_url(
download_url = await self.s3_filesystem.s3.generate_presigned_url(
"get_object",
Params={
"Bucket": settings.S3_BUCKET,
"Key": s3_object["Key"],
"Key": s3_object,
},
ExpiresIn=settings.S3_URL_EXPIRATION,
)
download_urls.append(download_url)

except DatasetNotFoundError:
raise
except Exception as e:
msg = f"Error generating pre-signed download URLs for dataset {dataset_id}"
raise DatasetUpstreamError("s3", msg, e) from e
raise DatasetUpstreamError("s3", msg) from e

return DatasetDownloadResponse(id=dataset_id, download_urls=download_urls)

Expand Down
11 changes: 6 additions & 5 deletions lumigator/backend/backend/services/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,17 @@ def create_experiment(self, request: ExperimentCreate) -> GetExperimentResponse:
loguru.logger.info(f"Created tracking experiment '{experiment.name}' with ID '{experiment.id}'.")
return experiment

def get_experiment(self, experiment_id: str) -> GetExperimentResponse:
record = self._tracking_session.get_experiment(experiment_id)
async def get_experiment(self, experiment_id: str) -> GetExperimentResponse:
record = await self._tracking_session.get_experiment(experiment_id)
if record is None:
raise ExperimentNotFoundError(experiment_id) from None
return GetExperimentResponse.model_validate(record)

def list_experiments(self, skip: int, limit: int) -> ListingResponse[GetExperimentResponse]:
records = self._tracking_session.list_experiments(skip, limit)
async def list_experiments(self, skip: int, limit: int) -> ListingResponse[GetExperimentResponse]:
records = await self._tracking_session.list_experiments(skip, limit)
total = await self._tracking_session.experiments_count()
return ListingResponse(
total=self._tracking_session.experiments_count(),
total=total,
items=[GetExperimentResponse.model_validate(x) for x in records],
)

Expand Down
7 changes: 4 additions & 3 deletions lumigator/backend/backend/services/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def get_upstream_job_status(self, job_id: UUID) -> str:
status_response = self.ray_client.get_job_status(str(job_id))
return str(status_response.value.lower())
except RuntimeError as e:
raise JobUpstreamError("ray", "error getting Ray job status", e) from e
raise JobUpstreamError("ray", "error getting Ray job status") from e

def get_job_logs(self, job_id: UUID) -> JobLogsResponse:
"""Retrieves the logs for a job from the upstream service.
Expand Down Expand Up @@ -572,11 +572,12 @@ def get_job_result(self, job_id: UUID) -> JobResultResponse:

return JobResultResponse.model_validate(result_record)

def get_job_result_download(self, job_id: UUID) -> JobResultDownloadResponse:
async def get_job_result_download(self, job_id: UUID) -> JobResultDownloadResponse:
"""Return job results file URL for downloading."""
# Generate presigned download URL for the object
result_key = self._get_results_s3_key(job_id)
download_url = self._dataset_service.s3_client.generate_presigned_url(

download_url = await self._dataset_service.s3_filesystem.s3.generate_presigned_url(
"get_object",
Params={
"Bucket": settings.S3_BUCKET,
Expand Down
10 changes: 5 additions & 5 deletions lumigator/backend/backend/services/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ async def _run_inference_eval_pipeline(
"""Currently this is our only workflow. As we make this more flexible to handle different
sequences of jobs, we'll need to refactor this function to be more generic.
"""
experiment = self._tracking_client.get_experiment(request.experiment_id)
experiment = await self._tracking_client.get_experiment(request.experiment_id)

# input is WorkflowCreateRequest, we need to split the configs and generate one
# JobInferenceCreate and one JobEvalCreate
Expand Down Expand Up @@ -334,14 +334,14 @@ async def _run_inference_eval_pipeline(
await self._handle_workflow_failure(workflow.id)
return

def get_workflow(self, workflow_id: str) -> WorkflowDetailsResponse:
async def get_workflow(self, workflow_id: str) -> WorkflowDetailsResponse:
"""Get a workflow."""
tracking_server_workflow = self._tracking_client.get_workflow(workflow_id)
tracking_server_workflow = await self._tracking_client.get_workflow(workflow_id)
if tracking_server_workflow is None:
raise WorkflowNotFoundError(workflow_id) from None
return tracking_server_workflow

def create_workflow(self, request: WorkflowCreateRequest) -> WorkflowResponse:
async def create_workflow(self, request: WorkflowCreateRequest) -> WorkflowResponse:
"""Creates a new workflow and submits inference and evaluation jobs.

Args:
Expand All @@ -351,7 +351,7 @@ def create_workflow(self, request: WorkflowCreateRequest) -> WorkflowResponse:
WorkflowResponse: The response object containing the details of the created workflow.
"""
# If the experiment this workflow is associated with doesn't exist, there's no point in continuing.
experiment = self._tracking_client.get_experiment(request.experiment_id)
experiment = await self._tracking_client.get_experiment(request.experiment_id)
if not experiment:
raise WorkflowValidationError(f"Cannot create workflow '{request.name}': No experiment found.") from None

Expand Down
Loading
Loading