44from typer .testing import CliRunner
55
66from datetime import datetime
7+ from typing import List
78from vlmrun .client .types import (
89 ModelInfoResponse ,
9- DatasetResponse ,
10+ DatasetCreateResponse ,
1011 HubInfoResponse ,
1112 HubDomainsResponse ,
1213 HubSchemaQueryResponse ,
1314 FileResponse ,
1415 PredictionResponse ,
1516 FeedbackSubmitResponse ,
17+ CreditUsage ,
1618)
1719
1820
@@ -38,7 +40,7 @@ def generate(self, *args, **kwargs):
3840 created_at = "2024-01-01T00:00:00Z" ,
3941 completed_at = "2024-01-01T00:00:01Z" ,
4042 response = {"result" : "test" },
41- usage = {"total_tokens" : 100 }
43+ usage = {"total_tokens" : 100 },
4244 )
4345
4446 def __init__ (self , api_key = None , base_url = None ):
@@ -56,40 +58,6 @@ def __init__(self, api_key=None, base_url=None):
5658 self .audio = self .AudioPredictions (self )
5759 self .feedback = self .Feedback (self )
5860
59- class Dataset :
60- def __init__ (self , client ):
61- self ._client = client
62-
63- def create (
64- self ,
65- file_id : str ,
66- domain : str ,
67- dataset_name : str ,
68- dataset_type : str = "images" ,
69- ) -> DatasetResponse :
70- if dataset_type not in ["images" , "videos" , "documents" ]:
71- raise ValueError (
72- "dataset_type must be one of: images, videos, documents"
73- )
74- return DatasetResponse (
75- dataset_id = "dataset1" ,
76- dataset_uri = "gs://vlmrun-test-bucket/dataset1.tar.gz" ,
77- dataset_type = dataset_type ,
78- domain = domain ,
79- message = "Dataset created successfully" ,
80- created_at = datetime .fromisoformat ("2024-01-01T00:00:00+00:00" ),
81- )
82-
83- def get (self , dataset_id : str ) -> DatasetResponse :
84- return DatasetResponse (
85- dataset_id = "dataset1" ,
86- dataset_uri = "gs://vlmrun-test-bucket/dataset1.tar.gz" ,
87- dataset_type = "images" ,
88- domain = "test-domain" ,
89- message = "Dataset created successfully" ,
90- created_at = datetime .fromisoformat ("2024-01-01T00:00:00+00:00" ),
91- )
92-
9361 class FineTuning :
9462 def __init__ (self , client ):
9563 self ._client = client
@@ -124,18 +92,20 @@ def create(self, model, prompt, **kwargs):
12492 created_at = "2024-01-01T00:00:00Z" ,
12593 completed_at = None ,
12694 response = None ,
127- usage = {"total_tokens" : 0 }
95+ usage = {"total_tokens" : 0 },
12896 )
12997
13098 def list (self ):
131- return [PredictionResponse (
132- id = "prediction1" ,
133- status = "running" ,
134- created_at = "2024-01-01T00:00:00Z" ,
135- completed_at = None ,
136- response = None ,
137- usage = {"total_tokens" : 0 }
138- )]
99+ return [
100+ PredictionResponse (
101+ id = "prediction1" ,
102+ status = "running" ,
103+ created_at = "2024-01-01T00:00:00Z" ,
104+ completed_at = None ,
105+ response = None ,
106+ usage = {"total_tokens" : 0 },
107+ )
108+ ]
139109
140110 def get (self , prediction_id ):
141111 return PredictionResponse (
@@ -144,17 +114,17 @@ def get(self, prediction_id):
144114 created_at = "2024-01-01T00:00:00Z" ,
145115 completed_at = None ,
146116 response = None ,
147- usage = {"total_tokens" : 0 }
117+ usage = {"total_tokens" : 0 },
148118 )
149-
119+
150120 def wait (self , prediction_id , timeout = 60 , sleep = 1 ):
151121 return PredictionResponse (
152122 id = prediction_id ,
153123 status = "completed" ,
154124 created_at = "2024-01-01T00:00:00Z" ,
155125 completed_at = "2024-01-01T00:00:01Z" ,
156126 response = {"result" : "test" },
157- usage = {"total_tokens" : 100 }
127+ usage = {"total_tokens" : 100 },
158128 )
159129
160130 class Files :
@@ -166,28 +136,28 @@ def list(self):
166136 FileResponse (
167137 id = "file1" ,
168138 filename = "test.txt" ,
169- bytes = b"test content" ,
139+ bytes = 10 ,
170140 purpose = "assistants" ,
171- created_at = "2024-01-01T00:00:00Z"
141+ created_at = "2024-01-01T00:00:00Z" ,
172142 )
173143 ]
174144
175145 def upload (self , file_path , purpose = "fine-tune" ):
176146 return FileResponse (
177147 id = "file1" ,
178148 filename = str (file_path ),
179- bytes = b"test content" ,
149+ bytes = 10 ,
180150 purpose = purpose ,
181- created_at = "2024-01-01T00:00:00Z"
151+ created_at = "2024-01-01T00:00:00Z" ,
182152 )
183153
184154 def get (self , file_id ):
185155 return FileResponse (
186156 id = file_id ,
187157 filename = "test.txt" ,
188- bytes = b"test content" ,
158+ bytes = 10 ,
189159 purpose = "assistants" ,
190- created_at = "2024-01-01T00:00:00Z"
160+ created_at = "2024-01-01T00:00:00Z" ,
191161 )
192162
193163 def get_content (self , file_id ):
@@ -197,9 +167,9 @@ def delete(self, file_id):
197167 return FileResponse (
198168 id = file_id ,
199169 filename = "test.txt" ,
200- bytes = b"test content" ,
170+ bytes = 10 ,
201171 purpose = "assistants" ,
202- created_at = "2024-01-01T00:00:00Z"
172+ created_at = "2024-01-01T00:00:00Z" ,
203173 )
204174
205175 class Models :
@@ -250,7 +220,7 @@ def generate(self, *args, **kwargs):
250220 created_at = "2024-01-01T00:00:00Z" ,
251221 completed_at = "2024-01-01T00:00:01Z" ,
252222 response = {"result" : "test" },
253- usage = {"total_tokens" : 100 }
223+ usage = {"total_tokens" : 100 },
254224 )
255225
256226 class VideoPredictions :
@@ -264,7 +234,7 @@ def generate(self, *args, **kwargs):
264234 created_at = "2024-01-01T00:00:00Z" ,
265235 completed_at = "2024-01-01T00:00:01Z" ,
266236 response = {"result" : "test" },
267- usage = {"total_tokens" : 100 }
237+ usage = {"total_tokens" : 100 },
268238 )
269239
270240 class DocumentPredictions :
@@ -278,9 +248,76 @@ def generate(self, *args, **kwargs):
278248 created_at = "2024-01-01T00:00:00Z" ,
279249 completed_at = "2024-01-01T00:00:01Z" ,
280250 response = {"result" : "test" },
281- usage = {"total_tokens" : 100 }
251+ usage = {"total_tokens" : 100 },
282252 )
283253
254+ class Dataset :
255+ def __init__ (self , client ):
256+ self ._client = client
257+
258+ def create (
259+ self ,
260+ file_id : str ,
261+ domain : str ,
262+ dataset_name : str ,
263+ dataset_type : str = "images" ,
264+ ) -> DatasetCreateResponse :
265+ if dataset_type not in ["images" , "videos" , "documents" ]:
266+ raise ValueError (
267+ "dataset_type must be one of: images, videos, documents"
268+ )
269+ return DatasetCreateResponse (
270+ dataset_id = "dataset1" ,
271+ dataset_uri = "gs://vlmrun-test-bucket/dataset1.tar.gz" ,
272+ dataset_type = dataset_type ,
273+ dataset_name = dataset_name ,
274+ domain = domain ,
275+ message = "Dataset created successfully" ,
276+ created_at = datetime .fromisoformat ("2024-01-01T00:00:00+00:00" ),
277+ status = "pending" ,
278+ usage = CreditUsage (
279+ credits_used = 10 ,
280+ elements_processed = 10 ,
281+ element_type = "image" ,
282+ ),
283+ )
284+
285+ def get (self , dataset_id : str ) -> DatasetCreateResponse :
286+ return DatasetCreateResponse (
287+ dataset_id = "dataset1" ,
288+ dataset_uri = "gs://vlmrun-test-bucket/dataset1.tar.gz" ,
289+ dataset_type = "images" ,
290+ dataset_name = "test-dataset" ,
291+ domain = "test-domain" ,
292+ message = "Dataset created successfully" ,
293+ created_at = datetime .fromisoformat ("2024-01-01T00:00:00+00:00" ),
294+ status = "completed" ,
295+ usage = CreditUsage (
296+ credits_used = 10 ,
297+ elements_processed = 10 ,
298+ element_type = "image" ,
299+ ),
300+ )
301+
302+ def list (self ) -> List [DatasetCreateResponse ]:
303+ return [
304+ DatasetCreateResponse (
305+ dataset_id = "dataset1" ,
306+ dataset_uri = "gs://vlmrun-test-bucket/dataset1.tar.gz" ,
307+ dataset_type = "images" ,
308+ domain = "test-domain" ,
309+ dataset_name = "test-dataset" ,
310+ message = "Dataset created successfully" ,
311+ created_at = datetime .fromisoformat ("2024-01-01T00:00:00+00:00" ),
312+ status = "completed" ,
313+ usage = CreditUsage (
314+ credits_used = 10 ,
315+ elements_processed = 10 ,
316+ element_type = "image" ,
317+ ),
318+ )
319+ ]
320+
284321 class Feedback :
285322 def __init__ (self , client ):
286323 self ._client = client
@@ -290,15 +327,15 @@ def submit(self, id, label=None, notes=None, flag=None):
290327 id = "feedback1" ,
291328 created_at = "2024-01-01T00:00:00Z" ,
292329 request_id = id ,
293- response = label
330+ response = label ,
294331 )
295332
296333 def get (self , id ):
297334 return FeedbackSubmitResponse (
298335 id = "feedback1" ,
299336 created_at = "2024-01-01T00:00:00Z" ,
300337 request_id = id ,
301- response = None
338+ response = None ,
302339 )
303340
304341 monkeypatch .setattr ("vlmrun.cli.cli.Client" , MockClient )
0 commit comments