1010 HubInfoResponse ,
1111 HubDomainsResponse ,
1212 HubSchemaQueryResponse ,
13+ FileResponse ,
14+ PredictionResponse ,
15+ FeedbackSubmitResponse ,
1316)
1417
1518
@@ -24,6 +27,20 @@ def mock_client(monkeypatch):
2427 """Mock the Client class."""
2528
2629 class MockClient :
30+ class AudioPredictions :
31+ def __init__ (self , client ):
32+ self ._client = client
33+
34+ def generate (self , * args , ** kwargs ):
35+ return PredictionResponse (
36+ id = "prediction1" ,
37+ status = "completed" ,
38+ created_at = "2024-01-01T00:00:00Z" ,
39+ completed_at = "2024-01-01T00:00:01Z" ,
40+ response = {"result" : "test" },
41+ usage = {"total_tokens" : 100 }
42+ )
43+
2744 def __init__ (self , api_key = None , base_url = None ):
2845 self .api_key = api_key or "test-key"
2946 self .base_url = base_url or "https://api.vlm.run"
@@ -33,9 +50,11 @@ def __init__(self, api_key=None, base_url=None):
3350 self .files = self .Files (self )
3451 self .models = self .Models (self )
3552 self .hub = self .Hub (self )
36- self .image = self .Image (self )
37- self .video = self .Video (self )
38- self .document = self .Document (self )
53+ self .image = self .ImagePredictions (self )
54+ self .video = self .VideoPredictions (self )
55+ self .document = self .DocumentPredictions (self )
56+ self .audio = self .AudioPredictions (self )
57+ self .feedback = self .Feedback (self )
3958
4059 class Dataset :
4160 def __init__ (self , client ):
@@ -99,44 +118,89 @@ def __init__(self, client):
99118 self ._client = client
100119
101120 def create (self , model , prompt , ** kwargs ):
102- return {"id" : "prediction1" }
121+ return PredictionResponse (
122+ id = "prediction1" ,
123+ status = "running" ,
124+ created_at = "2024-01-01T00:00:00Z" ,
125+ completed_at = None ,
126+ response = None ,
127+ usage = {"total_tokens" : 0 }
128+ )
103129
104130 def list (self ):
105- return [{"id" : "prediction1" , "status" : "running" }]
131+ return [PredictionResponse (
132+ id = "prediction1" ,
133+ status = "running" ,
134+ created_at = "2024-01-01T00:00:00Z" ,
135+ completed_at = None ,
136+ response = None ,
137+ usage = {"total_tokens" : 0 }
138+ )]
106139
107140 def get (self , prediction_id ):
108- return {"id" : prediction_id , "status" : "running" }
141+ return PredictionResponse (
142+ id = prediction_id ,
143+ status = "running" ,
144+ created_at = "2024-01-01T00:00:00Z" ,
145+ completed_at = None ,
146+ response = None ,
147+ usage = {"total_tokens" : 0 }
148+ )
149+
150+ def wait (self , prediction_id , timeout = 60 , sleep = 1 ):
151+ return PredictionResponse (
152+ id = prediction_id ,
153+ status = "completed" ,
154+ created_at = "2024-01-01T00:00:00Z" ,
155+ completed_at = "2024-01-01T00:00:01Z" ,
156+ response = {"result" : "test" },
157+ usage = {"total_tokens" : 100 }
158+ )
109159
110160 class Files :
111161 def __init__ (self , client ):
112162 self ._client = client
113163
114164 def list (self ):
115165 return [
116- {
117- "id" : "file1" ,
118- "filename" : "test.txt" ,
119- "size" : 100 ,
120- "created_at" : "2024-01-01" ,
121- }
166+ FileResponse (
167+ id = "file1" ,
168+ filename = "test.txt" ,
169+ bytes = b"test content" ,
170+ purpose = "assistants" ,
171+ created_at = "2024-01-01T00:00:00Z"
172+ )
122173 ]
123174
124175 def upload (self , file_path , purpose = "fine-tune" ):
125- return {"id" : "file1" , "filename" : file_path }
176+ return FileResponse (
177+ id = "file1" ,
178+ filename = str (file_path ),
179+ bytes = b"test content" ,
180+ purpose = purpose ,
181+ created_at = "2024-01-01T00:00:00Z"
182+ )
126183
127184 def get (self , file_id ):
128- return {
129- "id" : file_id ,
130- "filename" : "test.txt" ,
131- "size" : 100 ,
132- "created_at" : "2024-01-01" ,
133- }
185+ return FileResponse (
186+ id = file_id ,
187+ filename = "test.txt" ,
188+ bytes = b"test content" ,
189+ purpose = "assistants" ,
190+ created_at = "2024-01-01T00:00:00Z"
191+ )
134192
135193 def get_content (self , file_id ):
136194 return b"test content"
137195
138196 def delete (self , file_id ):
139- return True
197+ return FileResponse (
198+ id = file_id ,
199+ filename = "test.txt" ,
200+ bytes = b"test content" ,
201+ purpose = "assistants" ,
202+ created_at = "2024-01-01T00:00:00Z"
203+ )
140204
141205 class Models :
142206 def __init__ (self , client ):
@@ -175,26 +239,67 @@ def get_schema(self, domain):
175239 schema_hash = "abcd1234" ,
176240 )
177241
178- class Image :
242+ class ImagePredictions :
179243 def __init__ (self , client ):
180244 self ._client = client
181245
182246 def generate (self , * args , ** kwargs ):
183- return b"image data"
247+ return PredictionResponse (
248+ id = "prediction1" ,
249+ status = "completed" ,
250+ created_at = "2024-01-01T00:00:00Z" ,
251+ completed_at = "2024-01-01T00:00:01Z" ,
252+ response = {"result" : "test" },
253+ usage = {"total_tokens" : 100 }
254+ )
184255
185- class Video :
256+ class VideoPredictions :
186257 def __init__ (self , client ):
187258 self ._client = client
188259
189260 def generate (self , * args , ** kwargs ):
190- return b"video data"
261+ return PredictionResponse (
262+ id = "prediction1" ,
263+ status = "completed" ,
264+ created_at = "2024-01-01T00:00:00Z" ,
265+ completed_at = "2024-01-01T00:00:01Z" ,
266+ response = {"result" : "test" },
267+ usage = {"total_tokens" : 100 }
268+ )
191269
192- class Document :
270+ class DocumentPredictions :
193271 def __init__ (self , client ):
194272 self ._client = client
195273
196274 def generate (self , * args , ** kwargs ):
197- return b"document data"
275+ return PredictionResponse (
276+ id = "prediction1" ,
277+ status = "completed" ,
278+ created_at = "2024-01-01T00:00:00Z" ,
279+ completed_at = "2024-01-01T00:00:01Z" ,
280+ response = {"result" : "test" },
281+ usage = {"total_tokens" : 100 }
282+ )
283+
284+ class Feedback :
285+ def __init__ (self , client ):
286+ self ._client = client
287+
288+ def submit (self , id , label = None , notes = None , flag = None ):
289+ return FeedbackSubmitResponse (
290+ id = "feedback1" ,
291+ created_at = "2024-01-01T00:00:00Z" ,
292+ request_id = id ,
293+ response = label
294+ )
295+
296+ def get (self , id ):
297+ return FeedbackSubmitResponse (
298+ id = "feedback1" ,
299+ created_at = "2024-01-01T00:00:00Z" ,
300+ request_id = id ,
301+ response = None
302+ )
198303
199304 monkeypatch .setattr ("vlmrun.cli.cli.Client" , MockClient )
200305 return MockClient ()
0 commit comments