-
Notifications
You must be signed in to change notification settings - Fork 691
Expand file tree
/
Copy pathtest_exceptions.py
More file actions
636 lines (486 loc) · 23.1 KB
/
test_exceptions.py
File metadata and controls
636 lines (486 loc) · 23.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
"""Tests for outlines/exceptions.py and per-provider exception mapping.
Each provider section can be run independently:
pytest tests/test_exceptions.py -k "openai"
pytest tests/test_exceptions.py -k "anthropic"
pytest tests/test_exceptions.py -k "mistral"
pytest tests/test_exceptions.py -k "gemini"
pytest tests/test_exceptions.py -k "ollama"
pytest tests/test_exceptions.py -k "tgi"
pytest tests/test_exceptions.py -k "dottxt"
pytest tests/test_exceptions.py -k "vllm"
pytest tests/test_exceptions.py -k "sglang"
"""
import pytest
from unittest.mock import Mock
from outlines.exceptions import (
APIConnectionError,
APIError,
APITimeoutError,
AuthenticationError,
BadRequestError,
GenerationError,
NotFoundError,
OutlinesError,
PermissionDeniedError,
ProviderResponseError,
RateLimitError,
ServerError,
_extract_status_code,
is_provider_exception,
normalize_provider_exception,
)
# ---------------------------------------------------------------------------
# Exception hierarchy
# ---------------------------------------------------------------------------
class TestExceptionHierarchy:
def test_all_are_outlines_error(self):
for cls in (
APIError, AuthenticationError, PermissionDeniedError, NotFoundError,
RateLimitError, BadRequestError, ServerError, APITimeoutError,
APIConnectionError, ProviderResponseError, GenerationError,
):
assert issubclass(cls, OutlinesError)
def test_all_are_api_error(self):
for cls in (
AuthenticationError, PermissionDeniedError, NotFoundError,
RateLimitError, BadRequestError, ServerError, APITimeoutError,
APIConnectionError, ProviderResponseError, GenerationError,
):
assert issubclass(cls, APIError)
def test_retryable_flags(self):
assert RateLimitError.retryable is True
assert ServerError.retryable is True
assert APITimeoutError.retryable is True
assert APIConnectionError.retryable is True
assert AuthenticationError.retryable is False
assert PermissionDeniedError.retryable is False
assert NotFoundError.retryable is False
assert BadRequestError.retryable is False
assert ProviderResponseError.retryable is False
assert GenerationError.retryable is False
def test_api_error_stores_provider_and_original(self):
orig = ValueError("boom")
err = APIError(provider="test", original_exception=orig)
assert err.provider == "test"
assert err.original_exception is orig
def test_api_error_status_code_param(self):
err = APIError(status_code=429)
assert err.status_code == 429
def test_api_error_default_message(self):
err = APIError()
assert str(err) == "API error"
def test_api_error_provider_message(self):
err = APIError(provider="openai")
assert "[openai]" in str(err)
def test_hint_stored_in_notes_or_attribute(self):
err = AuthenticationError()
if hasattr(err, "__notes__"): # Python 3.11+
assert any("Check API key" in note for note in err.__notes__)
else:
assert "Check API key" in err._hint_note
def test_hint_note_uses_arrow_format(self):
err = AuthenticationError()
if hasattr(err, "__notes__"):
assert any(note.startswith(" →") for note in err.__notes__)
else:
assert err._hint_note.startswith(" →")
def test_hint_not_in_message(self):
err = AuthenticationError(provider="openai")
assert "Check API key" not in str(err)
assert "Check API key" not in err.args[0]
def test_no_hint_no_notes(self):
err = APIError()
assert not getattr(err, "__notes__", [])
assert not hasattr(err, "_hint_note")
def test_hint_attribute_accessible(self):
assert AuthenticationError.hint == "Check API key."
assert RateLimitError.hint != ""
assert APIError.hint == ""
# ---------------------------------------------------------------------------
# _extract_status_code
# ---------------------------------------------------------------------------
class TestExtractStatusCode:
def test_none_input(self):
assert _extract_status_code(None) is None
def test_status_code_attr(self):
class E(Exception):
status_code = 429
assert _extract_status_code(E()) == 429
def test_code_attr(self):
class E(Exception):
code = 404
assert _extract_status_code(E()) == 404
def test_status_attr(self):
class E(Exception):
status = 500
assert _extract_status_code(E()) == 500
def test_response_status_code(self):
class Resp:
status_code = 403
class E(Exception):
response = Resp()
assert _extract_status_code(E()) == 403
def test_sentinel_minus_one_filtered(self):
class E(Exception):
status_code = -1
assert _extract_status_code(E()) is None
def test_non_int_ignored(self):
class E(Exception):
status_code = "429"
assert _extract_status_code(E()) is None
# ---------------------------------------------------------------------------
# normalize_provider_exception — status-code fallback
# (uses provider "test" so no SDK-specific mapping is applied)
# ---------------------------------------------------------------------------
class TestNormalizeProviderException:
def _exc(self, status_code=None):
class FakeSDKError(Exception):
pass
if status_code is not None:
FakeSDKError.status_code = status_code
return FakeSDKError()
def test_status_code_401(self):
assert isinstance(normalize_provider_exception(self._exc(401), "test"), AuthenticationError)
def test_status_code_403(self):
assert isinstance(normalize_provider_exception(self._exc(403), "test"), PermissionDeniedError)
def test_status_code_404(self):
assert isinstance(normalize_provider_exception(self._exc(404), "test"), NotFoundError)
def test_status_code_429(self):
assert isinstance(normalize_provider_exception(self._exc(429), "test"), RateLimitError)
def test_status_code_500(self):
assert isinstance(normalize_provider_exception(self._exc(500), "test"), ServerError)
def test_status_code_503(self):
assert isinstance(normalize_provider_exception(self._exc(503), "test"), ServerError)
def test_status_code_400(self):
assert isinstance(normalize_provider_exception(self._exc(400), "test"), BadRequestError)
def test_status_code_422(self):
assert isinstance(normalize_provider_exception(self._exc(422), "test"), BadRequestError)
def test_no_status_code_fallback_to_api_error(self):
result = normalize_provider_exception(ValueError("unknown"), "test")
assert type(result) is APIError
def test_original_exception_preserved(self):
orig = ValueError("original")
assert normalize_provider_exception(orig, "test").original_exception is orig
def test_provider_stored(self):
result = normalize_provider_exception(self._exc(429), "mycloud")
assert result.provider == "mycloud"
def test_unknown_provider_uses_status_code_fallback(self):
# An unrecognised provider name returns {} from _build_exception_map,
# so only the status-code fallback applies.
assert isinstance(normalize_provider_exception(self._exc(429), "unknown-provider"), RateLimitError)
# ---------------------------------------------------------------------------
# is_provider_exception
# ---------------------------------------------------------------------------
class TestIsProviderException:
"""is_provider_exception must return False for programmer errors and True
for genuine provider/transport exceptions."""
# --- programmer errors must pass through ---
def test_type_error_is_not_provider(self):
assert not is_provider_exception(TypeError("bad arg"), "openai")
def test_attribute_error_is_not_provider(self):
assert not is_provider_exception(AttributeError("no attr"), "openai")
def test_key_error_is_not_provider(self):
assert not is_provider_exception(KeyError("missing"), "openai")
def test_name_error_is_not_provider(self):
assert not is_provider_exception(NameError("undef"), "openai")
def test_assertion_error_is_not_provider(self):
assert not is_provider_exception(AssertionError("assert"), "openai")
def test_plain_exception_no_status_code_is_not_provider(self):
assert not is_provider_exception(Exception("generic"), "openai")
# --- exceptions with HTTP status codes are treated as provider errors ---
def test_status_code_429_is_provider(self):
class FakeSDKError(Exception):
status_code = 429
assert is_provider_exception(FakeSDKError(), "openai")
def test_status_code_500_is_provider(self):
class FakeSDKError(Exception):
status_code = 500
assert is_provider_exception(FakeSDKError(), "unknown-provider")
def test_status_code_on_response_is_provider(self):
class Resp:
status_code = 403
class FakeSDKError(Exception):
response = Resp()
assert is_provider_exception(FakeSDKError(), "openai")
def test_sentinel_minus_one_is_not_provider(self):
# ollama.ResponseError uses -1 when status is unknown; should not wrap
class FakeSDKError(Exception):
status_code = -1
assert not is_provider_exception(FakeSDKError(), "ollama")
# --- known SDK exception classes are caught by the explicit map ---
def test_openai_sdk_exception_is_provider(self):
openai = pytest.importorskip("openai")
assert is_provider_exception(Mock(spec=openai.RateLimitError), "openai")
def test_anthropic_sdk_exception_is_provider(self):
anthropic = pytest.importorskip("anthropic")
assert is_provider_exception(Mock(spec=anthropic.APITimeoutError), "anthropic")
# --- unknown provider with no status code falls through ---
def test_unknown_provider_no_status_code_is_not_provider(self):
assert not is_provider_exception(ValueError("oops"), "unknown-provider")
# ---------------------------------------------------------------------------
# OpenAI provider mapping
# ---------------------------------------------------------------------------
class TestOpenAIExceptionMap:
@pytest.fixture(autouse=True)
def _require_openai(self):
pytest.importorskip("openai")
def _check(self, sdk_exc_cls, expected_outlines_cls):
result = normalize_provider_exception(Mock(spec=sdk_exc_cls), "openai")
assert isinstance(result, expected_outlines_cls), (
f"Expected {expected_outlines_cls.__name__}, got {type(result).__name__}"
)
def test_authentication_error(self):
import openai
self._check(openai.AuthenticationError, AuthenticationError)
def test_permission_denied(self):
import openai
self._check(openai.PermissionDeniedError, PermissionDeniedError)
def test_not_found(self):
import openai
self._check(openai.NotFoundError, NotFoundError)
def test_rate_limit(self):
import openai
self._check(openai.RateLimitError, RateLimitError)
def test_bad_request(self):
import openai
self._check(openai.BadRequestError, BadRequestError)
def test_internal_server_error(self):
import openai
self._check(openai.InternalServerError, ServerError)
def test_timeout(self):
import openai
self._check(openai.APITimeoutError, APITimeoutError)
def test_connection_error(self):
import openai
self._check(openai.APIConnectionError, APIConnectionError)
def test_response_validation(self):
import openai
self._check(openai.APIResponseValidationError, ProviderResponseError)
def test_length_finish_reason(self):
import openai
self._check(openai.LengthFinishReasonError, GenerationError)
def test_content_filter_finish_reason(self):
import openai
self._check(openai.ContentFilterFinishReasonError, GenerationError)
# ---------------------------------------------------------------------------
# Anthropic provider mapping
# ---------------------------------------------------------------------------
class TestAnthropicExceptionMap:
@pytest.fixture(autouse=True)
def _require_anthropic(self):
pytest.importorskip("anthropic")
def _check(self, sdk_exc_cls, expected_outlines_cls):
result = normalize_provider_exception(Mock(spec=sdk_exc_cls), "anthropic")
assert isinstance(result, expected_outlines_cls)
def test_authentication_error(self):
import anthropic
self._check(anthropic.AuthenticationError, AuthenticationError)
def test_permission_denied(self):
import anthropic
self._check(anthropic.PermissionDeniedError, PermissionDeniedError)
def test_not_found(self):
import anthropic
self._check(anthropic.NotFoundError, NotFoundError)
def test_rate_limit(self):
import anthropic
self._check(anthropic.RateLimitError, RateLimitError)
def test_bad_request(self):
import anthropic
self._check(anthropic.BadRequestError, BadRequestError)
def test_overloaded(self):
import anthropic._exceptions as anthropic_exc
self._check(anthropic_exc.OverloadedError, ServerError)
def test_service_unavailable(self):
import anthropic._exceptions as anthropic_exc
self._check(anthropic_exc.ServiceUnavailableError, ServerError)
def test_timeout(self):
import anthropic
self._check(anthropic.APITimeoutError, APITimeoutError)
def test_connection_error(self):
import anthropic
self._check(anthropic.APIConnectionError, APIConnectionError)
def test_response_validation(self):
import anthropic
self._check(anthropic.APIResponseValidationError, ProviderResponseError)
# ---------------------------------------------------------------------------
# Mistral provider mapping
# ---------------------------------------------------------------------------
class TestMistralExceptionMap:
@pytest.fixture(autouse=True)
def _require_mistral(self):
pytest.importorskip("mistralai")
def _check(self, sdk_exc_cls, expected_outlines_cls):
result = normalize_provider_exception(Mock(spec=sdk_exc_cls), "mistral")
assert isinstance(result, expected_outlines_cls)
def test_http_validation_error(self):
import mistralai.models as mm
self._check(mm.HTTPValidationError, BadRequestError)
def test_validation_error(self):
import mistralai.models as mm
self._check(mm.ValidationError, BadRequestError)
def test_response_validation_error(self):
import mistralai.models as mm
self._check(mm.ResponseValidationError, ProviderResponseError)
def test_no_response_error(self):
import mistralai.models as mm
self._check(mm.NoResponseError, APIConnectionError)
def test_httpx_timeout(self):
import httpx
self._check(httpx.TimeoutException, APITimeoutError)
def test_httpx_connect_error(self):
import httpx
self._check(httpx.ConnectError, APIConnectionError)
def test_status_code_fallback_401(self):
class MistralSDKError(Exception):
status_code = 401
assert isinstance(normalize_provider_exception(MistralSDKError(), "mistral"), AuthenticationError)
def test_status_code_fallback_429(self):
class MistralSDKError(Exception):
status_code = 429
assert isinstance(normalize_provider_exception(MistralSDKError(), "mistral"), RateLimitError)
def test_status_code_fallback_500(self):
class MistralSDKError(Exception):
status_code = 500
assert isinstance(normalize_provider_exception(MistralSDKError(), "mistral"), ServerError)
# ---------------------------------------------------------------------------
# Gemini provider mapping
# ---------------------------------------------------------------------------
class TestGeminiExceptionMap:
@pytest.fixture(autouse=True)
def _require_gemini(self):
pytest.importorskip("google.genai")
def _check(self, sdk_exc_cls, expected_outlines_cls):
result = normalize_provider_exception(Mock(spec=sdk_exc_cls), "gemini")
assert isinstance(result, expected_outlines_cls)
def test_server_error(self):
from google.genai import errors as genai_errors
self._check(genai_errors.ServerError, ServerError)
def test_httpx_timeout(self):
import httpx
self._check(httpx.TimeoutException, APITimeoutError)
def test_httpx_connect_error(self):
import httpx
self._check(httpx.ConnectError, APIConnectionError)
def test_client_error_status_code_fallback(self):
from google.genai import errors as genai_errors
mock_exc = Mock(spec=genai_errors.ClientError)
mock_exc.code = 401
result = normalize_provider_exception(mock_exc, "gemini")
assert isinstance(result, AuthenticationError)
# ---------------------------------------------------------------------------
# Ollama provider mapping
# ---------------------------------------------------------------------------
class TestOllamaExceptionMap:
@pytest.fixture(autouse=True)
def _require_ollama(self):
pytest.importorskip("ollama")
def test_request_error(self):
import ollama
result = normalize_provider_exception(ollama.RequestError("msg"), "ollama")
assert isinstance(result, APIConnectionError)
def test_response_error_with_status_code(self):
import ollama
result = normalize_provider_exception(ollama.ResponseError("not found", status_code=404), "ollama")
assert isinstance(result, NotFoundError)
def test_response_error_sentinel_minus_one(self):
import ollama
result = normalize_provider_exception(ollama.ResponseError("unknown", status_code=-1), "ollama")
assert type(result) is APIError
def test_response_error_429(self):
import ollama
result = normalize_provider_exception(ollama.ResponseError("rate limited", status_code=429), "ollama")
assert isinstance(result, RateLimitError)
# ---------------------------------------------------------------------------
# TGI provider mapping
# ---------------------------------------------------------------------------
class TestTGIExceptionMap:
@pytest.fixture(autouse=True)
def _require_tgi(self):
pytest.importorskip("huggingface_hub")
def _check(self, sdk_exc_cls, expected_outlines_cls):
result = normalize_provider_exception(Mock(spec=sdk_exc_cls), "tgi")
assert isinstance(result, expected_outlines_cls)
def test_inference_timeout(self):
from huggingface_hub.errors import InferenceTimeoutError
self._check(InferenceTimeoutError, APITimeoutError)
def test_overloaded(self):
from huggingface_hub.errors import OverloadedError
self._check(OverloadedError, ServerError)
def test_validation_error(self):
from huggingface_hub.errors import ValidationError
self._check(ValidationError, BadRequestError)
def test_generation_error(self):
from huggingface_hub.errors import GenerationError as HFGenerationError
self._check(HFGenerationError, GenerationError)
def test_incomplete_generation_error(self):
from huggingface_hub.errors import IncompleteGenerationError
self._check(IncompleteGenerationError, GenerationError)
def test_bad_request_error(self):
from huggingface_hub.errors import BadRequestError as HFBadRequestError
self._check(HFBadRequestError, BadRequestError)
def test_gated_repo_error(self):
from huggingface_hub.errors import GatedRepoError
self._check(GatedRepoError, PermissionDeniedError)
def test_repository_not_found(self):
from huggingface_hub.errors import RepositoryNotFoundError
self._check(RepositoryNotFoundError, NotFoundError)
# ---------------------------------------------------------------------------
# Dottxt provider mapping
# ---------------------------------------------------------------------------
class TestDottxtExceptionMap:
@pytest.fixture(autouse=True)
def _require_dottxt(self):
pytest.importorskip("urllib3")
def _check(self, sdk_exc_cls, expected_outlines_cls):
result = normalize_provider_exception(Mock(spec=sdk_exc_cls), "dottxt")
assert isinstance(result, expected_outlines_cls)
def test_connect_timeout(self):
from urllib3.exceptions import ConnectTimeoutError
self._check(ConnectTimeoutError, APITimeoutError)
def test_read_timeout(self):
from urllib3.exceptions import ReadTimeoutError
self._check(ReadTimeoutError, APITimeoutError)
def test_max_retry_error(self):
from urllib3.exceptions import MaxRetryError
self._check(MaxRetryError, APIConnectionError)
def test_new_connection_error(self):
from urllib3.exceptions import NewConnectionError
self._check(NewConnectionError, APIConnectionError)
# ---------------------------------------------------------------------------
# vLLM provider mapping (reuses OpenAI map via provider name)
# ---------------------------------------------------------------------------
class TestVLLMExceptionMap:
@pytest.fixture(autouse=True)
def _require_openai(self):
pytest.importorskip("openai")
def test_authentication_error(self):
import openai
result = normalize_provider_exception(Mock(spec=openai.AuthenticationError), "vllm")
assert isinstance(result, AuthenticationError)
def test_rate_limit(self):
import openai
result = normalize_provider_exception(Mock(spec=openai.RateLimitError), "vllm")
assert isinstance(result, RateLimitError)
def test_provider_name(self):
import openai
result = normalize_provider_exception(Mock(spec=openai.RateLimitError), "vllm")
assert result.provider == "vllm"
# ---------------------------------------------------------------------------
# SGLang provider mapping (reuses OpenAI map via provider name)
# ---------------------------------------------------------------------------
class TestSGLangExceptionMap:
@pytest.fixture(autouse=True)
def _require_openai(self):
pytest.importorskip("openai")
def test_authentication_error(self):
import openai
result = normalize_provider_exception(Mock(spec=openai.AuthenticationError), "sglang")
assert isinstance(result, AuthenticationError)
def test_rate_limit(self):
import openai
result = normalize_provider_exception(Mock(spec=openai.RateLimitError), "sglang")
assert isinstance(result, RateLimitError)
def test_provider_name(self):
import openai
result = normalize_provider_exception(Mock(spec=openai.RateLimitError), "sglang")
assert result.provider == "sglang"