diff --git a/rest_framework/compat.py b/rest_framework/compat.py index e37feb5660..8b782a1c26 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -185,6 +185,11 @@ def apply_markdown(text): else: DurationField = duration_string = parse_duration = None +try: + # DecimalValidator is unavailable in Django < 1.9 + from django.core.validators import DecimalValidator +except ImportError: + DecimalValidator = None def set_rollback(): if hasattr(transaction, 'set_rollback'): diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py index 05dd2e379e..1d8cb22f2f 100644 --- a/rest_framework/utils/field_mapping.py +++ b/rest_framework/utils/field_mapping.py @@ -8,6 +8,7 @@ from django.db import models from django.utils.text import capfirst +from rest_framework.compat import DecimalValidator from rest_framework.validators import UniqueValidator NUMERIC_FIELD_TYPES = ( @@ -132,7 +133,7 @@ def get_field_kwargs(field_name, model_field): if isinstance(model_field, models.DecimalField): validator_kwarg = [ validator for validator in validator_kwarg - if not isinstance(validator, validators.DecimalValidator) + if DecimalValidator and not isinstance(validator, DecimalValidator) ] # Ensure that max_length is passed explicitly as a keyword arg, diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py index 93a69fe4b1..57e540e7a5 100644 --- a/tests/test_model_serializer.py +++ b/tests/test_model_serializer.py @@ -22,7 +22,7 @@ from rest_framework import serializers from rest_framework.compat import DurationField as ModelDurationField -from rest_framework.compat import unicode_repr +from rest_framework.compat import DecimalValidator, unicode_repr def dedent(blocktext): @@ -861,3 +861,41 @@ class Meta: }] assert serializer.data == expected + + +class DecimalFieldModel(models.Model): + decimal_field = models.DecimalField( + max_digits=3, + decimal_places=1, + validators=[MinValueValidator(1), MaxValueValidator(3)] + ) + + +class TestDecimalFieldMappings(TestCase): + @pytest.mark.skipif(DecimalValidator is not None, + reason='DecimalValidator is available in Django 1.9+') + def test_decimal_field_has_no_decimal_validator(self): + """ + Test that a DecimalField has no validators before Django 1.9. + """ + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = DecimalFieldModel + + serializer = TestSerializer() + + assert len(serializer.fields['decimal_field'].validators) == 0 + + @pytest.mark.skipif(DecimalValidator is None, + reason='DecimalValidator is available in Django 1.9+') + def test_decimal_field_has_decimal_validator(self): + """ + Test that a DecimalField has DecimalValidator in Django 1.9+. + """ + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = DecimalFieldModel + + serializer = TestSerializer() + + assert len(serializer.fields['decimal_field'].validators) == 2