Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Support /image/schema and /document/schema routes
  • Loading branch information
spillai committed Mar 5, 2025
commit 9b5163d56087196547c11a37c9ecc03c9b4c1121
147 changes: 124 additions & 23 deletions vlmrun/client/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,45 @@ def wait(self, id: str, timeout: int = 60, sleep: int = 1) -> PredictionResponse
class ImagePredictions(SchemaCastMixin, Predictions):
"""Image prediction resource for VLM Run API."""

@staticmethod
def _handle_images_or_urls(
images: Optional[List[Union[Path, Image.Image]]] = None,
urls: Optional[List[str]] = None
) -> List[str]:
"""Handle images and URLs.

Args:
images: List of images to handle
urls: List of URLs to handle
"""
# Input validation
if not images and not urls:
raise ValueError("Either `images` or `urls` must be provided")
if images and urls:
raise ValueError("Only one of `images` or `urls` can be provided")
if images:
# Check if all images are of the same type
image_type = type(images[0])
if not all(isinstance(image, image_type) for image in images):
raise ValueError("All images must be of the same type")
if isinstance(images[0], Path):
images = [Image.open(str(image)) for image in images]
elif isinstance(images[0], Image.Image):
pass
else:
raise ValueError("Image must be a path or a PIL Image")
images_data = [encode_image(image, format="JPEG") for image in images]
else:
# URL handling
if not urls:
raise ValueError("URLs list cannot be empty")
if not isinstance(urls[0], str):
raise ValueError("URLs must be strings")
if not all(isinstance(url, str) for url in urls):
raise ValueError("All URLs must be strings")
images_data = urls
return images_data

def generate(
self,
domain: str,
Expand Down Expand Up @@ -201,6 +240,7 @@ def generate(
raise ValueError("URLs must start with 'http'")
images_data = urls

images_data = self._handle_images_or_urls(images, urls)
additional_kwargs = {}
if config:
additional_kwargs["config"] = config.model_dump()
Expand All @@ -225,40 +265,47 @@ def generate(
self._cast_response_to_schema(prediction, domain, config)
return prediction

def schema(self,
images: Optional[List[Union[Path, Image.Image]]] = None,
urls: Optional[List[str]] = None,
) -> PredictionResponse:
"""Auto-generate a schema for a given image or document.

Args:
images: List of images to generate the schema from
urls: List of URLs to generate the schema from

Returns:
PredictionResponse: Prediction response
"""
images_data = self._handle_images_or_urls(images, urls)
response, status_code, headers = self._requestor.request(
method="POST",
url="image/schema",
data={"images": images_data},
)
if not isinstance(response, dict):
raise TypeError("Expected dict response")
prediction = PredictionResponse(**response)
prediction.response = SchemaResponse(**prediction.response)
return prediction



def FilePredictions(route: str):
"""File prediction resource for VLM Run API."""

class _FilePredictions(SchemaCastMixin, Predictions):
"""File prediction resource for VLM Run API."""

def generate(
def _handle_file_or_url(
self,
file: Optional[Union[Path, str]] = None,
url: Optional[str] = None,
model: str = "vlm-1",
domain: Optional[str] = None,
batch: bool = False,
config: Optional[GenerationConfig] = GenerationConfig(),
metadata: Optional[RequestMetadata] = RequestMetadata(),
callback_url: Optional[str] = None,
) -> PredictionResponse:
"""Generate a document prediction.

Args:
model: Model to use for prediction
file: File (pathlib.Path) or file_id to generate prediction from
url: URL to generate prediction from
domain: Domain to use for prediction
batch: Whether to run prediction in batch mode
config: GenerateConfig to use for prediction
metadata: Metadata to include in prediction
callback_url: URL to call when prediction is complete

Returns:
PredictionResponse: Prediction response
"""
) -> tuple[bool, str]:
"""Handle file or URL."""
is_url = False
file_or_url = None
if not file and not url:
raise ValueError("Either `file` or `url` must be provided")
if file and url:
Expand Down Expand Up @@ -290,6 +337,35 @@ def generate(
raise ValueError(
"File or URL must be a pathlib.Path, str, or AnyHttpUrl"
)
return is_url, file_or_url

def generate(
self,
file: Optional[Union[Path, str]] = None,
url: Optional[str] = None,
model: str = "vlm-1",
domain: Optional[str] = None,
batch: bool = False,
config: Optional[GenerationConfig] = GenerationConfig(),
metadata: Optional[RequestMetadata] = RequestMetadata(),
callback_url: Optional[str] = None,
) -> PredictionResponse:
"""Generate a document prediction.

Args:
model: Model to use for prediction
file: File (pathlib.Path) or file_id to generate prediction from
url: URL to generate prediction from
domain: Domain to use for prediction
batch: Whether to run prediction in batch mode
config: GenerateConfig to use for prediction
metadata: Metadata to include in prediction
callback_url: URL to call when prediction is complete

Returns:
PredictionResponse: Prediction response
"""
is_url, file_or_url = self._handle_file_or_url(file, url)

additional_kwargs = {}
if config:
Expand All @@ -315,6 +391,31 @@ def generate(
self._cast_response_to_schema(prediction, domain, config)
return prediction

def schema(self,
file: Optional[Union[Path, str]] = None,
url: Optional[str] = None,
) -> PredictionResponse:
"""Auto-generate a schema for a given document.

Args:
file: File (pathlib.Path) or file_id to generate the schema from
url: URL to generate the schema from

Returns:
PredictionResponse: Prediction response
"""
is_url, file_or_url = self._handle_file_or_url(file, url)
response, status_code, headers = self._requestor.request(
method="POST",
url=f"{route}/schema",
data={"url" if is_url else "file_id": file_or_url},
)
if not isinstance(response, dict):
raise TypeError("Expected dict response")
prediction = PredictionResponse(**response)
prediction.response = SchemaResponse(**prediction.response)
return prediction

return _FilePredictions


Expand Down
2 changes: 1 addition & 1 deletion vlmrun/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.2.8"
__version__ = "0.2.9"
Loading