Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add 'code' to ValidationError
  • Loading branch information
lovelydinosaur committed Oct 10, 2016
commit f078a04f7e64db75d33513983a3688a3a7e7b6a5
6 changes: 3 additions & 3 deletions rest_framework/authtoken/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ def validate(self, attrs):
# (Assuming the default `ModelBackend` authentication backend.)
if not user.is_active:
msg = _('User account is disabled.')
raise serializers.ValidationError(msg)
raise serializers.ValidationError(msg, code='authorization')
else:
msg = _('Unable to log in with provided credentials.')
raise serializers.ValidationError(msg)
raise serializers.ValidationError(msg, code='authorization')
else:
msg = _('Must include "username" and "password".')
raise serializers.ValidationError(msg)
raise serializers.ValidationError(msg, code='authorization')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the codes be different from each other here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessarily. We're not trying to identify every possible case, just a categorization. If uses want to have specific codes for those different cases then that's simple enough to do by customizing this.


attrs['user'] = user
return attrs
24 changes: 22 additions & 2 deletions rest_framework/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from django.utils import six
from django.utils.encoding import force_text
from django.utils.functional import Promise
from django.utils.translation import ugettext_lazy as _
from django.utils.translation import ungettext

Expand Down Expand Up @@ -37,7 +38,19 @@ def _force_text_recursive(data):
if isinstance(data, ReturnDict):
return ReturnDict(ret, serializer=data.serializer)
return ret
return force_text(data)

text = force_text(data)
code = getattr(data, 'code', 'invalid')
return ErrorMessage(text, code)


class ErrorMessage(six.text_type):
code = None

def __new__(cls, string, code=None):
self = super(ErrorMessage, cls).__new__(cls, string)
self.code = code
return self


class APIException(Exception):
Expand Down Expand Up @@ -68,7 +81,14 @@ def __str__(self):
class ValidationError(APIException):
status_code = status.HTTP_400_BAD_REQUEST

def __init__(self, detail):
def __init__(self, detail, code=None):
if code is not None:
assert isinstance(detail, six.string_types + (Promise,)), (
"When providing a 'code', the detail must be a string argument. "
"Use 'ErrorMessage' to set the code for a composite ValidationError"
)
detail = ErrorMessage(detail, code)

# For validation errors the 'detail' key is always required.
# The details should always be coerced to a list if not already.
if not isinstance(detail, dict) and not isinstance(detail, list):
Expand Down
18 changes: 15 additions & 3 deletions rest_framework/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from rest_framework.compat import (
get_remote_field, unicode_repr, unicode_to_repr, value_from_object
)
from rest_framework.exceptions import ValidationError
from rest_framework.exceptions import ErrorMessage, ValidationError
from rest_framework.settings import api_settings
from rest_framework.utils import html, humanize_datetime, representation

Expand Down Expand Up @@ -224,6 +224,18 @@ def __init__(self, value, display_text, disabled=False):
yield Option(value='n/a', display_text=cutoff_text, disabled=True)


def get_error_messages(exc_info):
"""
Given a Django ValidationError, return a list of ErrorMessage,
with the `code` populated.
"""
code = getattr(exc_info, 'code', None) or 'invalid'
return [
ErrorMessage(msg, code=code)
for msg in exc_info.messages
]


class CreateOnlyDefault(object):
"""
This class may be used to provide default values that are only used
Expand Down Expand Up @@ -525,7 +537,7 @@ def run_validators(self, value):
raise
errors.extend(exc.detail)
except DjangoValidationError as exc:
errors.extend(exc.messages)
errors.extend(get_error_messages(exc))
if errors:
raise ValidationError(errors)

Expand Down Expand Up @@ -563,7 +575,7 @@ def fail(self, key, **kwargs):
msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key)
raise AssertionError(msg)
message_string = msg.format(**kwargs)
raise ValidationError(message_string)
raise ValidationError(message_string, code=key)

@cached_property
def root(self):
Expand Down
8 changes: 4 additions & 4 deletions rest_framework/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def get_validation_error_detail(exc):
# exception class as well for simpler compat.
# Eg. Calling Model.clean() explicitly inside Serializer.validate()
return {
api_settings.NON_FIELD_ERRORS_KEY: list(exc.messages)
api_settings.NON_FIELD_ERRORS_KEY: get_error_messages(exc)
}
elif isinstance(exc.detail, dict):
# If errors may be a dict we use the standard {key: list of values}.
Expand Down Expand Up @@ -423,7 +423,7 @@ def to_internal_value(self, data):
datatype=type(data).__name__
)
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [message]
api_settings.NON_FIELD_ERRORS_KEY: [ErrorMessage(message, code='invalid')]
})

ret = OrderedDict()
Expand Down Expand Up @@ -580,13 +580,13 @@ def to_internal_value(self, data):
input_type=type(data).__name__
)
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [message]
api_settings.NON_FIELD_ERRORS_KEY: [ErrorMessage(message, code='not_a_list')]
})

if not self.allow_empty and len(data) == 0:
message = self.error_messages['empty']
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: [message]
api_settings.NON_FIELD_ERRORS_KEY: [ErrorMessage(message, code='empty')]
})

ret = []
Expand Down
27 changes: 15 additions & 12 deletions rest_framework/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from django.utils.translation import ugettext_lazy as _

from rest_framework.compat import unicode_to_repr
from rest_framework.exceptions import ValidationError
from rest_framework.exceptions import ErrorMessage, ValidationError
from rest_framework.utils.representation import smart_repr


Expand Down Expand Up @@ -80,7 +80,7 @@ def __call__(self, value):
queryset = self.filter_queryset(value, queryset)
queryset = self.exclude_current_instance(queryset)
if qs_exists(queryset):
raise ValidationError(self.message)
raise ValidationError(self.message, code='unique')

def __repr__(self):
return unicode_to_repr('<%s(queryset=%s)>' % (
Expand Down Expand Up @@ -120,13 +120,13 @@ def enforce_required_fields(self, attrs):
if self.instance is not None:
return

missing = {
field_name: self.missing_message
missing_items = {
field_name: ErrorMessage(self.missing_message, code='required')
for field_name in self.fields
if field_name not in attrs
}
if missing:
raise ValidationError(missing)
if missing_items:
raise ValidationError(missing_items)

def filter_queryset(self, attrs, queryset):
"""
Expand Down Expand Up @@ -167,7 +167,8 @@ def __call__(self, attrs):
]
if None not in checked_values and qs_exists(queryset):
field_names = ', '.join(self.fields)
raise ValidationError(self.message.format(field_names=field_names))
message = self.message.format(field_names=field_names)
raise ValidationError(message, code='unique')

def __repr__(self):
return unicode_to_repr('<%s(queryset=%s, fields=%s)>' % (
Expand Down Expand Up @@ -204,13 +205,13 @@ def enforce_required_fields(self, attrs):
The `UniqueFor<Range>Validator` classes always force an implied
'required' state on the fields they are applied to.
"""
missing = {
field_name: self.missing_message
missing_items = {
field_name: ErrorMessage(self.missing_message, code='required')
for field_name in [self.field, self.date_field]
if field_name not in attrs
}
if missing:
raise ValidationError(missing)
if missing_items:
raise ValidationError(missing_items)

def filter_queryset(self, attrs, queryset):
raise NotImplementedError('`filter_queryset` must be implemented.')
Expand All @@ -231,7 +232,9 @@ def __call__(self, attrs):
queryset = self.exclude_current_instance(attrs, queryset)
if qs_exists(queryset):
message = self.message.format(date_field=self.date_field)
raise ValidationError({self.field: message})
raise ValidationError({
self.field: ErrorMessage(message, code='unique')
})

def __repr__(self):
return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % (
Expand Down
8 changes: 4 additions & 4 deletions tests/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from django.test import TestCase
from django.utils.translation import ugettext_lazy as _

from rest_framework.exceptions import _force_text_recursive
from rest_framework.exceptions import ErrorMessage, _force_text_recursive


class ExceptionTestCase(TestCase):
Expand All @@ -12,10 +12,10 @@ def test_force_text_recursive(self):

s = "sfdsfggiuytraetfdlklj"
self.assertEqual(_force_text_recursive(_(s)), s)
self.assertEqual(type(_force_text_recursive(_(s))), type(s))
assert isinstance(_force_text_recursive(_(s)), ErrorMessage)

self.assertEqual(_force_text_recursive({'a': _(s)})['a'], s)
self.assertEqual(type(_force_text_recursive({'a': _(s)})['a']), type(s))
assert isinstance(_force_text_recursive({'a': _(s)})['a'], ErrorMessage)

self.assertEqual(_force_text_recursive([[_(s)]])[0][0], s)
self.assertEqual(type(_force_text_recursive([[_(s)]])[0][0]), type(s))
assert isinstance(_force_text_recursive([[_(s)]])[0][0], ErrorMessage)