diff --git a/lumigator/backend/backend/tests/integration/api/routes/test_api_workflows.py b/lumigator/backend/backend/tests/integration/api/routes/test_api_workflows.py index 269420422..4b230c348 100644 --- a/lumigator/backend/backend/tests/integration/api/routes/test_api_workflows.py +++ b/lumigator/backend/backend/tests/integration/api/routes/test_api_workflows.py @@ -5,6 +5,7 @@ import pytest import requests from fastapi.testclient import TestClient +from httpx import HTTPStatusError, RequestError from inference.schemas import GenerationConfig, InferenceJobConfig, InferenceServerConfig from loguru import logger from lumigator_schemas.datasets import DatasetFormat, DatasetResponse @@ -23,10 +24,12 @@ from lumigator_schemas.secrets import SecretUploadRequest from lumigator_schemas.tasks import TaskType from lumigator_schemas.workflows import WorkflowDetailsResponse, WorkflowResponse, WorkflowStatus -from pydantic import PositiveInt +from pydantic import PositiveInt, ValidationError from backend.main import app from backend.tests.conftest import ( + MAX_POLLS, + POLL_WAIT_TIME, TEST_CAUSAL_MODEL, TEST_SEQ2SEQ_MODEL, wait_for_job, @@ -437,21 +440,72 @@ def test_job_non_existing(local_client: TestClient, dependency_overrides_service assert response.json()["detail"] == f"Job with ID {non_existing_id} not found" -def wait_for_workflow_complete(local_client: TestClient, workflow_id: UUID): - workflow_status = WorkflowStatus.CREATED - workflow_details = None +def wait_for_workflow_complete(local_client: TestClient, workflow_id: UUID) -> WorkflowDetailsResponse | None: + """Wait for the workflow to complete, including post-completion processing to create compiled results. - for _ in range(1, 300): - time.sleep(1) - workflow_details = WorkflowDetailsResponse.model_validate(local_client.get(f"/workflows/{workflow_id}").json()) - workflow_status = WorkflowStatus(workflow_details.status) - if workflow_status in [WorkflowStatus.SUCCEEDED, WorkflowStatus.FAILED]: - logger.info(f"Workflow status: {workflow_status}") - break - if workflow_status not in [WorkflowStatus.SUCCEEDED, WorkflowStatus.FAILED]: - raise Exception(f"Stopped, job remains in {workflow_status} status") + Makes a total of ``MAX_POLLS`` (as configured in the ``conftest.py``). + Sleeps for ``POLL_WAIT_TIME`` seconds between each poll (as configured in the ``conftest.py``). - return workflow_details + :param local_client: The test client. + :param workflow_id: The workflow ID of the workflow to wait for. + :return: The workflow details, or ``None`` if the workflow did not reach the required successful state + within the maximum number of polls. + """ + attempt = 0 + max_attempts = MAX_POLLS + wait_duration = POLL_WAIT_TIME + + while attempt < max_attempts: + # Allow the waiting interval if we're coming around again. + if attempt > 0: + time.sleep(wait_duration) + + attempt += 1 + try: + response = local_client.get(f"/workflows/{workflow_id}") + response.raise_for_status() + # Validation failure will raise an exception (``ValidationError``) which is fine + # as if we're getting a response we expect it to be valid. + workflow = WorkflowDetailsResponse.model_validate(response.json()) + except (RequestError, HTTPStatusError) as e: + # Log the error but allow us to retry the request until we've maxed out our attempts. + logger.warning(f"Workflow: {workflow_id}, request: ({attempt}/{max_attempts}) failed: {e}") + continue + + # Check if the workflow is not in a terminal state. + if workflow.status not in {WorkflowStatus.SUCCEEDED, WorkflowStatus.FAILED}: + logger.info( + f"Workflow: {workflow_id}, " + f"request: ({attempt}/{max_attempts}), " + f"status: {workflow.status} " + f"not in terminal state" + ) + continue + + # If the workflow failed, we can stop checking. + if workflow.status == WorkflowStatus.FAILED: + return None + + # The workflow was successful, but we need the artifacts download url to be populated. + if not workflow.artifacts_download_url: + logger.info( + f"Workflow: {workflow_id}, " + f"request: ({attempt}/{max_attempts}), " + f"status: {workflow.status} " + f"artifacts not ready" + ) + continue + + logger.info( + f"Workflow: {workflow_id}," + f"request: ({attempt}/{max_attempts}), " + f"status: {workflow.status} " + f"completed and processed)" + ) + return workflow + + # Couldn't get the workflow details within the maximum number of polls. + return None def _test_launch_job_with_secret(