Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Rejig ErrorMessages, to prefer code= directly on ValidationError
  • Loading branch information
lovelydinosaur committed Oct 10, 2016
commit 7943429dabe7dbda33798e1ef8e8d19e1670430e
20 changes: 6 additions & 14 deletions rest_framework/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,37 +10,36 @@

from django.utils import six
from django.utils.encoding import force_text
from django.utils.functional import Promise
from django.utils.translation import ugettext_lazy as _
from django.utils.translation import ungettext

from rest_framework import status
from rest_framework.utils.serializer_helpers import ReturnDict, ReturnList


def _force_text_recursive(data):
def _force_text_recursive(data, code=None):
"""
Descend into a nested data structure, forcing any
lazy translation strings into plain text.
lazy translation strings or strings into `ErrorMessage`.
"""
if isinstance(data, list):
ret = [
_force_text_recursive(item) for item in data
_force_text_recursive(item, code) for item in data
]
if isinstance(data, ReturnList):
return ReturnList(ret, serializer=data.serializer)
return ret
elif isinstance(data, dict):
ret = {
key: _force_text_recursive(value)
key: _force_text_recursive(value, code)
for key, value in data.items()
}
if isinstance(data, ReturnDict):
return ReturnDict(ret, serializer=data.serializer)
return ret

text = force_text(data)
code = getattr(data, 'code', 'invalid')
code = getattr(data, 'code', code or 'invalid')
return ErrorMessage(text, code)


Expand Down Expand Up @@ -82,18 +81,11 @@ class ValidationError(APIException):
status_code = status.HTTP_400_BAD_REQUEST

def __init__(self, detail, code=None):
if code is not None:
assert isinstance(detail, six.string_types + (Promise,)), (
"When providing a 'code', the detail must be a string argument. "
"Use 'ErrorMessage' to set the code for a composite ValidationError"
)
detail = ErrorMessage(detail, code)

# For validation errors the 'detail' key is always required.
# The details should always be coerced to a list if not already.
if not isinstance(detail, dict) and not isinstance(detail, list):
detail = [detail]
self.detail = _force_text_recursive(detail)
self.detail = _force_text_recursive(detail, code=code)

def __str__(self):
return six.text_type(self.detail)
Expand Down
14 changes: 7 additions & 7 deletions rest_framework/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,8 +423,8 @@ def to_internal_value(self, data):
datatype=type(data).__name__
)
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [ErrorMessage(message, code='invalid')]
})
api_settings.NON_FIELD_ERRORS_KEY: [message]
}, code='invalid')

ret = OrderedDict()
errors = OrderedDict()
Expand All @@ -440,7 +440,7 @@ def to_internal_value(self, data):
except ValidationError as exc:
errors[field.field_name] = exc.detail
except DjangoValidationError as exc:
errors[field.field_name] = list(exc.messages)
errors[field.field_name] = get_validation_error_detail(exc)
except SkipField:
pass
else:
Expand Down Expand Up @@ -580,14 +580,14 @@ def to_internal_value(self, data):
input_type=type(data).__name__
)
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [ErrorMessage(message, code='not_a_list')]
})
api_settings.NON_FIELD_ERRORS_KEY: [message]
}, code='not_a_list')

if not self.allow_empty and len(data) == 0:
message = self.error_messages['empty']
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [ErrorMessage(message, code='empty')]
})
api_settings.NON_FIELD_ERRORS_KEY: [message]
}, code='empty')

ret = []
errors = []
Expand Down
14 changes: 7 additions & 7 deletions rest_framework/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from django.utils.translation import ugettext_lazy as _

from rest_framework.compat import unicode_to_repr
from rest_framework.exceptions import ErrorMessage, ValidationError
from rest_framework.exceptions import ValidationError
from rest_framework.utils.representation import smart_repr


Expand Down Expand Up @@ -121,12 +121,12 @@ def enforce_required_fields(self, attrs):
return

missing_items = {
field_name: ErrorMessage(self.missing_message, code='required')
field_name: self.missing_message
for field_name in self.fields
if field_name not in attrs
}
if missing_items:
raise ValidationError(missing_items)
raise ValidationError(missing_items, code='required')

def filter_queryset(self, attrs, queryset):
"""
Expand Down Expand Up @@ -206,12 +206,12 @@ def enforce_required_fields(self, attrs):
'required' state on the fields they are applied to.
"""
missing_items = {
field_name: ErrorMessage(self.missing_message, code='required')
field_name: self.missing_message
for field_name in [self.field, self.date_field]
if field_name not in attrs
}
if missing_items:
raise ValidationError(missing_items)
raise ValidationError(missing_items, code='required')

def filter_queryset(self, attrs, queryset):
raise NotImplementedError('`filter_queryset` must be implemented.')
Expand All @@ -233,8 +233,8 @@ def __call__(self, attrs):
if qs_exists(queryset):
message = self.message.format(date_field=self.date_field)
raise ValidationError({
self.field: ErrorMessage(message, code='unique')
})
self.field: message
}, code='unique')

def __repr__(self):
return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % (
Expand Down