Skip to content
Draft
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/specs/openapi.json

Large diffs are not rendered by default.

72 changes: 47 additions & 25 deletions lumigator/backend/backend/services/jobs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import csv
import json
import time
from http import HTTPStatus
from io import BytesIO, StringIO
from pathlib import Path
Expand Down Expand Up @@ -61,7 +62,7 @@

DEFAULT_SKIP = 0
DEFAULT_LIMIT = 100
DEFAULT_POST_INFER_JOB_TIMEOUT_SEC = 10 * 60
DEFAULT_POST_INFER_JOB_TIMEOUT_SEC = 5 * 60
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes it's hard to pin the right time. I'd suggest using the extra param to set the waiting time in tests and avoid modifying the default, but let's wait until we get feedback from the people running their own jobs.

JobSpecificRestrictedConfig = type[JobEvalConfig | JobInferenceConfig]


Expand Down Expand Up @@ -135,8 +136,8 @@ async def stop_job(self, job_id: UUID) -> bool:
return True

try:
status = await self.wait_for_job_complete(job_id, max_wait_time_sec=10)
except JobUpstreamError as e:
status = await self.wait_for_job_complete(job_id, timeout_seconds=10)
except TimeoutError as e:
loguru.logger.error("Failed to stop job {}: {}", job_id, e)
return False

Expand Down Expand Up @@ -339,34 +340,50 @@ def _retrieve_job_logs(self, job_id: UUID) -> JobLogsResponse:
except json.JSONDecodeError as e:
raise JobUpstreamError("ray", f"JSON decode error from {resp.text or ''}") from e

async def wait_for_job_complete(self, job_id, max_wait_time_sec):
async def wait_for_job_complete(
self,
job_id: UUID,
timeout_seconds: int = 300,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No defaults at this place. Defaults only at top level, please. Also, I'm not sure we need a complex backoff scheme, but I don't have firm arguments against it at the moment :-/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean just the timeout_seconds or all the defaults there?

I understand why you might want constant defaults declared separately to help with centralisation/management, but to me it makes sense having them inline at the moment. Here's my summary of 'why'...

  • Function signature clearly provides the values
  • Reduces cognitive load in having to track/lookup the defaults
  • Encapsulates the logic within the required scope, changing them doesn't impact other code not related to waiting for a job
  • Not being used anywhere else at the moment

If we need to re-use them later, we can just extract them to private class constants rather than at the top of the file?

Not that it means "it's fine" but we also already have inline defaults all through the code, in function calls and schemas.

initial_poll_interval_seconds: float = 1.0,
max_poll_interval_seconds: float = 10.0,
backoff_factor: float = 1.5,
):
"""Waits for a job to complete, or until a maximum wait time is reached.

:param job_id: The ID of the job to wait for.
:param max_wait_time_sec: The maximum time in seconds to wait for the job to complete.
:return: The status of the job when it completes.
:rtype: str
:raises JobUpstreamError: If there is an error with the upstream service returning the
job status
:param timeout_seconds: The maximum time in seconds to wait for the job to complete.
:param initial_poll_interval_seconds: The initial time in seconds to wait between polling the job status.
:param max_poll_interval_seconds: The maximum time in seconds to wait between polling the job status.
:param backoff_factor: The factor by which the poll interval will increase after each poll.
:return str: The status of the job when it completes.
:raises TimeoutError: If the job does not complete within the timeout period.
"""
loguru.logger.info(f"Waiting for job {job_id} to complete...")
# Get the initial job status
job_status = self.get_upstream_job_status(job_id)
start_time = time.time()
poll_interval = initial_poll_interval_seconds
previous_status = ""
loguru.logger.info(f"Waiting for job {job_id} to complete (timeout {timeout_seconds} seconds)...")

# Wait for the job to complete
elapsed_time = 0
while job_status not in self.TERMINAL_STATUS:
if elapsed_time >= max_wait_time_sec:
loguru.logger.info(f"Job {job_id} did not complete within the maximum wait time.")
break
await asyncio.sleep(5)
elapsed_time += 5
job_status = self.get_upstream_job_status(job_id)
while time.time() - start_time < timeout_seconds:
try:
job_status = self.get_upstream_job_status(job_id)

if job_status in self.TERMINAL_STATUS:
# Once the job is finished, retrieve the log and store it in the internal DB
self.get_job_logs(job_id)
loguru.logger.info(f"Job {job_id}, terminal status: {job_status}")
return job_status

if job_status != previous_status:
loguru.logger.info(f"Job {job_id}, current status: {job_status}")
previous_status = job_status

# Once the job is finished, retrieve the log and store it in the internal db
self.get_job_logs(job_id)
except JobUpstreamError as e:
loguru.logger.error("Error waiting for job {}. Cannot get upstream status: {}", job_id, e)

return job_status
await asyncio.sleep(poll_interval)
poll_interval = min(poll_interval * backoff_factor, max_poll_interval_seconds)

raise TimeoutError(f"Job {job_id} did not complete within {timeout_seconds} seconds.")

async def handle_annotation_job(self, job_id: UUID, request: JobCreate, max_wait_time_sec: int):
"""Long term we maybe want to move logic about how to handle a specific job
Expand All @@ -381,8 +398,13 @@ async def handle_annotation_job(self, job_id: UUID, request: JobCreate, max_wait
dataset_filename = self._dataset_service.get_dataset(dataset_id=request.dataset).filename
dataset_filename = Path(dataset_filename).stem
dataset_filename = f"{dataset_filename}-annotated.csv"
job_status = ""

try:
job_status = await self.wait_for_job_complete(job_id, max_wait_time_sec)
except TimeoutError as e:
loguru.logger.error(f"Job {job_id} timed out after {max_wait_time_sec} seconds: {e}")

job_status = await self.wait_for_job_complete(job_id, max_wait_time_sec)
if job_status == JobStatus.SUCCEEDED.value:
self._add_dataset_to_db(job_id, request, self._dataset_service.s3_filesystem, dataset_filename)
else:
Expand Down
8 changes: 4 additions & 4 deletions lumigator/backend/backend/services/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,13 @@ async def _run_inference_eval_pipeline(
try:
# Wait for the inference job to 'complete'.
status = await self._job_service.wait_for_job_complete(
inference_job.id, max_wait_time_sec=request.job_timeout_sec
inference_job.id, timeout_seconds=request.job_timeout_sec
)

if status != JobStatus.SUCCEEDED:
# Trigger the failure handling logic
raise JobUpstreamError(f"Inference job {inference_job.id} failed with status {status}") from None
except JobUpstreamError as e:
except TimeoutError as e:
loguru.logger.error(
"Workflow pipeline error: Workflow {}. Inference job: {} failed: {}", workflow.id, inference_job.id, e
)
Expand Down Expand Up @@ -278,7 +278,7 @@ async def _run_inference_eval_pipeline(
try:
# wait for the evaluation job to complete
status = await self._job_service.wait_for_job_complete(
evaluation_job.id, max_wait_time_sec=request.job_timeout_sec
evaluation_job.id, timeout_seconds=request.job_timeout_sec
)

if status != JobStatus.SUCCEEDED:
Expand All @@ -287,7 +287,7 @@ async def _run_inference_eval_pipeline(

# TODO: Handle other error types that can be raised by the method.
self._job_service._validate_results(evaluation_job.id, self._dataset_service.s3_filesystem)
except (JobUpstreamError, ValidationError) as e:
except (TimeoutError, ValidationError) as e:
loguru.logger.error(
"Workflow pipeline error: Workflow {}. Evaluation job: {} failed: {}", workflow.id, evaluation_job.id, e
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
import requests
from fastapi.testclient import TestClient
from httpx import HTTPStatusError, RequestError
from inference.schemas import GenerationConfig, InferenceJobConfig, InferenceServerConfig
from loguru import logger
from lumigator_schemas.datasets import DatasetFormat, DatasetResponse
Expand All @@ -23,7 +24,7 @@
from lumigator_schemas.secrets import SecretUploadRequest
from lumigator_schemas.tasks import TaskType
from lumigator_schemas.workflows import WorkflowDetailsResponse, WorkflowResponse, WorkflowStatus
from pydantic import PositiveInt
from pydantic import PositiveInt, ValidationError

from backend.main import app
from backend.tests.conftest import (
Expand Down Expand Up @@ -261,7 +262,7 @@ def run_workflow(
"model": model,
"provider": "hf",
"experiment_id": experiment_id,
"job_timeout_sec": 1000,
"job_timeout_sec": 60 * 3,
}
# The timeout cannot be 0
if job_timeout_sec:
Expand Down Expand Up @@ -432,19 +433,49 @@ def test_job_non_existing(local_client: TestClient, dependency_overrides_service
assert response.json()["detail"] == f"Job with ID {non_existing_id} not found"


def wait_for_workflow_complete(local_client: TestClient, workflow_id: UUID):
def wait_for_workflow_complete(
local_client: TestClient,
workflow_id: UUID,
timeout_seconds: int = 300,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No defaults at this level, again.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous (test) code already had magic numbers further down the method (300 iterations and 1 second sleeps), I don't think this change actually makes things worse as it makes the values and purpose clear in the method signature.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that doesn't mean it was ok :)

initial_poll_interval_seconds: float = 1.0,
max_poll_interval_seconds: float = 10.0,
backoff_factor: float = 1.5,
):
start_time = time.time()
workflow_status = WorkflowStatus.CREATED
for _ in range(1, 300):
time.sleep(1)
workflow_details = WorkflowDetailsResponse.model_validate(local_client.get(f"/workflows/{workflow_id}").json())
workflow_status = WorkflowStatus(workflow_details.status)
if workflow_status in [WorkflowStatus.SUCCEEDED, WorkflowStatus.FAILED]:
logger.info(f"Workflow status: {workflow_status}")
break
if workflow_status not in [WorkflowStatus.SUCCEEDED, WorkflowStatus.FAILED]:
raise Exception(f"Stopped, job remains in {workflow_status} status")

return workflow_details
status_retrieved = False
poll_interval = initial_poll_interval_seconds

logger.info(f"Waiting for workflow {workflow_id} to complete (timeout {timeout_seconds} seconds)...")

while time.time() - start_time < timeout_seconds:
try:
response = local_client.get(f"/workflows/{workflow_id}")
response.raise_for_status()

workflow_details = WorkflowDetailsResponse.model_validate(response.json())
workflow_status = workflow_details.status
status_retrieved = True

if workflow_status in {WorkflowStatus.SUCCEEDED, WorkflowStatus.FAILED}:
logger.info(f"Workflow {workflow_id} completed with status: {workflow_status}")
return workflow_details

logger.info(f"Workflow {workflow_id}, current status: {workflow_status}")

except (RequestError, HTTPStatusError) as e:
logger.error(f"Workflow {workflow_id}, request failed (HTTP): {e}")
except ValidationError as e:
logger.error(f"Workflow {workflow_id}, response parse error: {e}")

time.sleep(poll_interval)
poll_interval = min(poll_interval * backoff_factor, max_poll_interval_seconds)

raise TimeoutError(
f"Workflow {workflow_id} did not complete within {timeout_seconds} seconds.(last status: {workflow_status})"
if status_retrieved
else "(status never retrieved)"
)


def _test_launch_job_with_secret(
Expand Down
21 changes: 11 additions & 10 deletions lumigator/jobs/inference/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
accelerate==1.1.1
datasets==2.20.0
accelerate==1.5.2
datasets==2.19.1
langcodes==3.5.0
litellm==1.60.6
loguru==0.7.2
pydantic>=2.10.0
python-box==7.2.0
requests-mock==1.12.1
s3fs==2024.5.0
litellm==1.63.12
loguru==0.7.3
numpy==1.26.3
pandas==2.2.3
pydantic==2.10.6
python-box==7.3.2
s3fs==2024.2.0
sentencepiece==0.2.0
torch==2.5.1
transformers==4.46.3
torch==2.6.0
transformers==4.49.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ividal can the ML team please check these versions just in case?

20 changes: 11 additions & 9 deletions lumigator/jobs/inference/requirements_cpu.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
--extra-index-url https://download.pytorch.org/whl/cpu
accelerate==1.1.1
datasets==2.20.0
accelerate==1.5.2
datasets==2.19.1
langcodes==3.5.0
litellm==1.60.4
loguru==0.7.2
pydantic>=2.10.0
python-box==7.2.0
s3fs
litellm==1.63.12
loguru==0.7.3
numpy==1.26.3
pandas==2.2.3
pydantic==2.10.6
python-box==7.3.2
s3fs==2024.2.0
sentencepiece==0.2.0
torch==2.5.1
transformers==4.46.3
torch==2.6.0
transformers==4.49.0
2 changes: 1 addition & 1 deletion lumigator/schemas/lumigator_schemas/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class WorkflowCreateRequest(BaseModel):
inference_output_field: str = "predictions"
config_template: str | None = None
generation_config: GenerationConfig = Field(default_factory=GenerationConfig)
job_timeout_sec: PositiveInt = 60 * 60
job_timeout_sec: PositiveInt = 60 * 5
# Eventually metrics should be managed by the experiment level https://github.com/mozilla-ai/lumigator/issues/1105
metrics: list[str] | None = None

Expand Down
14 changes: 11 additions & 3 deletions notebooks/walkthrough.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,8 @@
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"\n",
"# set this value to limit the evaluation to the first max_samples items (0=all)\n",
"max_samples = 10\n",
"# team_name is a way to group jobs together under the same namespace, feel free to customize it\n",
Expand All @@ -626,9 +628,15 @@
" job_config=infer_job_config,\n",
" )\n",
" job_infer_creation_result = lm_client.jobs.create_job(infer_job_create)\n",
" lm_client.jobs.wait_for_job(job_infer_creation_result.id)\n",
"\n",
" infer_dataset = lm_client.jobs.get_job_dataset(str(job_infer_creation_result.id))\n",
" try:\n",
" lm_client.jobs.wait_for_job(job_infer_creation_result.id)\n",
" except Exception as e:\n",
" print(f\"Job {job_infer_creation_result.id} error: {e}\")\n",
" continue\n",
"\n",
" # Allow a few seconds for the new dataset to be added now that the job has succeeded\n",
" time.sleep(10)\n",
" infer_dataset = lm_client.jobs.get_job_dataset(job_infer_creation_result.id)\n",
"\n",
" eval_job_config = JobEvalConfig(\n",
" metrics=[\"rouge\", \"meteor\", \"bertscore\"],\n",
Expand Down
Loading