Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions lumigator/backend/backend/services/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,13 +212,18 @@ def _results_to_binary_file(self, results: dict[str, Any], fields: list[str]) ->
return bin_data

def _add_dataset_to_db(
self, job_id: UUID, request: JobCreate, s3: S3FileSystem, dataset_filename: str, is_gt_generated: bool = True
self,
job_id: UUID,
request: JobCreate,
s3_file_system: S3FileSystem,
dataset_filename: str,
is_gt_generated: bool = True,
):
"""Attempts to add the result of a job (generated dataset) as a new dataset in Lumigator.

:param job_id: The ID of the job, used to identify the S3 path
:param request: The job request containing the dataset and output fields
:param s3: The S3 filesystem dependency for accessing storage
:param s3_file_system: The S3 filesystem dependency for accessing storage
:raises DatasetNotFoundError: If the dataset in the request does not exist
:raises DatasetSizeError: if the dataset is too large
:raises DatasetInvalidError: if the dataset is invalid
Expand All @@ -228,7 +233,7 @@ def _add_dataset_to_db(
loguru.logger.info("Adding a new dataset entry to the database...")

# Get the dataset from the S3 bucket
results = self._validate_results(job_id, s3)
results = self._validate_results(job_id, s3_file_system)

# make sure the artifacts are present in the results
required_keys = {"examples", "ground_truth", request.job_config.output_field}
Expand Down Expand Up @@ -261,12 +266,12 @@ def _add_dataset_to_db(

loguru.logger.info(f"Dataset '{dataset_filename}' with ID '{dataset_record.id}' added to the database.")

def _validate_results(self, job_id: UUID, s3: S3FileSystem) -> JobResultObject:
def _validate_results(self, job_id: UUID, s3_file_system: S3FileSystem) -> JobResultObject:
"""Handles the evaluation result for a given job.

Args:
job_id (UUID): The unique identifier of the job.
s3 (S3FileSystem): The S3 file system object used to interact with the S3 bucket.
s3_file_system (S3FileSystem): The S3 file system object used to interact with the S3 bucket.

Note:
Currently, this function only validates the evaluation result. Future implementations
Expand All @@ -275,8 +280,8 @@ def _validate_results(self, job_id: UUID, s3: S3FileSystem) -> JobResultObject:
loguru.logger.info("Handling evaluation result")

result_key = self._get_results_s3_key(job_id)
# TODO: Add dependency to the S3 service and use a path creation function.
with s3.open(f"{settings.S3_BUCKET}/{result_key}", "r") as f:
# TODO: use a path creation function.
with s3_file_system.open(f"{settings.S3_BUCKET}/{result_key}", "r") as f:
return JobResultObject.model_validate(json.loads(f.read()))

def get_upstream_job_status(self, job_id: UUID) -> str:
Expand Down Expand Up @@ -401,7 +406,7 @@ def create_job(
:param request: The job creation request.
:return: The job response.
:raises JobTypeUnsupportedError: If the job type is not supported.
:raises SecretNotFoundError: If the secret key identifying the API key required for the job is not found.
:raises JobValidationError: If the secret key identifying the API key required for the job is not found.
"""
# Typing won't allow other job_type's
job_type = request.job_config.job_type
Expand Down Expand Up @@ -477,7 +482,7 @@ def create_job(
runtime_env=runtime_env,
num_gpus=settings.RAY_WORKER_GPUS,
)
loguru.logger.info("Submitting {job_type} Ray job...")
loguru.logger.info(f"Submitting {job_type} Ray job...")
submit_ray_job(self.ray_client, entrypoint)

# NOTE: Only inference jobs can store results in a dataset atm. Among them:
Expand Down