-
Notifications
You must be signed in to change notification settings - Fork 24
Tweaks for waiting for jobs #1262
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 10 commits
ebab168
756dea7
9855851
141ba0d
226d937
da5c16d
c32ba97
b0c07b7
a3fbb2e
9749a19
42d9a00
7160c8d
a4cd751
d1d111f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| import asyncio | ||
| import csv | ||
| import json | ||
| import time | ||
| from http import HTTPStatus | ||
| from io import BytesIO, StringIO | ||
| from pathlib import Path | ||
|
|
@@ -61,7 +62,7 @@ | |
|
|
||
| DEFAULT_SKIP = 0 | ||
| DEFAULT_LIMIT = 100 | ||
| DEFAULT_POST_INFER_JOB_TIMEOUT_SEC = 10 * 60 | ||
| DEFAULT_POST_INFER_JOB_TIMEOUT_SEC = 5 * 60 | ||
| JobSpecificRestrictedConfig = type[JobEvalConfig | JobInferenceConfig] | ||
|
|
||
|
|
||
|
|
@@ -135,8 +136,8 @@ async def stop_job(self, job_id: UUID) -> bool: | |
| return True | ||
|
|
||
| try: | ||
| status = await self.wait_for_job_complete(job_id, max_wait_time_sec=10) | ||
| except JobUpstreamError as e: | ||
| status = await self.wait_for_job_complete(job_id, timeout_seconds=10) | ||
| except TimeoutError as e: | ||
| loguru.logger.error("Failed to stop job {}: {}", job_id, e) | ||
| return False | ||
|
|
||
|
|
@@ -339,34 +340,50 @@ def _retrieve_job_logs(self, job_id: UUID) -> JobLogsResponse: | |
| except json.JSONDecodeError as e: | ||
| raise JobUpstreamError("ray", f"JSON decode error from {resp.text or ''}") from e | ||
|
|
||
| async def wait_for_job_complete(self, job_id, max_wait_time_sec): | ||
| async def wait_for_job_complete( | ||
| self, | ||
| job_id: UUID, | ||
| timeout_seconds: int = 300, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No defaults at this place. Defaults only at top level, please. Also, I'm not sure we need a complex backoff scheme, but I don't have firm arguments against it at the moment :-/
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean just the I understand why you might want constant defaults declared separately to help with centralisation/management, but to me it makes sense having them inline at the moment. Here's my summary of 'why'...
If we need to re-use them later, we can just extract them to private class constants rather than at the top of the file? Not that it means "it's fine" but we also already have inline defaults all through the code, in function calls and schemas. |
||
| initial_poll_interval_seconds: float = 1.0, | ||
| max_poll_interval_seconds: float = 10.0, | ||
| backoff_factor: float = 1.5, | ||
| ): | ||
| """Waits for a job to complete, or until a maximum wait time is reached. | ||
|
|
||
| :param job_id: The ID of the job to wait for. | ||
| :param max_wait_time_sec: The maximum time in seconds to wait for the job to complete. | ||
| :return: The status of the job when it completes. | ||
| :rtype: str | ||
| :raises JobUpstreamError: If there is an error with the upstream service returning the | ||
| job status | ||
| :param timeout_seconds: The maximum time in seconds to wait for the job to complete. | ||
| :param initial_poll_interval_seconds: The initial time in seconds to wait between polling the job status. | ||
| :param max_poll_interval_seconds: The maximum time in seconds to wait between polling the job status. | ||
| :param backoff_factor: The factor by which the poll interval will increase after each poll. | ||
| :return str: The status of the job when it completes. | ||
| :raises TimeoutError: If the job does not complete within the timeout period. | ||
| """ | ||
| loguru.logger.info(f"Waiting for job {job_id} to complete...") | ||
| # Get the initial job status | ||
| job_status = self.get_upstream_job_status(job_id) | ||
| start_time = time.time() | ||
| poll_interval = initial_poll_interval_seconds | ||
| previous_status = "" | ||
| loguru.logger.info(f"Waiting for job {job_id} to complete (timeout {timeout_seconds} seconds)...") | ||
|
|
||
| # Wait for the job to complete | ||
| elapsed_time = 0 | ||
| while job_status not in self.TERMINAL_STATUS: | ||
| if elapsed_time >= max_wait_time_sec: | ||
| loguru.logger.info(f"Job {job_id} did not complete within the maximum wait time.") | ||
| break | ||
| await asyncio.sleep(5) | ||
| elapsed_time += 5 | ||
| job_status = self.get_upstream_job_status(job_id) | ||
| while time.time() - start_time < timeout_seconds: | ||
| try: | ||
| job_status = self.get_upstream_job_status(job_id) | ||
|
|
||
| if job_status in self.TERMINAL_STATUS: | ||
| # Once the job is finished, retrieve the log and store it in the internal DB | ||
| self.get_job_logs(job_id) | ||
| loguru.logger.info(f"Job {job_id}, terminal status: {job_status}") | ||
| return job_status | ||
|
|
||
| if job_status != previous_status: | ||
| loguru.logger.info(f"Job {job_id}, current status: {job_status}") | ||
| previous_status = job_status | ||
|
|
||
| # Once the job is finished, retrieve the log and store it in the internal db | ||
| self.get_job_logs(job_id) | ||
| except JobUpstreamError as e: | ||
| loguru.logger.error("Error waiting for job {}. Cannot get upstream status: {}", job_id, e) | ||
|
|
||
| return job_status | ||
| await asyncio.sleep(poll_interval) | ||
| poll_interval = min(poll_interval * backoff_factor, max_poll_interval_seconds) | ||
|
|
||
| raise TimeoutError(f"Job {job_id} did not complete within {timeout_seconds} seconds.") | ||
|
|
||
| async def handle_annotation_job(self, job_id: UUID, request: JobCreate, max_wait_time_sec: int): | ||
| """Long term we maybe want to move logic about how to handle a specific job | ||
|
|
@@ -381,8 +398,13 @@ async def handle_annotation_job(self, job_id: UUID, request: JobCreate, max_wait | |
| dataset_filename = self._dataset_service.get_dataset(dataset_id=request.dataset).filename | ||
| dataset_filename = Path(dataset_filename).stem | ||
| dataset_filename = f"{dataset_filename}-annotated.csv" | ||
| job_status = "" | ||
|
|
||
| try: | ||
| job_status = await self.wait_for_job_complete(job_id, max_wait_time_sec) | ||
| except TimeoutError as e: | ||
javiermtorres marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| loguru.logger.error(f"Job {job_id} timed out after {max_wait_time_sec} seconds: {e}") | ||
|
|
||
| job_status = await self.wait_for_job_complete(job_id, max_wait_time_sec) | ||
| if job_status == JobStatus.SUCCEEDED.value: | ||
| self._add_dataset_to_db(job_id, request, self._dataset_service.s3_filesystem, dataset_filename) | ||
| else: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,7 @@ | |
| import pytest | ||
| import requests | ||
| from fastapi.testclient import TestClient | ||
| from httpx import HTTPStatusError, RequestError | ||
| from inference.schemas import GenerationConfig, InferenceJobConfig, InferenceServerConfig | ||
| from loguru import logger | ||
| from lumigator_schemas.datasets import DatasetFormat, DatasetResponse | ||
|
|
@@ -23,7 +24,7 @@ | |
| from lumigator_schemas.secrets import SecretUploadRequest | ||
| from lumigator_schemas.tasks import TaskType | ||
| from lumigator_schemas.workflows import WorkflowDetailsResponse, WorkflowResponse, WorkflowStatus | ||
| from pydantic import PositiveInt | ||
| from pydantic import PositiveInt, ValidationError | ||
|
|
||
| from backend.main import app | ||
| from backend.tests.conftest import ( | ||
|
|
@@ -261,7 +262,7 @@ def run_workflow( | |
| "model": model, | ||
| "provider": "hf", | ||
| "experiment_id": experiment_id, | ||
| "job_timeout_sec": 1000, | ||
| "job_timeout_sec": 60 * 3, | ||
| } | ||
| # The timeout cannot be 0 | ||
| if job_timeout_sec: | ||
|
|
@@ -432,19 +433,49 @@ def test_job_non_existing(local_client: TestClient, dependency_overrides_service | |
| assert response.json()["detail"] == f"Job with ID {non_existing_id} not found" | ||
|
|
||
|
|
||
| def wait_for_workflow_complete(local_client: TestClient, workflow_id: UUID): | ||
| def wait_for_workflow_complete( | ||
| local_client: TestClient, | ||
| workflow_id: UUID, | ||
| timeout_seconds: int = 300, | ||
|
||
| initial_poll_interval_seconds: float = 1.0, | ||
| max_poll_interval_seconds: float = 10.0, | ||
| backoff_factor: float = 1.5, | ||
| ): | ||
| start_time = time.time() | ||
| workflow_status = WorkflowStatus.CREATED | ||
| for _ in range(1, 300): | ||
| time.sleep(1) | ||
| workflow_details = WorkflowDetailsResponse.model_validate(local_client.get(f"/workflows/{workflow_id}").json()) | ||
| workflow_status = WorkflowStatus(workflow_details.status) | ||
| if workflow_status in [WorkflowStatus.SUCCEEDED, WorkflowStatus.FAILED]: | ||
| logger.info(f"Workflow status: {workflow_status}") | ||
| break | ||
| if workflow_status not in [WorkflowStatus.SUCCEEDED, WorkflowStatus.FAILED]: | ||
| raise Exception(f"Stopped, job remains in {workflow_status} status") | ||
|
|
||
| return workflow_details | ||
| status_retrieved = False | ||
| poll_interval = initial_poll_interval_seconds | ||
|
|
||
| logger.info(f"Waiting for workflow {workflow_id} to complete (timeout {timeout_seconds} seconds)...") | ||
|
|
||
| while time.time() - start_time < timeout_seconds: | ||
| try: | ||
| response = local_client.get(f"/workflows/{workflow_id}") | ||
| response.raise_for_status() | ||
|
|
||
| workflow_details = WorkflowDetailsResponse.model_validate(response.json()) | ||
| workflow_status = workflow_details.status | ||
| status_retrieved = True | ||
|
|
||
| if workflow_status in {WorkflowStatus.SUCCEEDED, WorkflowStatus.FAILED}: | ||
| logger.info(f"Workflow {workflow_id} completed with status: {workflow_status}") | ||
| return workflow_details | ||
|
|
||
| logger.info(f"Workflow {workflow_id}, current status: {workflow_status}") | ||
|
|
||
| except (RequestError, HTTPStatusError) as e: | ||
| logger.error(f"Workflow {workflow_id}, request failed (HTTP): {e}") | ||
| except ValidationError as e: | ||
| logger.error(f"Workflow {workflow_id}, response parse error: {e}") | ||
|
|
||
| time.sleep(poll_interval) | ||
| poll_interval = min(poll_interval * backoff_factor, max_poll_interval_seconds) | ||
|
|
||
| raise TimeoutError( | ||
| f"Workflow {workflow_id} did not complete within {timeout_seconds} seconds.(last status: {workflow_status})" | ||
| if status_retrieved | ||
| else "(status never retrieved)" | ||
| ) | ||
|
|
||
|
|
||
| def _test_launch_job_with_secret( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,12 +1,13 @@ | ||
| accelerate==1.1.1 | ||
| datasets==2.20.0 | ||
| accelerate==1.5.2 | ||
| datasets==2.19.1 | ||
| langcodes==3.5.0 | ||
| litellm==1.60.6 | ||
| loguru==0.7.2 | ||
| pydantic>=2.10.0 | ||
| python-box==7.2.0 | ||
| requests-mock==1.12.1 | ||
| s3fs==2024.5.0 | ||
| litellm==1.63.12 | ||
| loguru==0.7.3 | ||
| numpy==1.26.3 | ||
| pandas==2.2.3 | ||
| pydantic==2.10.6 | ||
| python-box==7.3.2 | ||
| s3fs==2024.2.0 | ||
| sentencepiece==0.2.0 | ||
| torch==2.5.1 | ||
| transformers==4.46.3 | ||
| torch==2.6.0 | ||
| transformers==4.49.0 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ividal can the ML team please check these versions just in case? |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,12 +1,14 @@ | ||
| --extra-index-url https://download.pytorch.org/whl/cpu | ||
| accelerate==1.1.1 | ||
| datasets==2.20.0 | ||
| accelerate==1.5.2 | ||
| datasets==2.19.1 | ||
| langcodes==3.5.0 | ||
| litellm==1.60.4 | ||
| loguru==0.7.2 | ||
| pydantic>=2.10.0 | ||
| python-box==7.2.0 | ||
| s3fs | ||
| litellm==1.63.12 | ||
| loguru==0.7.3 | ||
| numpy==1.26.3 | ||
| pandas==2.2.3 | ||
| pydantic==2.10.6 | ||
| python-box==7.3.2 | ||
| s3fs==2024.2.0 | ||
| sentencepiece==0.2.0 | ||
| torch==2.5.1 | ||
| transformers==4.46.3 | ||
| torch==2.6.0 | ||
| transformers==4.49.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sometimes it's hard to pin the right time. I'd suggest using the extra param to set the waiting time in tests and avoid modifying the default, but let's wait until we get feedback from the people running their own jobs.