diff --git a/django_filters/compat.py b/django_filters/compat.py index e7fa88b2d..cbac451ba 100644 --- a/django_filters/compat.py +++ b/django_filters/compat.py @@ -1,4 +1,6 @@ +from __future__ import absolute_import + import django from django.conf import settings @@ -12,6 +14,13 @@ is_crispy = 'crispy_forms' in settings.INSTALLED_APPS and crispy_forms +# coreapi only compatible with DRF 3.4+ +try: + from rest_framework.compat import coreapi +except ImportError: + coreapi = None + + def remote_field(field): """ https://docs.djangoproject.com/en/1.9/releases/1.9/#field-rel-changes diff --git a/django_filters/rest_framework/backends.py b/django_filters/rest_framework/backends.py index 4027ac2e0..37c7dfa57 100644 --- a/django_filters/rest_framework/backends.py +++ b/django_filters/rest_framework/backends.py @@ -89,3 +89,15 @@ def to_html(self, request, queryset, view): return template_render(template, context={ 'filter': filter_instance }) + + def get_schema_fields(self, view): + # This is not compatible with widgets where the query param differs from the + # filter's attribute name. Notably, this includes `MultiWidget`, where query + # params will be of the format `_0`, `_1`, etc... + assert compat.coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' + filter_class = self.get_filter_class(view, view.get_queryset()) + + return [] if not filter_class else [ + compat.coreapi.Field(name=field_name, required=False, location='query') + for field_name in filter_class().filters.keys() + ] diff --git a/requirements/test-ci.txt b/requirements/test-ci.txt index 85ceb08a8..2c79c1d17 100644 --- a/requirements/test-ci.txt +++ b/requirements/test-ci.txt @@ -1,3 +1,5 @@ +coreapi + coverage mock pytz diff --git a/requirements/test.txt b/requirements/test.txt index 5b885fa47..b811569ac 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,3 +1,3 @@ --r travis-ci.txt +-r test-ci.txt django djangorestframework diff --git a/tests/rest_framework/test_backends.py b/tests/rest_framework/test_backends.py index f6b383c69..763d09383 100644 --- a/tests/rest_framework/test_backends.py +++ b/tests/rest_framework/test_backends.py @@ -2,6 +2,7 @@ import datetime from decimal import Decimal +from unittest import skipIf from django.conf.urls import url from django.test import TestCase @@ -17,7 +18,7 @@ from rest_framework import generics, serializers, status from rest_framework.test import APIRequestFactory -from django_filters import filters +from django_filters import compat, filters from django_filters.rest_framework import DjangoFilterBackend, FilterSet from django_filters.rest_framework import backends @@ -121,6 +122,35 @@ def get_queryset(self): ] +@skipIf(compat.coreapi is None, 'coreapi must be installed') +class GetSchemaFieldsTests(TestCase): + def test_fields_with_filter_fields_list(self): + backend = DjangoFilterBackend() + fields = backend.get_schema_fields(FilterFieldsRootView()) + fields = [f.name for f in fields] + + self.assertEqual(fields, ['decimal', 'date']) + + def test_fields_with_filter_fields_dict(self): + class DictFilterFieldsRootView(FilterFieldsRootView): + filter_fields = { + 'decimal': ['exact', 'lt', 'gt'], + } + + backend = DjangoFilterBackend() + fields = backend.get_schema_fields(DictFilterFieldsRootView()) + fields = [f.name for f in fields] + + self.assertEqual(fields, ['decimal', 'decimal__lt', 'decimal__gt']) + + def test_fields_with_filter_class(self): + backend = DjangoFilterBackend() + fields = backend.get_schema_fields(FilterClassRootView()) + fields = [f.name for f in fields] + + self.assertEqual(fields, ['text', 'decimal', 'date']) + + class CommonFilteringTestCase(TestCase): def _serialize_object(self, obj): return {'id': obj.id, 'text': obj.text, 'decimal': str(obj.decimal), 'date': obj.date.isoformat()}