Skip to content

Commit 9a54ffd

Browse files
authored
Fully functional datasets support (#41)
- updated `files` and `dataset` CLI
1 parent 4fc0b5c commit 9a54ffd

File tree

18 files changed

+273
-197
lines changed

18 files changed

+273
-197
lines changed

tests/common/test_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from vlmrun.common.utils import download_artifact
1+
from pathlib import Path
2+
3+
from vlmrun.common.utils import download_artifact, create_archive
24

35
PDF_URL = "https://storage.googleapis.com/vlm-data-public-prod/hub/examples/document.bank-statement/lending_bankstatement.pdf"
46

@@ -7,3 +9,20 @@ def test_download_artifact():
79
"""Test that download_artifact can download a PDF."""
810
pdf = download_artifact(PDF_URL, "file")
911
assert pdf.exists()
12+
13+
14+
def test_create_archive():
15+
"""Test that create_archive can create a tar.gz file."""
16+
import tarfile
17+
18+
archive_path: Path = create_archive(
19+
Path(__file__).parent.parent / "test_data/image_dataset", "test_image_dataset"
20+
)
21+
assert archive_path.exists()
22+
assert archive_path.name.endswith(".tar.gz")
23+
24+
# Unzip the archive and check if there is a folder with the same name as the stem
25+
stem = archive_path.name.replace(".tar.gz", "")
26+
with tarfile.open(archive_path, "r:gz") as tar:
27+
assert len(tar.getmembers()) == 4 # basedir + 3 images
28+
assert tar.getmembers()[0].name == stem

tests/conftest.py

Lines changed: 98 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,17 @@
44
from typer.testing import CliRunner
55

66
from datetime import datetime
7+
from typing import List
78
from vlmrun.client.types import (
89
ModelInfoResponse,
9-
DatasetResponse,
10+
DatasetCreateResponse,
1011
HubInfoResponse,
1112
HubDomainsResponse,
1213
HubSchemaQueryResponse,
1314
FileResponse,
1415
PredictionResponse,
1516
FeedbackSubmitResponse,
17+
CreditUsage,
1618
)
1719

1820

@@ -38,7 +40,7 @@ def generate(self, *args, **kwargs):
3840
created_at="2024-01-01T00:00:00Z",
3941
completed_at="2024-01-01T00:00:01Z",
4042
response={"result": "test"},
41-
usage={"total_tokens": 100}
43+
usage={"total_tokens": 100},
4244
)
4345

4446
def __init__(self, api_key=None, base_url=None):
@@ -56,40 +58,6 @@ def __init__(self, api_key=None, base_url=None):
5658
self.audio = self.AudioPredictions(self)
5759
self.feedback = self.Feedback(self)
5860

59-
class Dataset:
60-
def __init__(self, client):
61-
self._client = client
62-
63-
def create(
64-
self,
65-
file_id: str,
66-
domain: str,
67-
dataset_name: str,
68-
dataset_type: str = "images",
69-
) -> DatasetResponse:
70-
if dataset_type not in ["images", "videos", "documents"]:
71-
raise ValueError(
72-
"dataset_type must be one of: images, videos, documents"
73-
)
74-
return DatasetResponse(
75-
dataset_id="dataset1",
76-
dataset_uri="gs://vlmrun-test-bucket/dataset1.tar.gz",
77-
dataset_type=dataset_type,
78-
domain=domain,
79-
message="Dataset created successfully",
80-
created_at=datetime.fromisoformat("2024-01-01T00:00:00+00:00"),
81-
)
82-
83-
def get(self, dataset_id: str) -> DatasetResponse:
84-
return DatasetResponse(
85-
dataset_id="dataset1",
86-
dataset_uri="gs://vlmrun-test-bucket/dataset1.tar.gz",
87-
dataset_type="images",
88-
domain="test-domain",
89-
message="Dataset created successfully",
90-
created_at=datetime.fromisoformat("2024-01-01T00:00:00+00:00"),
91-
)
92-
9361
class FineTuning:
9462
def __init__(self, client):
9563
self._client = client
@@ -124,18 +92,20 @@ def create(self, model, prompt, **kwargs):
12492
created_at="2024-01-01T00:00:00Z",
12593
completed_at=None,
12694
response=None,
127-
usage={"total_tokens": 0}
95+
usage={"total_tokens": 0},
12896
)
12997

13098
def list(self):
131-
return [PredictionResponse(
132-
id="prediction1",
133-
status="running",
134-
created_at="2024-01-01T00:00:00Z",
135-
completed_at=None,
136-
response=None,
137-
usage={"total_tokens": 0}
138-
)]
99+
return [
100+
PredictionResponse(
101+
id="prediction1",
102+
status="running",
103+
created_at="2024-01-01T00:00:00Z",
104+
completed_at=None,
105+
response=None,
106+
usage={"total_tokens": 0},
107+
)
108+
]
139109

140110
def get(self, prediction_id):
141111
return PredictionResponse(
@@ -144,17 +114,17 @@ def get(self, prediction_id):
144114
created_at="2024-01-01T00:00:00Z",
145115
completed_at=None,
146116
response=None,
147-
usage={"total_tokens": 0}
117+
usage={"total_tokens": 0},
148118
)
149-
119+
150120
def wait(self, prediction_id, timeout=60, sleep=1):
151121
return PredictionResponse(
152122
id=prediction_id,
153123
status="completed",
154124
created_at="2024-01-01T00:00:00Z",
155125
completed_at="2024-01-01T00:00:01Z",
156126
response={"result": "test"},
157-
usage={"total_tokens": 100}
127+
usage={"total_tokens": 100},
158128
)
159129

160130
class Files:
@@ -166,28 +136,28 @@ def list(self):
166136
FileResponse(
167137
id="file1",
168138
filename="test.txt",
169-
bytes=b"test content",
139+
bytes=10,
170140
purpose="assistants",
171-
created_at="2024-01-01T00:00:00Z"
141+
created_at="2024-01-01T00:00:00Z",
172142
)
173143
]
174144

175145
def upload(self, file_path, purpose="fine-tune"):
176146
return FileResponse(
177147
id="file1",
178148
filename=str(file_path),
179-
bytes=b"test content",
149+
bytes=10,
180150
purpose=purpose,
181-
created_at="2024-01-01T00:00:00Z"
151+
created_at="2024-01-01T00:00:00Z",
182152
)
183153

184154
def get(self, file_id):
185155
return FileResponse(
186156
id=file_id,
187157
filename="test.txt",
188-
bytes=b"test content",
158+
bytes=10,
189159
purpose="assistants",
190-
created_at="2024-01-01T00:00:00Z"
160+
created_at="2024-01-01T00:00:00Z",
191161
)
192162

193163
def get_content(self, file_id):
@@ -197,9 +167,9 @@ def delete(self, file_id):
197167
return FileResponse(
198168
id=file_id,
199169
filename="test.txt",
200-
bytes=b"test content",
170+
bytes=10,
201171
purpose="assistants",
202-
created_at="2024-01-01T00:00:00Z"
172+
created_at="2024-01-01T00:00:00Z",
203173
)
204174

205175
class Models:
@@ -250,7 +220,7 @@ def generate(self, *args, **kwargs):
250220
created_at="2024-01-01T00:00:00Z",
251221
completed_at="2024-01-01T00:00:01Z",
252222
response={"result": "test"},
253-
usage={"total_tokens": 100}
223+
usage={"total_tokens": 100},
254224
)
255225

256226
class VideoPredictions:
@@ -264,7 +234,7 @@ def generate(self, *args, **kwargs):
264234
created_at="2024-01-01T00:00:00Z",
265235
completed_at="2024-01-01T00:00:01Z",
266236
response={"result": "test"},
267-
usage={"total_tokens": 100}
237+
usage={"total_tokens": 100},
268238
)
269239

270240
class DocumentPredictions:
@@ -278,9 +248,76 @@ def generate(self, *args, **kwargs):
278248
created_at="2024-01-01T00:00:00Z",
279249
completed_at="2024-01-01T00:00:01Z",
280250
response={"result": "test"},
281-
usage={"total_tokens": 100}
251+
usage={"total_tokens": 100},
282252
)
283253

254+
class Dataset:
255+
def __init__(self, client):
256+
self._client = client
257+
258+
def create(
259+
self,
260+
file_id: str,
261+
domain: str,
262+
dataset_name: str,
263+
dataset_type: str = "images",
264+
) -> DatasetCreateResponse:
265+
if dataset_type not in ["images", "videos", "documents"]:
266+
raise ValueError(
267+
"dataset_type must be one of: images, videos, documents"
268+
)
269+
return DatasetCreateResponse(
270+
dataset_id="dataset1",
271+
dataset_uri="gs://vlmrun-test-bucket/dataset1.tar.gz",
272+
dataset_type=dataset_type,
273+
dataset_name=dataset_name,
274+
domain=domain,
275+
message="Dataset created successfully",
276+
created_at=datetime.fromisoformat("2024-01-01T00:00:00+00:00"),
277+
status="pending",
278+
usage=CreditUsage(
279+
credits_used=10,
280+
elements_processed=10,
281+
element_type="image",
282+
),
283+
)
284+
285+
def get(self, dataset_id: str) -> DatasetCreateResponse:
286+
return DatasetCreateResponse(
287+
dataset_id="dataset1",
288+
dataset_uri="gs://vlmrun-test-bucket/dataset1.tar.gz",
289+
dataset_type="images",
290+
dataset_name="test-dataset",
291+
domain="test-domain",
292+
message="Dataset created successfully",
293+
created_at=datetime.fromisoformat("2024-01-01T00:00:00+00:00"),
294+
status="completed",
295+
usage=CreditUsage(
296+
credits_used=10,
297+
elements_processed=10,
298+
element_type="image",
299+
),
300+
)
301+
302+
def list(self) -> List[DatasetCreateResponse]:
303+
return [
304+
DatasetCreateResponse(
305+
dataset_id="dataset1",
306+
dataset_uri="gs://vlmrun-test-bucket/dataset1.tar.gz",
307+
dataset_type="images",
308+
domain="test-domain",
309+
dataset_name="test-dataset",
310+
message="Dataset created successfully",
311+
created_at=datetime.fromisoformat("2024-01-01T00:00:00+00:00"),
312+
status="completed",
313+
usage=CreditUsage(
314+
credits_used=10,
315+
elements_processed=10,
316+
element_type="image",
317+
),
318+
)
319+
]
320+
284321
class Feedback:
285322
def __init__(self, client):
286323
self._client = client
@@ -290,15 +327,15 @@ def submit(self, id, label=None, notes=None, flag=None):
290327
id="feedback1",
291328
created_at="2024-01-01T00:00:00Z",
292329
request_id=id,
293-
response=label
330+
response=label,
294331
)
295332

296333
def get(self, id):
297334
return FeedbackSubmitResponse(
298335
id="feedback1",
299336
created_at="2024-01-01T00:00:00Z",
300337
request_id=id,
301-
response=None
338+
response=None,
302339
)
303340

304341
monkeypatch.setattr("vlmrun.cli.cli.Client", MockClient)
106 KB
Loading
594 KB
Loading
93.7 KB
Loading

tests/test_dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44
from datetime import datetime
5-
from vlmrun.client.types import DatasetResponse
5+
from vlmrun.client.types import DatasetCreateResponse
66

77

88
def test_dataset_create(mock_client):
@@ -13,7 +13,7 @@ def test_dataset_create(mock_client):
1313
dataset_name="test-dataset",
1414
dataset_type="images",
1515
)
16-
assert isinstance(response, DatasetResponse)
16+
assert isinstance(response, DatasetCreateResponse)
1717
assert response.dataset_id == "dataset1"
1818
assert response.domain == "test-domain"
1919
assert response.dataset_type == "images"
@@ -23,7 +23,7 @@ def test_dataset_create(mock_client):
2323
def test_dataset_get(mock_client):
2424
"""Test dataset retrieval."""
2525
response = mock_client.dataset.get("dataset1")
26-
assert isinstance(response, DatasetResponse)
26+
assert isinstance(response, DatasetCreateResponse)
2727
assert response.dataset_id == "dataset1"
2828
assert response.domain == "test-domain"
2929
assert response.dataset_type == "images"

tests/test_feedback.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
class TestLabel(BaseModel):
88
"""Test label model."""
9+
910
score: int
1011
comment: str
1112

@@ -14,10 +15,7 @@ def test_submit_feedback(mock_client):
1415
"""Test submitting feedback for a prediction."""
1516
label = TestLabel(score=5, comment="Great prediction!")
1617
response = mock_client.feedback.submit(
17-
id="prediction1",
18-
label=label,
19-
notes="Test feedback",
20-
flag=False
18+
id="prediction1", label=label, notes="Test feedback", flag=False
2119
)
2220
assert isinstance(response, FeedbackSubmitResponse)
2321
assert response.id == "feedback1"

tests/test_files.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Tests for files operations."""
22

3-
from pathlib import Path
43
from vlmrun.client.types import FileResponse
54

65

@@ -19,7 +18,7 @@ def test_upload_file(mock_client, tmp_path):
1918
# Create a temporary file
2019
test_file = tmp_path / "test.txt"
2120
test_file.write_text("test content")
22-
21+
2322
response = mock_client.files.upload(test_file)
2423
assert isinstance(response, FileResponse)
2524
assert response.id == "file1"
@@ -32,14 +31,7 @@ def test_get_file(mock_client):
3231
assert isinstance(response, FileResponse)
3332
assert response.id == "file1"
3433
assert response.filename == "test.txt"
35-
assert len(response.bytes) == len(b"test content")
36-
37-
38-
def test_get_content(mock_client):
39-
"""Test getting file content."""
40-
response = mock_client.files.get_content("file1")
41-
assert isinstance(response, bytes)
42-
assert response == b"test content"
34+
assert response.bytes == 10
4335

4436

4537
def test_delete_file(mock_client):

0 commit comments

Comments
 (0)