diff --git a/docs/api-guide/exceptions.md b/docs/api-guide/exceptions.md index 3e4b3e8be5..f0f178d92f 100644 --- a/docs/api-guide/exceptions.md +++ b/docs/api-guide/exceptions.md @@ -98,7 +98,7 @@ Note that the exception handler will only be called for responses generated by r The **base class** for all exceptions raised inside an `APIView` class or `@api_view`. -To provide a custom exception, subclass `APIException` and set the `.status_code` and `.default_detail` properties on the class. +To provide a custom exception, subclass `APIException` and set the `.status_code`, `.default_detail`, and `default_code` attributes on the class. For example, if your API relies on a third party service that may sometimes be unreachable, you might want to implement an exception for the "503 Service Unavailable" HTTP response code. You could do this like so: @@ -107,10 +107,42 @@ For example, if your API relies on a third party service that may sometimes be u class ServiceUnavailable(APIException): status_code = 503 default_detail = 'Service temporarily unavailable, try again later.' + default_code = 'service_unavailable' + +#### Inspecting API exceptions + +There are a number of different properties available for inspecting the status +of an API exception. You can use these to build custom exception handling +for your project. + +The available attributes and methods are: + +* `.detail` - Return the textual description of the error. +* `.get_codes()` - Return the code identifier of the error. +* `.full_details()` - Return both the textual description and the code identifier. + +In most cases the error detail will be a simple item: + + >>> print(exc.detail) + You do not have permission to perform this action. + >>> print(exc.get_codes()) + permission_denied + >>> print(exc.full_details()) + {'message':'You do not have permission to perform this action.','code':'permission_denied'} + +In the case of validation errors the error detail will be either a list or +dictionary of items: + + >>> print(exc.detail) + {"name":"This field is required.","age":"A valid integer is required."} + >>> print(exc.get_codes()) + {"name":"required","age":"invalid"} + >>> print(exc.get_full_details()) + {"name":{"message":"This field is required.","code":"required"},"age":{"message":"A valid integer is required.","code":"invalid"}} ## ParseError -**Signature:** `ParseError(detail=None)` +**Signature:** `ParseError(detail=None, code=None)` Raised if the request contains malformed data when accessing `request.data`. @@ -118,7 +150,7 @@ By default this exception results in a response with the HTTP status code "400 B ## AuthenticationFailed -**Signature:** `AuthenticationFailed(detail=None)` +**Signature:** `AuthenticationFailed(detail=None, code=None)` Raised when an incoming request includes incorrect authentication. @@ -126,7 +158,7 @@ By default this exception results in a response with the HTTP status code "401 U ## NotAuthenticated -**Signature:** `NotAuthenticated(detail=None)` +**Signature:** `NotAuthenticated(detail=None, code=None)` Raised when an unauthenticated request fails the permission checks. @@ -134,7 +166,7 @@ By default this exception results in a response with the HTTP status code "401 U ## PermissionDenied -**Signature:** `PermissionDenied(detail=None)` +**Signature:** `PermissionDenied(detail=None, code=None)` Raised when an authenticated request fails the permission checks. @@ -142,7 +174,7 @@ By default this exception results in a response with the HTTP status code "403 F ## NotFound -**Signature:** `NotFound(detail=None)` +**Signature:** `NotFound(detail=None, code=None)` Raised when a resource does not exists at the given URL. This exception is equivalent to the standard `Http404` Django exception. @@ -150,7 +182,7 @@ By default this exception results in a response with the HTTP status code "404 N ## MethodNotAllowed -**Signature:** `MethodNotAllowed(method, detail=None)` +**Signature:** `MethodNotAllowed(method, detail=None, code=None)` Raised when an incoming request occurs that does not map to a handler method on the view. @@ -158,7 +190,7 @@ By default this exception results in a response with the HTTP status code "405 M ## NotAcceptable -**Signature:** `NotAcceptable(detail=None)` +**Signature:** `NotAcceptable(detail=None, code=None)` Raised when an incoming request occurs with an `Accept` header that cannot be satisfied by any of the available renderers. @@ -166,7 +198,7 @@ By default this exception results in a response with the HTTP status code "406 N ## UnsupportedMediaType -**Signature:** `UnsupportedMediaType(media_type, detail=None)` +**Signature:** `UnsupportedMediaType(media_type, detail=None, code=None)` Raised if there are no parsers that can handle the content type of the request data when accessing `request.data`. @@ -174,7 +206,7 @@ By default this exception results in a response with the HTTP status code "415 U ## Throttled -**Signature:** `Throttled(wait=None, detail=None)` +**Signature:** `Throttled(wait=None, detail=None, code=None)` Raised when an incoming request fails the throttling checks. @@ -182,7 +214,7 @@ By default this exception results in a response with the HTTP status code "429 T ## ValidationError -**Signature:** `ValidationError(detail)` +**Signature:** `ValidationError(detail, code=None)` The `ValidationError` exception is slightly different from the other `APIException` classes: diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py index 90d3bd96ed..b91a8454fc 100644 --- a/rest_framework/authtoken/serializers.py +++ b/rest_framework/authtoken/serializers.py @@ -21,13 +21,13 @@ def validate(self, attrs): # (Assuming the default `ModelBackend` authentication backend.) if not user.is_active: msg = _('User account is disabled.') - raise serializers.ValidationError(msg) + raise serializers.ValidationError(msg, code='authorization') else: msg = _('Unable to log in with provided credentials.') - raise serializers.ValidationError(msg) + raise serializers.ValidationError(msg, code='authorization') else: msg = _('Must include "username" and "password".') - raise serializers.ValidationError(msg) + raise serializers.ValidationError(msg, code='authorization') attrs['user'] = user return attrs diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index 29afaffe00..e41655fef6 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -17,27 +17,61 @@ from rest_framework.utils.serializer_helpers import ReturnDict, ReturnList -def _force_text_recursive(data): +def _get_error_details(data, default_code=None): """ Descend into a nested data structure, forcing any - lazy translation strings into plain text. + lazy translation strings or strings into `ErrorDetail`. """ if isinstance(data, list): ret = [ - _force_text_recursive(item) for item in data + _get_error_details(item, default_code) for item in data ] if isinstance(data, ReturnList): return ReturnList(ret, serializer=data.serializer) return ret elif isinstance(data, dict): ret = { - key: _force_text_recursive(value) + key: _get_error_details(value, default_code) for key, value in data.items() } if isinstance(data, ReturnDict): return ReturnDict(ret, serializer=data.serializer) return ret - return force_text(data) + + text = force_text(data) + code = getattr(data, 'code', default_code) + return ErrorDetail(text, code) + + +def _get_codes(detail): + if isinstance(detail, list): + return [_get_codes(item) for item in detail] + elif isinstance(detail, dict): + return {key: _get_codes(value) for key, value in detail.items()} + return detail.code + + +def _get_full_details(detail): + if isinstance(detail, list): + return [_get_full_details(item) for item in detail] + elif isinstance(detail, dict): + return {key: _get_full_details(value) for key, value in detail.items()} + return { + 'message': detail, + 'code': detail.code + } + + +class ErrorDetail(six.text_type): + """ + A string-like object that can additionally + """ + code = None + + def __new__(cls, string, code=None): + self = super(ErrorDetail, cls).__new__(cls, string) + self.code = code + return self class APIException(Exception): @@ -47,16 +81,35 @@ class APIException(Exception): """ status_code = status.HTTP_500_INTERNAL_SERVER_ERROR default_detail = _('A server error occurred.') + default_code = 'error' - def __init__(self, detail=None): - if detail is not None: - self.detail = force_text(detail) - else: - self.detail = force_text(self.default_detail) + def __init__(self, detail=None, code=None): + if detail is None: + detail = self.default_detail + if code is None: + code = self.default_code + + self.detail = _get_error_details(detail, code) def __str__(self): return self.detail + def get_codes(self): + """ + Return only the code part of the error details. + + Eg. {"name": ["required"]} + """ + return _get_codes(self.detail) + + def get_full_details(self): + """ + Return both the message & code parts of the error details. + + Eg. {"name": [{"message": "This field is required.", "code": "required"}]} + """ + return _get_full_details(self.detail) + # The recommended style for using `ValidationError` is to keep it namespaced # under `serializers`, in order to minimize potential confusion with Django's @@ -67,13 +120,21 @@ def __str__(self): class ValidationError(APIException): status_code = status.HTTP_400_BAD_REQUEST + default_detail = _('Invalid input.') + default_code = 'invalid' + + def __init__(self, detail, code=None): + if detail is None: + detail = self.default_detail + if code is None: + code = self.default_code - def __init__(self, detail): - # For validation errors the 'detail' key is always required. - # The details should always be coerced to a list if not already. + # For validation failures, we may collect may errors together, so the + # details should always be coerced to a list if not already. if not isinstance(detail, dict) and not isinstance(detail, list): detail = [detail] - self.detail = _force_text_recursive(detail) + + self.detail = _get_error_details(detail, code) def __str__(self): return six.text_type(self.detail) @@ -82,62 +143,63 @@ def __str__(self): class ParseError(APIException): status_code = status.HTTP_400_BAD_REQUEST default_detail = _('Malformed request.') + default_code = 'parse_error' class AuthenticationFailed(APIException): status_code = status.HTTP_401_UNAUTHORIZED default_detail = _('Incorrect authentication credentials.') + default_code = 'authentication_failed' class NotAuthenticated(APIException): status_code = status.HTTP_401_UNAUTHORIZED default_detail = _('Authentication credentials were not provided.') + default_code = 'not_authenticated' class PermissionDenied(APIException): status_code = status.HTTP_403_FORBIDDEN default_detail = _('You do not have permission to perform this action.') + default_code = 'permission_denied' class NotFound(APIException): status_code = status.HTTP_404_NOT_FOUND default_detail = _('Not found.') + default_code = 'not_found' class MethodNotAllowed(APIException): status_code = status.HTTP_405_METHOD_NOT_ALLOWED default_detail = _('Method "{method}" not allowed.') + default_code = 'method_not_allowed' - def __init__(self, method, detail=None): - if detail is not None: - self.detail = force_text(detail) - else: - self.detail = force_text(self.default_detail).format(method=method) + def __init__(self, method, detail=None, code=None): + if detail is None: + detail = force_text(self.default_detail).format(method=method) + super(MethodNotAllowed, self).__init__(detail, code) class NotAcceptable(APIException): status_code = status.HTTP_406_NOT_ACCEPTABLE default_detail = _('Could not satisfy the request Accept header.') + default_code = 'not_acceptable' - def __init__(self, detail=None, available_renderers=None): - if detail is not None: - self.detail = force_text(detail) - else: - self.detail = force_text(self.default_detail) + def __init__(self, detail=None, code=None, available_renderers=None): self.available_renderers = available_renderers + super(NotAcceptable, self).__init__(detail, code) class UnsupportedMediaType(APIException): status_code = status.HTTP_415_UNSUPPORTED_MEDIA_TYPE default_detail = _('Unsupported media type "{media_type}" in request.') + default_code = 'unsupported_media_type' - def __init__(self, media_type, detail=None): - if detail is not None: - self.detail = force_text(detail) - else: - self.detail = force_text(self.default_detail).format( - media_type=media_type - ) + def __init__(self, media_type, detail=None, code=None): + if detail is None: + detail = force_text(self.default_detail).format(media_type=media_type) + super(UnsupportedMediaType, self).__init__(detail, code) class Throttled(APIException): @@ -145,12 +207,10 @@ class Throttled(APIException): default_detail = _('Request was throttled.') extra_detail_singular = 'Expected available in {wait} second.' extra_detail_plural = 'Expected available in {wait} seconds.' + default_code = 'throttled' - def __init__(self, wait=None, detail=None): - if detail is not None: - self.detail = force_text(detail) - else: - self.detail = force_text(self.default_detail) + def __init__(self, wait=None, detail=None, code=None): + super(Throttled, self).__init__(detail, code) if wait is None: self.wait = None diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 7f8391b8a8..1894b064cc 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -34,7 +34,7 @@ from rest_framework.compat import ( get_remote_field, unicode_repr, unicode_to_repr, value_from_object ) -from rest_framework.exceptions import ValidationError +from rest_framework.exceptions import ErrorDetail, ValidationError from rest_framework.settings import api_settings from rest_framework.utils import html, humanize_datetime, representation @@ -224,6 +224,18 @@ def __init__(self, value, display_text, disabled=False): yield Option(value='n/a', display_text=cutoff_text, disabled=True) +def get_error_detail(exc_info): + """ + Given a Django ValidationError, return a list of ErrorDetail, + with the `code` populated. + """ + code = getattr(exc_info, 'code', None) or 'invalid' + return [ + ErrorDetail(msg, code=code) + for msg in exc_info.messages + ] + + class CreateOnlyDefault(object): """ This class may be used to provide default values that are only used @@ -525,7 +537,7 @@ def run_validators(self, value): raise errors.extend(exc.detail) except DjangoValidationError as exc: - errors.extend(exc.messages) + errors.extend(get_error_detail(exc)) if errors: raise ValidationError(errors) @@ -563,7 +575,7 @@ def fail(self, key, **kwargs): msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key) raise AssertionError(msg) message_string = msg.format(**kwargs) - raise ValidationError(message_string) + raise ValidationError(message_string, code=key) @cached_property def root(self): diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 7e99d40b31..a6ed7d87ee 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -291,32 +291,29 @@ def __new__(cls, name, bases, attrs): return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs) -def get_validation_error_detail(exc): +def as_serializer_error(exc): assert isinstance(exc, (ValidationError, DjangoValidationError)) if isinstance(exc, DjangoValidationError): - # Normally you should raise `serializers.ValidationError` - # inside your codebase, but we handle Django's validation - # exception class as well for simpler compat. - # Eg. Calling Model.clean() explicitly inside Serializer.validate() - return { - api_settings.NON_FIELD_ERRORS_KEY: list(exc.messages) - } - elif isinstance(exc.detail, dict): + detail = get_error_detail(exc) + else: + detail = exc.detail + + if isinstance(detail, dict): # If errors may be a dict we use the standard {key: list of values}. # Here we ensure that all the values are *lists* of errors. return { key: value if isinstance(value, (list, dict)) else [value] - for key, value in exc.detail.items() + for key, value in detail.items() } - elif isinstance(exc.detail, list): + elif isinstance(detail, list): # Errors raised as a list are non-field errors. return { - api_settings.NON_FIELD_ERRORS_KEY: exc.detail + api_settings.NON_FIELD_ERRORS_KEY: detail } # Errors raised as a string are non-field errors. return { - api_settings.NON_FIELD_ERRORS_KEY: [exc.detail] + api_settings.NON_FIELD_ERRORS_KEY: [detail] } @@ -410,7 +407,7 @@ def run_validation(self, data=empty): value = self.validate(value) assert value is not None, '.validate() should return the validated data' except (ValidationError, DjangoValidationError) as exc: - raise ValidationError(detail=get_validation_error_detail(exc)) + raise ValidationError(detail=as_serializer_error(exc)) return value @@ -424,7 +421,7 @@ def to_internal_value(self, data): ) raise ValidationError({ api_settings.NON_FIELD_ERRORS_KEY: [message] - }) + }, code='invalid') ret = OrderedDict() errors = OrderedDict() @@ -440,7 +437,7 @@ def to_internal_value(self, data): except ValidationError as exc: errors[field.field_name] = exc.detail except DjangoValidationError as exc: - errors[field.field_name] = list(exc.messages) + errors[field.field_name] = get_error_detail(exc) except SkipField: pass else: @@ -564,7 +561,7 @@ def run_validation(self, data=empty): value = self.validate(value) assert value is not None, '.validate() should return the validated data' except (ValidationError, DjangoValidationError) as exc: - raise ValidationError(detail=get_validation_error_detail(exc)) + raise ValidationError(detail=as_serializer_error(exc)) return value @@ -581,13 +578,13 @@ def to_internal_value(self, data): ) raise ValidationError({ api_settings.NON_FIELD_ERRORS_KEY: [message] - }) + }, code='not_a_list') if not self.allow_empty and len(data) == 0: message = self.error_messages['empty'] raise ValidationError({ api_settings.NON_FIELD_ERRORS_KEY: [message] - }) + }, code='empty') ret = [] errors = [] diff --git a/rest_framework/validators.py b/rest_framework/validators.py index 84af0b9d58..57f8bad530 100644 --- a/rest_framework/validators.py +++ b/rest_framework/validators.py @@ -80,7 +80,7 @@ def __call__(self, value): queryset = self.filter_queryset(value, queryset) queryset = self.exclude_current_instance(queryset) if qs_exists(queryset): - raise ValidationError(self.message) + raise ValidationError(self.message, code='unique') def __repr__(self): return unicode_to_repr('<%s(queryset=%s)>' % ( @@ -120,13 +120,13 @@ def enforce_required_fields(self, attrs): if self.instance is not None: return - missing = { + missing_items = { field_name: self.missing_message for field_name in self.fields if field_name not in attrs } - if missing: - raise ValidationError(missing) + if missing_items: + raise ValidationError(missing_items, code='required') def filter_queryset(self, attrs, queryset): """ @@ -167,7 +167,8 @@ def __call__(self, attrs): ] if None not in checked_values and qs_exists(queryset): field_names = ', '.join(self.fields) - raise ValidationError(self.message.format(field_names=field_names)) + message = self.message.format(field_names=field_names) + raise ValidationError(message, code='unique') def __repr__(self): return unicode_to_repr('<%s(queryset=%s, fields=%s)>' % ( @@ -204,13 +205,13 @@ def enforce_required_fields(self, attrs): The `UniqueForValidator` classes always force an implied 'required' state on the fields they are applied to. """ - missing = { + missing_items = { field_name: self.missing_message for field_name in [self.field, self.date_field] if field_name not in attrs } - if missing: - raise ValidationError(missing) + if missing_items: + raise ValidationError(missing_items, code='required') def filter_queryset(self, attrs, queryset): raise NotImplementedError('`filter_queryset` must be implemented.') @@ -231,7 +232,9 @@ def __call__(self, attrs): queryset = self.exclude_current_instance(attrs, queryset) if qs_exists(queryset): message = self.message.format(date_field=self.date_field) - raise ValidationError({self.field: message}) + raise ValidationError({ + self.field: message + }, code='unique') def __repr__(self): return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % ( diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index cec050eb87..29703cb779 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -3,19 +3,39 @@ from django.test import TestCase from django.utils.translation import ugettext_lazy as _ -from rest_framework.exceptions import _force_text_recursive +from rest_framework.exceptions import ErrorDetail, _get_error_details class ExceptionTestCase(TestCase): - def test_force_text_recursive(self): - - s = "sfdsfggiuytraetfdlklj" - self.assertEqual(_force_text_recursive(_(s)), s) - self.assertEqual(type(_force_text_recursive(_(s))), type(s)) - - self.assertEqual(_force_text_recursive({'a': _(s)})['a'], s) - self.assertEqual(type(_force_text_recursive({'a': _(s)})['a']), type(s)) - - self.assertEqual(_force_text_recursive([[_(s)]])[0][0], s) - self.assertEqual(type(_force_text_recursive([[_(s)]])[0][0]), type(s)) + def test_get_error_details(self): + + example = "string" + lazy_example = _(example) + + self.assertEqual( + _get_error_details(lazy_example), + example + ) + assert isinstance( + _get_error_details(lazy_example), + ErrorDetail + ) + + self.assertEqual( + _get_error_details({'nested': lazy_example})['nested'], + example + ) + assert isinstance( + _get_error_details({'nested': lazy_example})['nested'], + ErrorDetail + ) + + self.assertEqual( + _get_error_details([[lazy_example]])[0][0], + example + ) + assert isinstance( + _get_error_details([[lazy_example]])[0][0], + ErrorDetail + ) diff --git a/tests/test_validation.py b/tests/test_validation.py index ab04d59e6d..bc950dd22a 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -60,7 +60,7 @@ def test_nested_validation_error_detail(self): } }) - self.assertEqual(serializers.get_validation_error_detail(e), { + self.assertEqual(serializers.as_serializer_error(e), { 'nested': { 'field': ['error'], } diff --git a/tests/test_validation_error.py b/tests/test_validation_error.py new file mode 100644 index 0000000000..8e371a3494 --- /dev/null +++ b/tests/test_validation_error.py @@ -0,0 +1,101 @@ +from django.test import TestCase + +from rest_framework import serializers, status +from rest_framework.decorators import api_view +from rest_framework.response import Response +from rest_framework.settings import api_settings +from rest_framework.test import APIRequestFactory +from rest_framework.views import APIView + +factory = APIRequestFactory() + + +class ExampleSerializer(serializers.Serializer): + char = serializers.CharField() + integer = serializers.IntegerField() + + +class ErrorView(APIView): + def get(self, request, *args, **kwargs): + ExampleSerializer(data={}).is_valid(raise_exception=True) + + +@api_view(['GET']) +def error_view(request): + ExampleSerializer(data={}).is_valid(raise_exception=True) + + +class TestValidationErrorWithFullDetails(TestCase): + def setUp(self): + self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER + + def exception_handler(exc, request): + data = exc.get_full_details() + return Response(data, status=status.HTTP_400_BAD_REQUEST) + + api_settings.EXCEPTION_HANDLER = exception_handler + + self.expected_response_data = { + 'char': [{ + 'message': 'This field is required.', + 'code': 'required', + }], + 'integer': [{ + 'message': 'This field is required.', + 'code': 'required' + }], + } + + def tearDown(self): + api_settings.EXCEPTION_HANDLER = self.DEFAULT_HANDLER + + def test_class_based_view_exception_handler(self): + view = ErrorView.as_view() + + request = factory.get('/', content_type='application/json') + response = view(request) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.data, self.expected_response_data) + + def test_function_based_view_exception_handler(self): + view = error_view + + request = factory.get('/', content_type='application/json') + response = view(request) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.data, self.expected_response_data) + + +class TestValidationErrorWithCodes(TestCase): + def setUp(self): + self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER + + def exception_handler(exc, request): + data = exc.get_codes() + return Response(data, status=status.HTTP_400_BAD_REQUEST) + + api_settings.EXCEPTION_HANDLER = exception_handler + + self.expected_response_data = { + 'char': ['required'], + 'integer': ['required'], + } + + def tearDown(self): + api_settings.EXCEPTION_HANDLER = self.DEFAULT_HANDLER + + def test_class_based_view_exception_handler(self): + view = ErrorView.as_view() + + request = factory.get('/', content_type='application/json') + response = view(request) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.data, self.expected_response_data) + + def test_function_based_view_exception_handler(self): + view = error_view + + request = factory.get('/', content_type='application/json') + response = view(request) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.data, self.expected_response_data)