Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Introduce ValidationErrorMessage
`ValidationErrorMessage` is a string-like object that holds a code
attribute.

The code attribute has been removed from ValidationError to be able
to maintain better backward compatibility.

`ValidationErrorMessage` is abstracted in `ValidationError`'s
constructor
  • Loading branch information
johnraz committed Aug 28, 2016
commit 2bf6ee47f3987a39d83d32a1105f524fe8e72728
10 changes: 4 additions & 6 deletions rest_framework/authtoken/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,18 @@ def validate(self, attrs):
msg = _('User account is disabled.')
raise serializers.ValidationError(
msg,
code='authorization'
)
code='authorization')
else:
msg = _('Unable to log in with provided credentials.')
raise serializers.ValidationError(
msg,
code='authorization'
)
code='authorization')

else:
msg = _('Must include "username" and "password".')
raise serializers.ValidationError(
msg,
code='authorization'
)
code='authorization')

attrs['user'] = user
return attrs
24 changes: 17 additions & 7 deletions rest_framework/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __str__(self):
def build_error_from_django_validation_error(exc_info):
code = getattr(exc_info, 'code', None) or 'invalid'
return [
ValidationError(msg, code=code)
ValidationErrorMessage(msg, code=code)
for msg in exc_info.messages
]

Expand All @@ -73,20 +73,30 @@ def build_error_from_django_validation_error(exc_info):
# from rest_framework import serializers
# raise serializers.ValidationError('Value was invalid')

class ValidationErrorMessage(six.text_type):
code = None

def __new__(cls, string, code=None, *args, **kwargs):
self = super(ValidationErrorMessage, cls).__new__(
cls, string, *args, **kwargs)

self.code = code
return self


class ValidationError(APIException):
status_code = status.HTTP_400_BAD_REQUEST

def __init__(self, detail, code=None):
# If code is there, this means we are dealing with a message.
if code and not isinstance(detail, ValidationErrorMessage):
detail = ValidationErrorMessage(detail, code=code)

# For validation errors the 'detail' key is always required.
# The details should always be coerced to a list if not already.
if not isinstance(detail, dict) and not isinstance(detail, list):
detail = [detail]
elif isinstance(detail, dict) or (detail and isinstance(detail[0], ValidationError)):
assert code is None, (
'The `code` argument must not be set for compound errors.')

self.detail = detail
self.code = code
self.detail = _force_text_recursive(detail)

def __str__(self):
return six.text_type(self.detail)
Expand Down
2 changes: 1 addition & 1 deletion rest_framework/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ def run_validators(self, value):
# attempting to accumulate a list of errors.
if isinstance(exc.detail, dict):
raise
errors.append(ValidationError(exc.detail, code=exc.code))
errors.extend(exc.detail)
except DjangoValidationError as exc:
errors.extend(
build_error_from_django_validation_error(exc)
Expand Down
1 change: 1 addition & 0 deletions rest_framework/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self, data=None, status=None,
'`.error`. representation.'
)
raise AssertionError(msg)

self.data = data
self.template_name = template_name
self.exception = exception
Expand Down
54 changes: 10 additions & 44 deletions rest_framework/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from rest_framework import exceptions
from rest_framework.compat import JSONField as ModelJSONField
from rest_framework.compat import postgres_fields, unicode_to_repr
from rest_framework.exceptions import ValidationErrorMessage
from rest_framework.utils import model_meta
from rest_framework.utils.field_mapping import (
ClassLookupDict, get_field_kwargs, get_nested_relation_kwargs,
Expand Down Expand Up @@ -220,14 +221,7 @@ def is_valid(self, raise_exception=False):
self._errors = {}

if self._errors and raise_exception:
return_errors = None
if isinstance(self._errors, list):
return_errors = ReturnList(self._errors, serializer=self)
elif isinstance(self._errors, dict):
return_errors = ReturnDict(self._errors, serializer=self)

raise ValidationError(return_errors)

raise ValidationError(self.errors)
return not bool(self._errors)

@property
Expand All @@ -251,42 +245,12 @@ def data(self):
self._data = self.get_initial()
return self._data

def _transform_to_legacy_errors(self, errors_to_transform):
# Do not mutate `errors_to_transform` here.
errors = ReturnDict(serializer=self)
for field_name, values in errors_to_transform.items():
if isinstance(values, list):
errors[field_name] = values
continue

if isinstance(values.detail, list):
errors[field_name] = []
for value in values.detail:
if isinstance(value, ValidationError):
errors[field_name].extend(value.detail)
elif isinstance(value, list):
errors[field_name].extend(value)
else:
errors[field_name].append(value)

elif isinstance(values.detail, dict):
errors[field_name] = {}
for sub_field_name, value in values.detail.items():
errors[field_name][sub_field_name] = []
for validation_error in value:
errors[field_name][sub_field_name].extend(validation_error.detail)
return errors

@property
def errors(self):
if not hasattr(self, '_errors'):
msg = 'You must call `.is_valid()` before accessing `.errors`.'
raise AssertionError(msg)

if isinstance(self._errors, list):
return map(self._transform_to_legacy_errors, self._errors)
else:
return self._transform_to_legacy_errors(self._errors)
return self._errors

@property
def validated_data(self):
Expand Down Expand Up @@ -461,7 +425,7 @@ def to_internal_value(self, data):
message = self.error_messages['invalid'].format(
datatype=type(data).__name__
)
error = ValidationError(message, code='invalid')
error = ValidationErrorMessage(message, code='invalid')
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [error]
})
Expand All @@ -478,7 +442,7 @@ def to_internal_value(self, data):
if validate_method is not None:
validated_value = validate_method(validated_value)
except ValidationError as exc:
errors[field.field_name] = exc
errors[field.field_name] = exc.detail
except DjangoValidationError as exc:
errors[field.field_name] = (
exceptions.build_error_from_django_validation_error(exc)
Expand Down Expand Up @@ -621,7 +585,7 @@ def to_internal_value(self, data):
message = self.error_messages['not_a_list'].format(
input_type=type(data).__name__
)
error = ValidationError(
error = ValidationErrorMessage(
message,
code='not_a_list'
)
Expand All @@ -630,9 +594,11 @@ def to_internal_value(self, data):
})

if not self.allow_empty and len(data) == 0:
message = self.error_messages['empty']
message = ValidationErrorMessage(
self.error_messages['empty'],
code='empty_not_allowed')
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [ValidationError(message, code='empty_not_allowed')]
api_settings.NON_FIELD_ERRORS_KEY: [message]
})

ret = []
Expand Down
17 changes: 10 additions & 7 deletions rest_framework/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from django.utils.translation import ugettext_lazy as _

from rest_framework.compat import unicode_to_repr
from rest_framework.exceptions import ValidationError
from rest_framework.exceptions import ValidationError, ValidationErrorMessage
from rest_framework.utils.representation import smart_repr


Expand Down Expand Up @@ -120,9 +120,10 @@ def enforce_required_fields(self, attrs):
return

missing = {
field_name: ValidationError(
field_name: ValidationErrorMessage(
self.missing_message,
code='required')

for field_name in self.fields
if field_name not in attrs
}
Expand Down Expand Up @@ -168,8 +169,9 @@ def __call__(self, attrs):
]
if None not in checked_values and qs_exists(queryset):
field_names = ', '.join(self.fields)
raise ValidationError(self.message.format(field_names=field_names),
code='unique')
raise ValidationError(
self.message.format(field_names=field_names),
code='unique')

def __repr__(self):
return unicode_to_repr('<%s(queryset=%s, fields=%s)>' % (
Expand Down Expand Up @@ -207,7 +209,7 @@ def enforce_required_fields(self, attrs):
'required' state on the fields they are applied to.
"""
missing = {
field_name: ValidationError(
field_name: ValidationErrorMessage(
self.missing_message,
code='required')
for field_name in [self.field, self.date_field]
Expand Down Expand Up @@ -235,8 +237,9 @@ def __call__(self, attrs):
queryset = self.exclude_current_instance(attrs, queryset)
if qs_exists(queryset):
message = self.message.format(date_field=self.date_field)
error = ValidationError(message, code='unique')
raise ValidationError({self.field: error})
raise ValidationError({
self.field: ValidationErrorMessage(message, code='unique'),
})

def __repr__(self):
return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % (
Expand Down
2 changes: 1 addition & 1 deletion rest_framework/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def exception_handler(exc, context):
headers['Retry-After'] = '%d' % exc.wait

if isinstance(exc.detail, (list, dict)):
data = exc.detail.serializer.errors
data = exc.detail
else:
data = {'detail': exc.detail}

Expand Down
74 changes: 74 additions & 0 deletions tests/test_validation_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from django.test import TestCase

from rest_framework import serializers, status
from rest_framework.decorators import api_view
from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory
from rest_framework.views import APIView

factory = APIRequestFactory()


class ExampleSerializer(serializers.Serializer):
char = serializers.CharField()
integer = serializers.IntegerField()


class ErrorView(APIView):
def get(self, request, *args, **kwargs):
ExampleSerializer(data={}).is_valid(raise_exception=True)


@api_view(['GET'])
def error_view(request):
ExampleSerializer(data={}).is_valid(raise_exception=True)


class TestValidationErrorWithCode(TestCase):
def setUp(self):
self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER

def exception_handler(exc, request):
return_errors = {}
for field_name, errors in exc.detail.items():
return_errors[field_name] = []
for error in errors:
return_errors[field_name].append({
'code': error.code,
'message': error
})

return Response(return_errors, status=status.HTTP_400_BAD_REQUEST)

api_settings.EXCEPTION_HANDLER = exception_handler

self.expected_response_data = {
'char': [{
'message': 'This field is required.',
'code': 'required',
}],
'integer': [{
'message': 'This field is required.',
'code': 'required'
}],
}

def tearDown(self):
api_settings.EXCEPTION_HANDLER = self.DEFAULT_HANDLER

def test_class_based_view_exception_handler(self):
view = ErrorView.as_view()

request = factory.get('/', content_type='application/json')
response = view(request)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.data, self.expected_response_data)

def test_function_based_view_exception_handler(self):
view = error_view

request = factory.get('/', content_type='application/json')
response = view(request)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.data, self.expected_response_data)