Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@ video = [
doc = [
"pypdfium2>=4.30.0"
]
hub = [
"vlmrun-hub>=0.1.19a"
]
all = [
"numpy>=1.24.0",
"opencv-python>=4.8.0",
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ rich
tenacity
tqdm
typer>=0.9.0
vlmrun-hub>=0.1.28
54 changes: 38 additions & 16 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
ModelInfoResponse,
DatasetCreateResponse,
HubInfoResponse,
HubDomainInfo,
HubSchemaResponse,
HubDomainInfo,
FileResponse,
PredictionResponse,
FeedbackSubmitResponse,
CreditUsage,
)
from vlmrun.client.predictions import SchemaCastMixin


@pytest.fixture
Expand All @@ -29,19 +30,21 @@ def mock_client(monkeypatch):
"""Mock the VLMRun class."""

class MockVLMRun:
class AudioPredictions:
class AudioPredictions(SchemaCastMixin):
def __init__(self, client):
self._client = client

def generate(self, *args, **kwargs):
return PredictionResponse(
def generate(self, domain: str = None, **kwargs):
prediction = PredictionResponse(
id="prediction1",
status="completed",
created_at="2024-01-01T00:00:00+00:00",
completed_at="2024-01-01T00:00:01+00:00",
response={"result": "test"},
response={"invoice_number": "INV-001", "total_amount": 100.0},
usage=CreditUsage(credits_used=100),
)
self._cast_response_to_schema(prediction, domain, kwargs.get("config"))
return prediction

def __init__(self, api_key=None, base_url=None):
self.api_key = api_key or "test-key"
Expand Down Expand Up @@ -207,7 +210,18 @@ def get_schema(self, domain):
schema_hash="abcd1234",
)

class ImagePredictions:
def get_pydantic_model(self, domain: str):
"""Mock implementation for schema lookup."""
from pydantic import BaseModel

class MockInvoiceSchema(BaseModel):
invoice_number: str
total_amount: float

schemas = {"document.invoice": MockInvoiceSchema, "general": None}
return schemas.get(domain)

class ImagePredictions(SchemaCastMixin):
def __init__(self, client):
self._client = client

Expand All @@ -216,42 +230,50 @@ def generate(self, domain: str, images=None, urls=None, **kwargs):
raise ValueError("Either `images` or `urls` must be provided")
if images and urls:
raise ValueError("Only one of `images` or `urls` can be provided")
return PredictionResponse(

prediction = PredictionResponse(
id="prediction1",
status="completed",
created_at="2024-01-01T00:00:00+00:00",
completed_at="2024-01-01T00:00:01+00:00",
response={"result": "test"},
response={"invoice_number": "INV-001", "total_amount": 100.0},
usage=CreditUsage(credits_used=100),
)

class VideoPredictions:
self._cast_response_to_schema(prediction, domain, kwargs.get("config"))
return prediction

class VideoPredictions(SchemaCastMixin):
def __init__(self, client):
self._client = client

def generate(self, *args, **kwargs):
return PredictionResponse(
def generate(self, domain: str = None, **kwargs):
prediction = PredictionResponse(
id="prediction1",
status="completed",
created_at="2024-01-01T00:00:00+00:00",
completed_at="2024-01-01T00:00:01+00:00",
response={"result": "test"},
response={"invoice_number": "INV-001", "total_amount": 100.0},
usage=CreditUsage(credits_used=100),
)
self._cast_response_to_schema(prediction, domain, kwargs.get("config"))
return prediction

class DocumentPredictions:
class DocumentPredictions(SchemaCastMixin):
def __init__(self, client):
self._client = client

def generate(self, *args, **kwargs):
return PredictionResponse(
def generate(self, domain: str = None, **kwargs):
prediction = PredictionResponse(
id="prediction1",
status="completed",
created_at="2024-01-01T00:00:00+00:00",
completed_at="2024-01-01T00:00:01+00:00",
response={"result": "test"},
response={"invoice_number": "INV-001", "total_amount": 100.0},
usage=CreditUsage(credits_used=100),
)
self._cast_response_to_schema(prediction, domain, kwargs.get("config"))
return prediction

class Dataset:
def __init__(self, client):
Expand Down
61 changes: 59 additions & 2 deletions tests/test_predictions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
"""Tests for predictions operations."""

import pytest

from pydantic import BaseModel
from PIL import Image
from vlmrun.client.types import PredictionResponse
from vlmrun.client.types import PredictionResponse, GenerationConfig


class MockInvoiceSchema(BaseModel):
"""Mock invoice schema for testing."""

invoice_number: str
total_amount: float


def test_list_predictions(mock_client):
Expand Down Expand Up @@ -129,3 +136,53 @@ def test_audio_generate(mock_client, tmp_path):
json_schema={"type": "object"},
)
assert isinstance(response, PredictionResponse)


def test_schema_casting_with_domain(mock_client):
"""Test response casting using domain schema."""

def mock_get_schema(domain):
return MockInvoiceSchema

mock_client.hub.get_pydantic_model = mock_get_schema

response = mock_client.image.generate(
domain="document.invoice", urls=["https://example.com/test.jpg"]
)

assert isinstance(response.response, MockInvoiceSchema)


def test_schema_casting_with_custom_schema(mock_client):
"""Test response casting using custom schema from GenerationConfig."""
response = mock_client.image.generate(
domain="document.invoice",
urls=["https://example.com/test.jpg"],
config=GenerationConfig(json_schema=MockInvoiceSchema.model_json_schema()),
)

assert response.response.invoice_number == "INV-001"
assert response.response.total_amount == 100.0


@pytest.mark.parametrize("prediction_type", ["image", "document", "video", "audio"])
def test_schema_casting_across_prediction_types(mock_client, prediction_type):
"""Test schema casting works consistently across different prediction types."""

def mock_get_schema(domain):
return MockInvoiceSchema

mock_client.hub.get_pydantic_model = mock_get_schema

pred_client = getattr(mock_client, prediction_type)

if prediction_type == "image":
response = pred_client.generate(
domain="document.invoice", urls=["https://example.com/test.jpg"]
)
else:
response = pred_client.generate(
domain="document.invoice", url="https://example.com/test.file"
)

assert isinstance(response.response, MockInvoiceSchema)
21 changes: 20 additions & 1 deletion vlmrun/client/hub.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""VLM Run Hub API implementation."""

from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING, List, Type
from pydantic import BaseModel

from vlmrun.client.base_requestor import APIError
from vlmrun.client.types import (
HubSchemaResponse,
HubInfoResponse,
HubDomainInfo,
)
from vlmrun.hub.registry import registry

if TYPE_CHECKING:
from vlmrun.types.abstract import VLMRunProtocol
Expand Down Expand Up @@ -116,3 +118,20 @@ def get_schema(self, domain: str) -> HubSchemaResponse:
return HubSchemaResponse(**response)
except Exception as e:
raise APIError(f"Failed to get schema for domain {domain}: {str(e)}")

def get_pydantic_model(self, domain: str) -> Type[BaseModel]:
"""Get the Pydantic model for a given domain.

Args:
domain: Domain identifier (e.g. "document.invoice")

Returns:
Type[BaseModel]: The Pydantic model class for the domain

Raises:
APIError: If the domain is not found
"""
try:
return registry[domain]
except KeyError:
raise APIError(f"Domain not found: {domain}")
44 changes: 40 additions & 4 deletions vlmrun/client/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,36 @@
GenerationConfig,
RequestMetadata,
)
from vlmrun.hub.utils import jsonschema_to_model


class SchemaCastMixin:
"""Mixin class to handle schema casting for predictions."""

def _cast_response_to_schema(
self,
prediction: PredictionResponse,
domain: str,
config: Optional[GenerationConfig] = None,
) -> None:
"""Cast prediction response to appropriate schema.

Args:
prediction: PredictionResponse to cast
domain: Domain identifier
config: Optional GenerationConfig with custom schema
"""
if prediction.status == "completed" and prediction.response:
try:
if config and hasattr(config, "json_schema"):
schema = jsonschema_to_model(config.json_schema)
else:
schema = self._client.hub.get_pydantic_model(domain)

if schema:
prediction.response = schema(**prediction.response)
except Exception as e:
logger.debug(f"Failed to cast response to schema: {e}")


class Predictions:
Expand Down Expand Up @@ -81,7 +111,7 @@ def wait(self, id: str, timeout: int = 60, sleep: int = 1) -> PredictionResponse
raise TimeoutError(f"Prediction {id} did not complete within {timeout} seconds")


class ImagePredictions(Predictions):
class ImagePredictions(SchemaCastMixin, Predictions):
"""Image prediction resource for VLM Run API."""

def generate(
Expand Down Expand Up @@ -157,13 +187,16 @@ def generate(
)
if not isinstance(response, dict):
raise TypeError("Expected dict response")
return PredictionResponse(**response)
prediction = PredictionResponse(**response)

self._cast_response_to_schema(prediction, domain, config)
return prediction


def FilePredictions(route: str):
"""File prediction resource for VLM Run API."""

class _FilePredictions(Predictions):
class _FilePredictions(SchemaCastMixin, Predictions):
"""File prediction resource for VLM Run API."""

def generate(
Expand Down Expand Up @@ -241,7 +274,10 @@ def generate(
)
if not isinstance(response, dict):
raise TypeError("Expected dict response")
return PredictionResponse(**response)
prediction = PredictionResponse(**response)

self._cast_response_to_schema(prediction, domain, config)
return prediction

return _FilePredictions

Expand Down
2 changes: 1 addition & 1 deletion vlmrun/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.2.0"
__version__ = "0.2.1"
Loading