diff --git a/lumigator/backend/backend/api/deps.py b/lumigator/backend/backend/api/deps.py index 5fd74593a..0d980eac6 100644 --- a/lumigator/backend/backend/api/deps.py +++ b/lumigator/backend/backend/api/deps.py @@ -2,10 +2,8 @@ from collections.abc import Generator from typing import Annotated -import boto3 from fastapi import BackgroundTasks, Depends from lumigator_schemas.redactor import Redactor -from mypy_boto3_s3.client import S3Client from ray.job_submission import JobSubmissionClient from s3fs import S3FileSystem from sqlalchemy.orm import Session @@ -32,23 +30,6 @@ def get_db_session() -> Generator[Session, None, None]: DBSessionDep = Annotated[Session, Depends(get_db_session)] -def get_s3_client() -> S3Client: - aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID") - aws_secret_access_key = os.environ.get("AWS_SECRET_ACCESS_KEY") - aws_default_region = os.environ.get("AWS_DEFAULT_REGION") - - return boto3.client( - "s3", - endpoint_url=settings.S3_ENDPOINT_URL, - aws_access_key_id=aws_access_key, - aws_secret_access_key=aws_secret_access_key, - region_name=aws_default_region, - ) - - -S3ClientDep = Annotated[S3Client, Depends(get_s3_client)] - - def get_s3_filesystem() -> S3FileSystem: aws_access_key_id = os.environ.get("AWS_ACCESS_KEY_ID") aws_secret_access_key = os.environ.get("AWS_SECRET_ACCESS_KEY") @@ -66,13 +47,12 @@ def get_s3_filesystem() -> S3FileSystem: S3FileSystemDep = Annotated[S3FileSystem, Depends(get_s3_filesystem)] -def get_tracking_client_manager(s3_client: S3ClientDep, s3_file_system: S3FileSystemDep) -> TrackingClientManager: +def get_tracking_client_manager(s3_file_system: S3FileSystemDep) -> TrackingClientManager: """Dependency to provide a tracking client manager instance.""" if settings.TRACKING_BACKEND == settings.TrackingBackendType.MLFLOW: return MLflowClientManager( tracking_uri=settings.TRACKING_BACKEND_URI, s3_file_system=s3_file_system, - s3_client=s3_client, ) else: raise ValueError(f"Unsupported tracking backend: {settings.TRACKING_BACKEND}") @@ -92,11 +72,9 @@ def get_tracking_client( TrackingClientDep = Annotated[TrackingClient, Depends(get_tracking_client)] -def get_dataset_service( - session: DBSessionDep, s3_client: S3ClientDep, s3_filesystem: S3FileSystemDep -) -> DatasetService: +def get_dataset_service(session: DBSessionDep, s3_filesystem: S3FileSystemDep) -> DatasetService: dataset_repo = DatasetRepository(session) - return DatasetService(dataset_repo, s3_client, s3_filesystem) + return DatasetService(dataset_repo, s3_filesystem) DatasetServiceDep = Annotated[DatasetService, Depends(get_dataset_service)] diff --git a/lumigator/backend/backend/api/routes/datasets.py b/lumigator/backend/backend/api/routes/datasets.py index 5f20c0747..d52cc1457 100644 --- a/lumigator/backend/backend/api/routes/datasets.py +++ b/lumigator/backend/backend/api/routes/datasets.py @@ -89,7 +89,7 @@ def list_datasets( @router.get("/{dataset_id}/download") -def get_dataset_download( +async def get_dataset_download( service: DatasetServiceDep, dataset_id: UUID, extension: str | None = Query( @@ -100,4 +100,4 @@ def get_dataset_download( ), ) -> DatasetDownloadResponse: """Returns a collection of pre-signed URLs which can be used to download the dataset.""" - return service.get_dataset_download(dataset_id, extension) + return await service.get_dataset_download(dataset_id, extension) diff --git a/lumigator/backend/backend/api/routes/experiments.py b/lumigator/backend/backend/api/routes/experiments.py index 30f53fc33..233fc6e23 100644 --- a/lumigator/backend/backend/api/routes/experiments.py +++ b/lumigator/backend/backend/api/routes/experiments.py @@ -27,19 +27,21 @@ def create_experiment_id(service: ExperimentServiceDep, request: ExperimentCreat @router.get("/{experiment_id}") -def get_experiment(service: ExperimentServiceDep, experiment_id: str) -> GetExperimentResponse: +async def get_experiment(service: ExperimentServiceDep, experiment_id: str) -> GetExperimentResponse: """Get an experiment by ID.""" - return GetExperimentResponse.model_validate(service.get_experiment(experiment_id).model_dump()) + experiment = await service.get_experiment(experiment_id) + return GetExperimentResponse.model_validate(experiment.model_dump()) @router.get("/") -def list_experiments( +async def list_experiments( service: ExperimentServiceDep, skip: int = 0, limit: int = 100, ) -> ListingResponse[GetExperimentResponse]: """List all experiments.""" - return ListingResponse[GetExperimentResponse].model_validate(service.list_experiments(skip, limit).model_dump()) + experiments = await service.list_experiments(skip, limit) + return ListingResponse[GetExperimentResponse].model_validate(experiments.model_dump()) @router.delete("/{experiment_id}") diff --git a/lumigator/backend/backend/api/routes/jobs.py b/lumigator/backend/backend/api/routes/jobs.py index 7be3b2769..0e89f5524 100644 --- a/lumigator/backend/backend/api/routes/jobs.py +++ b/lumigator/backend/backend/api/routes/jobs.py @@ -221,12 +221,12 @@ def get_job_dataset( @router.get("/{job_id}/result/download") -def get_job_result_download( +async def get_job_result_download( service: JobServiceDep, job_id: UUID, ) -> JobResultDownloadResponse: """Return job results file URL for downloading.""" - return service.get_job_result_download(job_id) + return await service.get_job_result_download(job_id) def _get_all_ray_jobs() -> list[RayJobDetails]: diff --git a/lumigator/backend/backend/api/routes/workflows.py b/lumigator/backend/backend/api/routes/workflows.py index 5379a9b1e..3b68f5b3e 100644 --- a/lumigator/backend/backend/api/routes/workflows.py +++ b/lumigator/backend/backend/api/routes/workflows.py @@ -32,13 +32,14 @@ async def create_workflow(service: WorkflowServiceDep, request: WorkflowCreateRe It must be associated with an experiment id, which means you must already have created an experiment and have that ID in the request. """ - return WorkflowResponse.model_validate(service.create_workflow(request)) + return WorkflowResponse.model_validate(await service.create_workflow(request)) @router.get("/{workflow_id}") -def get_workflow(service: WorkflowServiceDep, workflow_id: str) -> WorkflowDetailsResponse: +async def get_workflow(service: WorkflowServiceDep, workflow_id: str) -> WorkflowDetailsResponse: """Get a workflow by ID.""" - return WorkflowDetailsResponse.model_validate(service.get_workflow(workflow_id).model_dump()) + workflow_details = await service.get_workflow(workflow_id) + return WorkflowDetailsResponse.model_validate(workflow_details.model_dump()) # get the logs diff --git a/lumigator/backend/backend/services/datasets.py b/lumigator/backend/backend/services/datasets.py index 1f65d7b0a..55d89c8f1 100644 --- a/lumigator/backend/backend/services/datasets.py +++ b/lumigator/backend/backend/services/datasets.py @@ -9,7 +9,6 @@ from loguru import logger from lumigator_schemas.datasets import DatasetDownloadResponse, DatasetFormat, DatasetResponse from lumigator_schemas.extras import ListingResponse -from mypy_boto3_s3.client import S3Client from pydantic import ByteSize from s3fs import S3FileSystem @@ -102,9 +101,8 @@ def dataset_has_gt(filename: str) -> bool: class DatasetService: - def __init__(self, dataset_repo: DatasetRepository, s3_client: S3Client, s3_filesystem: S3FileSystem): + def __init__(self, dataset_repo: DatasetRepository, s3_filesystem: S3FileSystem): self.dataset_repo = dataset_repo - self.s3_client = s3_client self.s3_filesystem = s3_filesystem def _get_dataset_record(self, dataset_id: UUID) -> DatasetRecord | None: @@ -153,7 +151,7 @@ def _save_dataset_to_s3(self, temp_fname, record): if record: self.dataset_repo.delete(record.id) - raise DatasetUpstreamError("s3", "error attempting to save dataset to S3", e) from e + raise DatasetUpstreamError("s3", "error attempting to save dataset to S3") from e finally: # Clean up temp file Path(temp.name).unlink() @@ -281,12 +279,12 @@ def delete_dataset(self, dataset_id: UUID) -> None: f"Cleaning up DB by removing ID. {e}" ) except Exception as e: - raise DatasetUpstreamError("s3", f"error attempting to delete dataset {dataset_id} from S3", e) from e + raise DatasetUpstreamError("s3", f"error attempting to delete dataset {dataset_id} from S3") from e # Getting this far means we are OK to remove the record from the DB. self.dataset_repo.delete(record.id) - def get_dataset_download(self, dataset_id: UUID, extension: str | None = None) -> DatasetDownloadResponse: + async def get_dataset_download(self, dataset_id: UUID, extension: str | None = None) -> DatasetDownloadResponse: """Generate pre-signed download URLs for dataset files. When supplied, only URLs for files that match the specified extension are returned. @@ -306,31 +304,32 @@ def get_dataset_download(self, dataset_id: UUID, extension: str | None = None) - dataset_key = self._get_s3_key(dataset_id, record.filename) try: - # Call list_objects_v2 to get all objects whose key names start with `dataset_key` - s3_response = self.s3_client.list_objects_v2(Bucket=settings.S3_BUCKET, Prefix=dataset_key) + # Call find to get all objects whose key names start with `dataset_key` + s3_response = self.s3_filesystem.find(path=settings.S3_BUCKET, prefix=dataset_key) - if s3_response.get("KeyCount") == 0: + if not len(s3_response): raise DatasetNotFoundError(dataset_id, f"No S3 files found with prefix '{dataset_key}'") from None download_urls = [] - for s3_object in s3_response["Contents"]: + for s3_object in s3_response: # Ignore files that don't end with the extension if it was specified - if extension and not s3_object["Key"].lower().endswith(extension): + if extension and not s3_object.lower().endswith(extension): continue - download_url = self.s3_client.generate_presigned_url( + download_url = await self.s3_filesystem.s3.generate_presigned_url( "get_object", Params={ "Bucket": settings.S3_BUCKET, - "Key": s3_object["Key"], + "Key": s3_object, }, ExpiresIn=settings.S3_URL_EXPIRATION, ) download_urls.append(download_url) - + except DatasetNotFoundError: + raise except Exception as e: msg = f"Error generating pre-signed download URLs for dataset {dataset_id}" - raise DatasetUpstreamError("s3", msg, e) from e + raise DatasetUpstreamError("s3", msg) from e return DatasetDownloadResponse(id=dataset_id, download_urls=download_urls) diff --git a/lumigator/backend/backend/services/experiments.py b/lumigator/backend/backend/services/experiments.py index 892d40897..3bf1fd126 100644 --- a/lumigator/backend/backend/services/experiments.py +++ b/lumigator/backend/backend/services/experiments.py @@ -36,16 +36,17 @@ def create_experiment(self, request: ExperimentCreate) -> GetExperimentResponse: loguru.logger.info(f"Created tracking experiment '{experiment.name}' with ID '{experiment.id}'.") return experiment - def get_experiment(self, experiment_id: str) -> GetExperimentResponse: - record = self._tracking_session.get_experiment(experiment_id) + async def get_experiment(self, experiment_id: str) -> GetExperimentResponse: + record = await self._tracking_session.get_experiment(experiment_id) if record is None: raise ExperimentNotFoundError(experiment_id) from None return GetExperimentResponse.model_validate(record) - def list_experiments(self, skip: int, limit: int) -> ListingResponse[GetExperimentResponse]: - records = self._tracking_session.list_experiments(skip, limit) + async def list_experiments(self, skip: int, limit: int) -> ListingResponse[GetExperimentResponse]: + records = await self._tracking_session.list_experiments(skip, limit) + total = await self._tracking_session.experiments_count() return ListingResponse( - total=self._tracking_session.experiments_count(), + total=total, items=[GetExperimentResponse.model_validate(x) for x in records], ) diff --git a/lumigator/backend/backend/services/jobs.py b/lumigator/backend/backend/services/jobs.py index fc3b94844..c8b591996 100644 --- a/lumigator/backend/backend/services/jobs.py +++ b/lumigator/backend/backend/services/jobs.py @@ -294,7 +294,7 @@ def get_upstream_job_status(self, job_id: UUID) -> str: status_response = self.ray_client.get_job_status(str(job_id)) return str(status_response.value.lower()) except RuntimeError as e: - raise JobUpstreamError("ray", "error getting Ray job status", e) from e + raise JobUpstreamError("ray", "error getting Ray job status") from e def get_job_logs(self, job_id: UUID) -> JobLogsResponse: """Retrieves the logs for a job from the upstream service. @@ -572,11 +572,12 @@ def get_job_result(self, job_id: UUID) -> JobResultResponse: return JobResultResponse.model_validate(result_record) - def get_job_result_download(self, job_id: UUID) -> JobResultDownloadResponse: + async def get_job_result_download(self, job_id: UUID) -> JobResultDownloadResponse: """Return job results file URL for downloading.""" # Generate presigned download URL for the object result_key = self._get_results_s3_key(job_id) - download_url = self._dataset_service.s3_client.generate_presigned_url( + + download_url = await self._dataset_service.s3_filesystem.s3.generate_presigned_url( "get_object", Params={ "Bucket": settings.S3_BUCKET, diff --git a/lumigator/backend/backend/services/workflows.py b/lumigator/backend/backend/services/workflows.py index 38247f68c..590d845cb 100644 --- a/lumigator/backend/backend/services/workflows.py +++ b/lumigator/backend/backend/services/workflows.py @@ -111,7 +111,7 @@ async def _run_inference_eval_pipeline( """Currently this is our only workflow. As we make this more flexible to handle different sequences of jobs, we'll need to refactor this function to be more generic. """ - experiment = self._tracking_client.get_experiment(request.experiment_id) + experiment = await self._tracking_client.get_experiment(request.experiment_id) # input is WorkflowCreateRequest, we need to split the configs and generate one # JobInferenceCreate and one JobEvalCreate @@ -334,14 +334,14 @@ async def _run_inference_eval_pipeline( await self._handle_workflow_failure(workflow.id) return - def get_workflow(self, workflow_id: str) -> WorkflowDetailsResponse: + async def get_workflow(self, workflow_id: str) -> WorkflowDetailsResponse: """Get a workflow.""" - tracking_server_workflow = self._tracking_client.get_workflow(workflow_id) + tracking_server_workflow = await self._tracking_client.get_workflow(workflow_id) if tracking_server_workflow is None: raise WorkflowNotFoundError(workflow_id) from None return tracking_server_workflow - def create_workflow(self, request: WorkflowCreateRequest) -> WorkflowResponse: + async def create_workflow(self, request: WorkflowCreateRequest) -> WorkflowResponse: """Creates a new workflow and submits inference and evaluation jobs. Args: @@ -351,7 +351,7 @@ def create_workflow(self, request: WorkflowCreateRequest) -> WorkflowResponse: WorkflowResponse: The response object containing the details of the created workflow. """ # If the experiment this workflow is associated with doesn't exist, there's no point in continuing. - experiment = self._tracking_client.get_experiment(request.experiment_id) + experiment = await self._tracking_client.get_experiment(request.experiment_id) if not experiment: raise WorkflowValidationError(f"Cannot create workflow '{request.name}': No experiment found.") from None diff --git a/lumigator/backend/backend/tests/conftest.py b/lumigator/backend/backend/tests/conftest.py index 5a76a18ed..7207b47ff 100644 --- a/lumigator/backend/backend/tests/conftest.py +++ b/lumigator/backend/backend/tests/conftest.py @@ -5,13 +5,11 @@ import uuid from collections.abc import Generator from pathlib import Path -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock from uuid import UUID -import boto3 import evaluator import fsspec -import loguru import pytest import requests_mock import yaml @@ -28,13 +26,12 @@ JobType, ) from lumigator_schemas.models import ModelsResponse -from mypy_boto3_s3 import S3Client from s3fs import S3FileSystem from sqlalchemy import Engine, create_engine from sqlalchemy.orm import Session from starlette.background import BackgroundTasks -from backend.api.deps import get_db_session, get_job_service, get_s3_client, get_s3_filesystem +from backend.api.deps import get_db_session, get_job_service, get_s3_filesystem from backend.api.router import API_V1_PREFIX from backend.main import create_app from backend.records.jobs import JobRecord @@ -259,31 +256,17 @@ def db_session(db_engine: Engine): def fake_s3fs() -> S3FileSystem: """Replace the filesystem registry for S3 with a MemoryFileSystem implementation.""" fsspec.register_implementation("s3", MemoryFileSystem, clobber=True, errtxt="Failed to register mock S3FS") - yield MemoryFileSystem() - logger.info(f"final s3fs contents: {str(MemoryFileSystem.store)}") - - -@pytest.fixture(scope="function") -def fake_s3_client(fake_s3fs) -> S3Client: - """Provide a fake S3 client using MemoryFileSystem as underlying storage.""" - return FakeS3Client(MemoryFileSystem.store) - - -@pytest.fixture(scope="function") -def boto_s3_client() -> S3Client: - """Provide a real S3 client.""" - aws_access_key_id = os.environ.get("AWS_ACCESS_KEY_ID", "lumigator") - aws_secret_access_key = os.environ.get("AWS_SECRET_ACCESS_KEY", "lumigator") - aws_endpoint_url = os.environ.get("AWS_ENDPOINT_URL", "http://localhost:9000") - aws_default_region = os.environ.get("AWS_DEFAULT_REGION", "us-east-2") + mfs = MemoryFileSystem() + mfs_mock = MagicMock(wraps=mfs) + mfs_mock.s3 = FakeS3Client(MemoryFileSystem.store) + # Mock the find method to match the path (minus the S3:// prefix) + # and be a bit less strict about just seeing the prefix in the path in general. + mfs_mock.find = lambda path, prefix: [ + k for k in MemoryFileSystem.store.keys() if k.removeprefix("s3://").startswith(path) and k.find(prefix) != -1 + ] - return boto3.client( - "s3", - endpoint_url=aws_endpoint_url, - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - region_name=aws_default_region, - ) + yield mfs_mock + logger.info(f"final s3fs contents: {str(MemoryFileSystem.store)}") @pytest.fixture(scope="function") @@ -340,9 +323,7 @@ def local_client(app: FastAPI): @pytest.fixture(scope="function") -def dependency_overrides_fakes( - app: FastAPI, db_session: Session, fake_s3_client: S3Client, fake_s3fs: S3FileSystem -) -> None: +def dependency_overrides_fakes(app: FastAPI, db_session: Session, fake_s3fs: S3FileSystem) -> None: """Override the FastAPI dependency injection for test DB sessions. Uses mocks/fakes for unit tests. Reference: https://sqlmodel.tiangolo.com/tutorial/fastapi/tests/#override-a-dependency @@ -351,21 +332,15 @@ def dependency_overrides_fakes( def get_db_session_override(): yield db_session - def get_s3_client_override(): - yield fake_s3_client - def get_s3_filesystem_override(): yield fake_s3fs app.dependency_overrides[get_db_session] = get_db_session_override - app.dependency_overrides[get_s3_client] = get_s3_client_override app.dependency_overrides[get_s3_filesystem] = get_s3_filesystem_override @pytest.fixture(scope="function") -def dependency_overrides_services( - app: FastAPI, db_session: Session, boto_s3_client: S3Client, boto_s3fs: S3FileSystem -) -> None: +def dependency_overrides_services(app: FastAPI, db_session: Session, boto_s3fs: S3FileSystem) -> None: """Override the FastAPI dependency injection for test DB sessions. Uses real clients for integration tests. Reference: https://sqlmodel.tiangolo.com/tutorial/fastapi/tests/#override-a-dependency @@ -374,14 +349,10 @@ def dependency_overrides_services( def get_db_session_override(): yield db_session - def get_s3_client_override(): - yield boto_s3_client - def get_s3_filesystem_override(): yield boto_s3fs app.dependency_overrides[get_db_session] = get_db_session_override - app.dependency_overrides[get_s3_client] = get_s3_client_override app.dependency_overrides[get_s3_filesystem] = get_s3_filesystem_override @@ -422,9 +393,9 @@ def result_repository(db_session): @pytest.fixture(scope="function") -def dataset_service(db_session, fake_s3_client, fake_s3fs): +def dataset_service(db_session, fake_s3fs): dataset_repo = DatasetRepository(db_session) - return DatasetService(dataset_repo=dataset_repo, s3_client=fake_s3_client, s3_filesystem=fake_s3fs) + return DatasetService(dataset_repo=dataset_repo, s3_filesystem=fake_s3fs) @pytest.fixture(scope="function") diff --git a/lumigator/backend/backend/tests/fakes/fake_s3.py b/lumigator/backend/backend/tests/fakes/fake_s3.py index 6f38ad0d0..26b6cae2a 100644 --- a/lumigator/backend/backend/tests/fakes/fake_s3.py +++ b/lumigator/backend/backend/tests/fakes/fake_s3.py @@ -39,7 +39,7 @@ def list_objects_v2(self, **kwargs): "Contents": [self.__map_entry_to_content(k) for k in self.storage.keys() if k.startswith(key)], } - def generate_presigned_url(self, method, **kwargs): + async def generate_presigned_url(self, method, **kwargs): params = kwargs["Params"] bucket = params["Bucket"] path = params["Key"] diff --git a/lumigator/backend/backend/tests/unit/services/test_dataset_service.py b/lumigator/backend/backend/tests/unit/services/test_dataset_service.py index e159879cf..5c8a2bef8 100644 --- a/lumigator/backend/backend/tests/unit/services/test_dataset_service.py +++ b/lumigator/backend/backend/tests/unit/services/test_dataset_service.py @@ -1,3 +1,4 @@ +import asyncio from uuid import UUID import pytest @@ -7,7 +8,7 @@ from backend.services.datasets import DatasetService, dataset_has_gt -def test_delete_dataset_file_not_found(db_session, fake_s3_client, fake_s3fs): +def test_delete_dataset_file_not_found(db_session, fake_s3fs): filename = "dataset.csv" format = "job" dataset_repo = DatasetRepository(db_session) @@ -15,14 +16,14 @@ def test_delete_dataset_file_not_found(db_session, fake_s3_client, fake_s3fs): assert dataset_record is not None assert dataset_record.filename == filename assert dataset_record.format == format - dataset_service = DatasetService(dataset_repo, fake_s3_client, fake_s3fs) + dataset_service = DatasetService(dataset_repo, fake_s3fs) dataset_service.delete_dataset(dataset_record.id) dataset_record = dataset_service._get_dataset_record(dataset_record.id) assert dataset_record is None -def test_upload_dataset(db_session, fake_s3_client, fake_s3fs, valid_upload_file): - dataset_service = DatasetService(DatasetRepository(db_session), fake_s3_client, fake_s3fs) +def test_upload_dataset(db_session, fake_s3fs, valid_upload_file): + dataset_service = DatasetService(DatasetRepository(db_session), fake_s3fs) upload_response = dataset_service.upload_dataset(valid_upload_file, DatasetFormat.JOB) assert upload_response.id is not None @@ -42,14 +43,14 @@ def test_upload_dataset(db_session, fake_s3_client, fake_s3fs, valid_upload_file (" ", 4), ], ) -def test_dataset_download_with_extensions(db_session, fake_s3_client, fake_s3fs, valid_upload_file, extension, total): - dataset_service = DatasetService(DatasetRepository(db_session), fake_s3_client, fake_s3fs) +def test_dataset_download_with_extensions(db_session, fake_s3fs, valid_upload_file, extension, total): + dataset_service = DatasetService(DatasetRepository(db_session), fake_s3fs) upload_response = dataset_service.upload_dataset(valid_upload_file, DatasetFormat.JOB) assert upload_response.id is not None assert isinstance(upload_response.id, UUID) - download_response = dataset_service.get_dataset_download(upload_response.id, extension) + download_response = asyncio.run(dataset_service.get_dataset_download(upload_response.id, extension)) assert download_response.id is not None assert isinstance(download_response.id, UUID) # 4 files total (HF dataset: 1 x arrow file + 2 x json) + 1 CSV. diff --git a/lumigator/backend/backend/tests/unit/services/test_workflow_service.py b/lumigator/backend/backend/tests/unit/services/test_workflow_service.py index 69ee39e79..c532ddda7 100644 --- a/lumigator/backend/backend/tests/unit/services/test_workflow_service.py +++ b/lumigator/backend/backend/tests/unit/services/test_workflow_service.py @@ -1,5 +1,6 @@ +import asyncio import unittest -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock from uuid import UUID from lumigator_schemas.tasks import ( @@ -23,7 +24,7 @@ def test_workflow_request_requires_system_prompt_for_text_generation(workflow_se experiment_mock.name = "Test Experiment" # Configure tracking client to return our mock experiment - workflow_service._tracking_client.get_experiment.return_value = experiment_mock + workflow_service._tracking_client.get_experiment = AsyncMock(return_value=experiment_mock) # Create request without system prompt request = WorkflowCreateRequest( @@ -36,7 +37,7 @@ def test_workflow_request_requires_system_prompt_for_text_generation(workflow_se # Act & Assert with unittest.TestCase().assertRaises(WorkflowValidationError) as context: - workflow_service.create_workflow(request) + asyncio.run(workflow_service.create_workflow(request)) # Verify the error message assert str(context.exception) == "Default system_prompt not available for text-generation" diff --git a/lumigator/backend/backend/tracking/mlflow.py b/lumigator/backend/backend/tracking/mlflow.py index 96616aa4d..2281404c4 100644 --- a/lumigator/backend/backend/tracking/mlflow.py +++ b/lumigator/backend/backend/tracking/mlflow.py @@ -17,7 +17,6 @@ from mlflow.exceptions import MlflowException from mlflow.tracking import MlflowClient from mlflow.utils.mlflow_tags import MLFLOW_PARENT_RUN_ID -from mypy_boto3_s3 import S3Client from pydantic import TypeAdapter from s3fs import S3FileSystem @@ -30,10 +29,9 @@ class MLflowTrackingClient(TrackingClient): """MLflow implementation of the TrackingClient interface.""" - def __init__(self, tracking_uri: str, s3_file_system: S3FileSystem, s3_client: S3Client): + def __init__(self, tracking_uri: str, s3_file_system: S3FileSystem): self._client = MlflowClient(tracking_uri=tracking_uri) self._s3_file_system = s3_file_system - self._s3_client = s3_client def create_experiment( self, @@ -134,7 +132,7 @@ def _find_workflows(self, experiment_id: str) -> list: workflow_ids.append(run.info.run_id) return workflow_ids - def get_experiment(self, experiment_id: str) -> GetExperimentResponse | None: + async def get_experiment(self, experiment_id: str) -> GetExperimentResponse | None: """Get an experiment and all its workflows.""" try: experiment = self._client.get_experiment(experiment_id) @@ -146,12 +144,12 @@ def get_experiment(self, experiment_id: str) -> GetExperimentResponse | None: # If the experiment is in the deleted lifecylce, return None if experiment.lifecycle_stage == "deleted": return None - return self._format_experiment(experiment) + return await self._format_experiment(experiment) - def _format_experiment(self, experiment: MlflowExperiment) -> GetExperimentResponse: + async def _format_experiment(self, experiment: MlflowExperiment) -> GetExperimentResponse: # now get all the workflows associated with that experiment workflow_ids = self._find_workflows(experiment.experiment_id) - workflows = [self.get_workflow(workflow_id) for workflow_id in workflow_ids] + workflows = [await self.get_workflow(workflow_id) for workflow_id in workflow_ids] task_definition_json = experiment.tags.get("task_definition") task_definition = TypeAdapter(TaskDefinition).validate_python(json.loads(task_definition_json)) return GetExperimentResponse( @@ -170,7 +168,7 @@ def update_experiment(self, experiment_id: str, new_name: str) -> None: """Update the name of an experiment.""" raise NotImplementedError - def list_experiments(self, skip: int, limit: int) -> list[GetExperimentResponse]: + async def list_experiments(self, skip: int, limit: int | None) -> list[GetExperimentResponse]: """List all experiments.""" page_token = None experiments = [] @@ -191,12 +189,12 @@ def list_experiments(self, skip: int, limit: int) -> list[GetExperimentResponse] if response.token is None: break reduced_experiments = experiments[:limit] if limit is not None else experiments - return [self._format_experiment(experiment) for experiment in reduced_experiments] + return [await self._format_experiment(experiment) for experiment in reduced_experiments] # TODO find a cheaper call - def experiments_count(self): + async def experiments_count(self): """Get the number of experiments.""" - return len(self.list_experiments(skip=0, limit=None)) + return len(await self.list_experiments(skip=0, limit=None)) # this corresponds to creating a run in MLflow. # The run will have n number of nested runs, @@ -227,7 +225,7 @@ def create_workflow( created_at=datetime.fromtimestamp(workflow.info.start_time / 1000), ) - def get_workflow(self, workflow_id: str) -> WorkflowDetailsResponse | None: + async def get_workflow(self, workflow_id: str) -> WorkflowDetailsResponse | None: """Get a workflow and all its jobs.""" try: workflow = self._client.get_run(workflow_id) @@ -293,7 +291,7 @@ def get_workflow(self, workflow_id: str) -> WorkflowDetailsResponse | None: with self._s3_file_system.open(f"{settings.S3_BUCKET}/{workflow_id}/compiled.json", "w") as f: f.write(json.dumps(compiled_results)) # Generate presigned download URL for the object - download_url = self._s3_client.generate_presigned_url( + download_url = await self._s3_file_system.s3.generate_presigned_url( "get_object", Params={ "Bucket": settings.S3_BUCKET, @@ -424,10 +422,9 @@ def list_jobs(self, workflow_id: str): class MLflowClientManager: """Connection manager for MLflow client.""" - def __init__(self, tracking_uri: str, s3_file_system: S3FileSystem, s3_client: S3Client): + def __init__(self, tracking_uri: str, s3_file_system: S3FileSystem): self._tracking_uri = tracking_uri self._s3_file_system = s3_file_system - self._s3_client = s3_client @contextlib.contextmanager def connect(self) -> Generator[TrackingClient, None, None]: @@ -435,6 +432,5 @@ def connect(self) -> Generator[TrackingClient, None, None]: tracking_client = MLflowTrackingClient( tracking_uri=self._tracking_uri, s3_file_system=self._s3_file_system, - s3_client=self._s3_client, ) yield tracking_client diff --git a/lumigator/backend/backend/tracking/tracking_interface.py b/lumigator/backend/backend/tracking/tracking_interface.py index 725407610..4b142b1e8 100644 --- a/lumigator/backend/backend/tracking/tracking_interface.py +++ b/lumigator/backend/backend/tracking/tracking_interface.py @@ -27,7 +27,7 @@ def create_experiment( """Create a new experiment.""" ... - def get_experiment(self, experiment_id: str) -> GetExperimentResponse | None: + async def get_experiment(self, experiment_id: str) -> GetExperimentResponse | None: """Get an experiment.""" ... @@ -39,11 +39,11 @@ def delete_experiment(self, experiment_id: str) -> None: """Delete an experiment.""" ... - def list_experiments(self, skip: int, limit: int) -> list[GetExperimentResponse]: + async def list_experiments(self, skip: int, limit: int) -> list[GetExperimentResponse]: """List all experiments.""" ... - def experiments_count(self) -> int: + async def experiments_count(self) -> int: """Count all experiments.""" ... @@ -53,7 +53,7 @@ def create_workflow( """Create a new workflow.""" ... - def get_workflow(self, workflow_id: str) -> WorkflowDetailsResponse | None: + async def get_workflow(self, workflow_id: str) -> WorkflowDetailsResponse | None: """Get a workflow.""" ... diff --git a/lumigator/backend/pyproject.toml b/lumigator/backend/pyproject.toml index afc13c929..4fcbbc9a4 100644 --- a/lumigator/backend/pyproject.toml +++ b/lumigator/backend/pyproject.toml @@ -24,12 +24,13 @@ dependencies = [ "alembic>=1.13.3", "lumigator-schemas", "mlflow>=2.20.3", - "cryptography>=43.0.0", + "cryptography>=43.0.0" ] [dependency-groups] dev = [ "pytest>=8.3.3", + "pytest-asyncio>=0.25.3", "requests-mock>=1.12.1", "moto[s3]>=5.0,<6", "debugpy>=1.8.11" @@ -37,3 +38,6 @@ dev = [ [tool.uv.sources] lumigator-schemas = { path = "../schemas", editable = true } + +[tool.pytest.ini_options] +asyncio_mode = "auto" diff --git a/lumigator/backend/uv.lock b/lumigator/backend/uv.lock index 0f0de21d0..1a189b719 100644 --- a/lumigator/backend/uv.lock +++ b/lumigator/backend/uv.lock @@ -182,6 +182,7 @@ dependencies = [ { name = "psycopg2-binary" }, { name = "pydantic" }, { name = "pydantic-settings" }, + { name = "pytest-asyncio" }, { name = "python-dotenv" }, { name = "ray", extra = ["client"] }, { name = "requests" }, @@ -213,6 +214,7 @@ requires-dist = [ { name = "psycopg2-binary", specifier = "==2.9.9" }, { name = "pydantic", specifier = ">=2.10.0" }, { name = "pydantic-settings", specifier = "==2.2.1" }, + { name = "pytest-asyncio", specifier = ">=0.25.3" }, { name = "python-dotenv", specifier = ">=1.0.1" }, { name = "ray", extras = ["client"], specifier = "==2.30.0" }, { name = "requests", specifier = ">=2,<3" }, @@ -2192,6 +2194,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/11/92/76a1c94d3afee238333bc0a42b82935dd8f9cf8ce9e336ff87ee14d9e1cf/pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6", size = 343083 }, ] +[[package]] +name = "pytest-asyncio" +version = "0.25.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f2/a8/ecbc8ede70921dd2f544ab1cadd3ff3bf842af27f87bbdea774c7baa1d38/pytest_asyncio-0.25.3.tar.gz", hash = "sha256:fc1da2cf9f125ada7e710b4ddad05518d4cee187ae9412e9ac9271003497f07a", size = 54239 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/67/17/3493c5624e48fd97156ebaec380dcaafee9506d7e2c46218ceebbb57d7de/pytest_asyncio-0.25.3-py3-none-any.whl", hash = "sha256:9e89518e0f9bd08928f97a3482fdc4e244df17529460bc038291ccaf8f85c7c3", size = 19467 }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0" diff --git a/lumigator/schemas/uv.lock b/lumigator/schemas/uv.lock index d0255ea50..128c7f5f0 100644 --- a/lumigator/schemas/uv.lock +++ b/lumigator/schemas/uv.lock @@ -35,7 +35,7 @@ wheels = [ [[package]] name = "lumigator-schemas" -version = "0.1.2a0" +version = "0.1.3a0" source = { editable = "." } dependencies = [ { name = "pydantic" }, diff --git a/lumigator/sdk/uv.lock b/lumigator/sdk/uv.lock index 104ead946..33e974271 100644 --- a/lumigator/sdk/uv.lock +++ b/lumigator/sdk/uv.lock @@ -141,7 +141,7 @@ wheels = [ [[package]] name = "lumigator-schemas" -version = "0.1.2a0" +version = "0.1.3a0" source = { editable = "../schemas" } dependencies = [ { name = "pydantic" }, @@ -155,7 +155,7 @@ dev = [{ name = "pytest", specifier = ">=8.3.3" }] [[package]] name = "lumigator-sdk" -version = "0.1.2a0" +version = "0.1.3a0" source = { editable = "." } dependencies = [ { name = "loguru" },