1010from vlmrun .common .image import encode_image
1111from vlmrun .client .base_requestor import APIRequestor
1212from vlmrun .types .abstract import VLMRunProtocol
13- from vlmrun .client .types import PredictionResponse , FileResponse
13+ from vlmrun .client .types import (
14+ PredictionResponse ,
15+ FileResponse ,
16+ GenerationConfig ,
17+ RequestMetadata ,
18+ )
1419
1520
1621class Predictions :
@@ -82,23 +87,19 @@ class ImagePredictions(Predictions):
8287 def generate (
8388 self ,
8489 images : list [Path | Image .Image ],
85- model : str ,
8690 domain : str ,
87- json_schema : dict | None = None ,
88- detail : str = "auto" ,
8991 batch : bool = False ,
90- metadata : dict = {},
92+ metadata : RequestMetadata | None = None ,
93+ config : GenerationConfig | None = None ,
9194 callback_url : str | None = None ,
9295 ) -> PredictionResponse :
9396 """Generate a document prediction.
9497
9598 Args:
9699 images: List of images to generate predictions from
97- model: Model to use for prediction
98100 domain: Domain to use for prediction
99- json_schema: JSON schema to use for prediction
100- detail: Detail level for prediction
101101 batch: Whether to run prediction in batch mode
102+ config: GenerateConfig to use for prediction
102103 metadata: Metadata to include in prediction
103104 callback_url: URL to call when prediction is complete
104105
@@ -117,18 +118,20 @@ def generate(
117118 else :
118119 raise ValueError ("Image must be a path or a PIL Image" )
119120
121+ additional_kwargs = {}
122+ if config :
123+ additional_kwargs ["config" ] = config .model_dump ()
124+ if metadata :
125+ additional_kwargs ["metadata" ] = metadata .model_dump ()
120126 response , status_code , headers = self ._requestor .request (
121127 method = "POST" ,
122128 url = "image/generate" ,
123129 data = {
124130 "image" : encode_image (images [0 ], format = "JPEG" ),
125- "model" : model ,
126131 "domain" : domain ,
127- "json_schema" : json_schema ,
128- "detail" : detail ,
129132 "batch" : batch ,
130- "metadata" : metadata ,
131133 "callback_url" : callback_url ,
134+ ** additional_kwargs ,
132135 },
133136 )
134137 if not isinstance (response , dict ):
@@ -144,64 +147,75 @@ class _FilePredictions(Predictions):
144147
145148 def generate (
146149 self ,
147- file_or_url : str | Path ,
148- model : str ,
149- domain : str ,
150- json_schema : dict | None = None ,
151- detail : str = "auto" ,
150+ file : Path | str | None = None ,
151+ url : str | None = None ,
152+ domain : str | None = None ,
152153 batch : bool = False ,
153- metadata : dict = {},
154+ config : GenerationConfig | None = GenerationConfig (),
155+ metadata : RequestMetadata | None = RequestMetadata (),
154156 callback_url : str | None = None ,
155157 ) -> PredictionResponse :
156158 """Generate a document prediction.
157159
158160 Args:
159- file_or_url : File (pathlib.Path) or file_id or URL to generate prediction from
160- model: Model to use for prediction
161+ file : File (pathlib.Path) or file_id to generate prediction from
162+ url: URL to generate prediction from
161163 domain: Domain to use for prediction
162- json_schema: JSON schema to use for prediction
163- detail: Detail level for prediction
164164 batch: Whether to run prediction in batch mode
165+ config: GenerateConfig to use for prediction
165166 metadata: Metadata to include in prediction
166167 callback_url: URL to call when prediction is complete
167168
168169 Returns:
169170 PredictionResponse: Prediction response
170171 """
171172 is_url = False
172- if isinstance (file_or_url , Path ):
173- logger .debug (
174- f"Uploading file [path={ file_or_url } , size={ file_or_url .stat ().st_size / 1024 / 1024 :.2f} MB] to VLM Run"
175- )
176- upload_response , _ , _ = self ._client .files .upload (
177- file = file_or_url , purpose = "assistants"
178- )
179- if not isinstance (upload_response , dict ):
180- raise TypeError ("Expected dict response" )
181- response = FileResponse (** upload_response )
182- logger .debug (
183- f"Uploaded file [file_id={ response .id } , name={ response .filename } ]"
184- )
185- file_or_url = response .id
186- elif isinstance (file_or_url , str ):
187- is_url = str (file_or_url ).startswith (("http://" , "https://" ))
173+ if not file and not url :
174+ raise ValueError ("Either `file` or `url` must be provided" )
175+ if file and url :
176+ raise ValueError ("Only one of `file` or `url` can be provided" )
177+ if file :
178+ if isinstance (file , Path ) or (
179+ isinstance (file , str ) and Path (file ).suffix
180+ ):
181+ logger .debug (
182+ f"Uploading file [path={ file } , size={ file .stat ().st_size / 1024 / 1024 :.2f} MB] to VLM Run"
183+ )
184+ response : FileResponse = self ._client .files .upload (
185+ file = Path (file ), purpose = "assistants"
186+ )
187+ logger .debug (
188+ f"Uploaded file [file_id={ response .id } , name={ response .filename } ]"
189+ )
190+ file_or_url = response .id
191+ elif isinstance (file , str ):
192+ logger .debug (f"Using file_id [file_id={ file } ]" )
193+ assert not Path (file ).suffix , "File must not have an extension"
194+ file_or_url = file
195+ else :
196+ raise ValueError ("File must be a pathlib.Path or a string" )
197+ elif url :
198+ is_url = True
199+ file_or_url = url
188200 else :
189201 raise ValueError (
190202 "File or URL must be a pathlib.Path, str, or AnyHttpUrl"
191203 )
192204
205+ additional_kwargs = {}
206+ if config :
207+ additional_kwargs ["config" ] = config .model_dump ()
208+ if metadata :
209+ additional_kwargs ["metadata" ] = metadata .model_dump ()
193210 response , status_code , headers = self ._requestor .request (
194211 method = "POST" ,
195212 url = f"{ route } /generate" ,
196213 data = {
197214 "url" if is_url else "file_id" : file_or_url ,
198- "model" : model ,
199215 "domain" : domain ,
200- "json_schema" : json_schema ,
201- "detail" : detail ,
202216 "batch" : batch ,
203- "metadata" : metadata ,
204217 "callback_url" : callback_url ,
218+ ** additional_kwargs ,
205219 },
206220 )
207221 if not isinstance (response , dict ):
0 commit comments