From 08c7853655252cb9de0ade24eddfc9762a01cee7 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 1 Aug 2016 14:15:35 +0100 Subject: [PATCH 01/43] Start test case --- rest_framework/test.py | 5 +++++ tests/test_requests_client.py | 24 ++++++++++++++++++++++++ 2 files changed, 29 insertions(+) create mode 100644 tests/test_requests_client.py diff --git a/rest_framework/test.py b/rest_framework/test.py index 3ba4059a9f..fb08c4a736 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -12,6 +12,7 @@ from django.utils import six from django.utils.encoding import force_bytes from django.utils.http import urlencode +from requests import Session from rest_framework.settings import api_settings @@ -221,6 +222,10 @@ class APITransactionTestCase(testcases.TransactionTestCase): class APITestCase(testcases.TestCase): client_class = APIClient + def _pre_setup(self): + super(APITestCase, self)._pre_setup() + self.requests = Session() + class APISimpleTestCase(testcases.SimpleTestCase): client_class = APIClient diff --git a/tests/test_requests_client.py b/tests/test_requests_client.py new file mode 100644 index 0000000000..a36349a3f0 --- /dev/null +++ b/tests/test_requests_client.py @@ -0,0 +1,24 @@ +from __future__ import unicode_literals + +from django.conf.urls import url +from django.test import override_settings + +from rest_framework.response import Response +from rest_framework.test import APITestCase +from rest_framework.views import APIView + + +class Root(APIView): + def get(self, request): + return Response({'hello': 'world'}) + + +urlpatterns = [ + url(r'^$', Root.as_view()), +] + + +@override_settings(ROOT_URLCONF='tests.test_requests_client') +class RequestsClientTests(APITestCase): + def test_get_root(self): + print self.requests.get('http://example.com') From 3d1fff3f26835612be17b6624d766a56520880ec Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 15 Aug 2016 15:42:35 +0100 Subject: [PATCH 02/43] Added 'requests' test client --- rest_framework/request.py | 2 +- rest_framework/test.py | 86 +++++++++++++++++++++++- tests/test_requests_client.py | 119 +++++++++++++++++++++++++++++++++- 3 files changed, 202 insertions(+), 5 deletions(-) diff --git a/rest_framework/request.py b/rest_framework/request.py index aafafcb325..f5738bfd50 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -373,7 +373,7 @@ def POST(self): if not _hasattr(self, '_data'): self._load_data_and_files() if is_form_media_type(self.content_type): - return self.data + return self._data return QueryDict('', encoding=self._request._encoding) @property diff --git a/rest_framework/test.py b/rest_framework/test.py index fb08c4a736..1fd530a0c3 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -4,7 +4,10 @@ # to make it harder for the user to import the wrong thing without realizing. from __future__ import unicode_literals +import io + from django.conf import settings +from django.core.handlers.wsgi import WSGIHandler from django.test import testcases from django.test.client import Client as DjangoClient from django.test.client import RequestFactory as DjangoRequestFactory @@ -13,6 +16,10 @@ from django.utils.encoding import force_bytes from django.utils.http import urlencode from requests import Session +from requests.adapters import BaseAdapter +from requests.models import Response +from requests.structures import CaseInsensitiveDict +from requests.utils import get_encoding_from_headers from rest_framework.settings import api_settings @@ -22,6 +29,83 @@ def force_authenticate(request, user=None, token=None): request._force_auth_token = token +class DjangoTestAdapter(BaseAdapter): + """ + A transport adaptor for `requests`, that makes requests via the + Django WSGI app, rather than making actual HTTP requests ovet the network. + """ + def __init__(self): + self.app = WSGIHandler() + self.factory = DjangoRequestFactory() + + def get_environ(self, request): + """ + Given a `requests.PreparedRequest` instance, return a WSGI environ dict. + """ + method = request.method + url = request.url + kwargs = {} + + # Set request content, if any exists. + if request.body is not None: + kwargs['data'] = request.body + if 'content-type' in request.headers: + kwargs['content_type'] = request.headers['content-type'] + + # Set request headers. + for key, value in request.headers.items(): + key = key.upper() + if key in ('CONNECTION', 'CONTENT_LENGTH', 'CONTENT-TYPE'): + continue + kwargs['HTTP_%s' % key] = value + + return self.factory.generic(method, url, **kwargs).environ + + def send(self, request, *args, **kwargs): + """ + Make an outgoing request to the Django WSGI application. + """ + response = Response() + + def start_response(status, headers): + status_code, _, reason_phrase = status.partition(' ') + response.status_code = int(status_code) + response.reason = reason_phrase + response.headers = CaseInsensitiveDict(headers) + response.encoding = get_encoding_from_headers(response.headers) + + environ = self.get_environ(request) + raw_bytes = self.app(environ, start_response) + + response.request = request + response.url = request.url + response.raw = io.BytesIO(b''.join(raw_bytes)) + + return response + + def close(self): + pass + + +class DjangoTestSession(Session): + def __init__(self, *args, **kwargs): + super(DjangoTestSession, self).__init__(*args, **kwargs) + + adapter = DjangoTestAdapter() + hostnames = list(settings.ALLOWED_HOSTS) + ['testserver'] + + for hostname in hostnames: + if hostname == '*': + hostname = '' + self.mount('http://%s' % hostname, adapter) + self.mount('https://%s' % hostname, adapter) + + def request(self, method, url, *args, **kwargs): + if ':' not in url: + url = 'http://testserver/' + url.lstrip('/') + return super(DjangoTestSession, self).request(method, url, *args, **kwargs) + + class APIRequestFactory(DjangoRequestFactory): renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT @@ -224,7 +308,7 @@ class APITestCase(testcases.TestCase): def _pre_setup(self): super(APITestCase, self)._pre_setup() - self.requests = Session() + self.requests = DjangoTestSession() class APISimpleTestCase(testcases.SimpleTestCase): diff --git a/tests/test_requests_client.py b/tests/test_requests_client.py index a36349a3f0..0687dd92e1 100644 --- a/tests/test_requests_client.py +++ b/tests/test_requests_client.py @@ -10,15 +10,128 @@ class Root(APIView): def get(self, request): - return Response({'hello': 'world'}) + return Response({ + 'method': request.method, + 'query_params': request.query_params, + }) + + def post(self, request): + files = { + key: (value.name, value.read()) + for key, value in request.FILES.items() + } + post = request.POST + json = None + if request.META.get('CONTENT_TYPE') == 'application/json': + json = request.data + + return Response({ + 'method': request.method, + 'query_params': request.query_params, + 'POST': post, + 'FILES': files, + 'JSON': json + }) + + +class Headers(APIView): + def get(self, request): + headers = { + key[5:]: value + for key, value in request.META.items() + if key.startswith('HTTP_') + } + return Response({ + 'method': request.method, + 'headers': headers + }) urlpatterns = [ url(r'^$', Root.as_view()), + url(r'^headers/$', Headers.as_view()), ] @override_settings(ROOT_URLCONF='tests.test_requests_client') class RequestsClientTests(APITestCase): - def test_get_root(self): - print self.requests.get('http://example.com') + def test_get_request(self): + response = self.requests.get('/') + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'GET', + 'query_params': {} + } + assert response.json() == expected + + def test_get_request_query_params_in_url(self): + response = self.requests.get('/?key=value') + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'GET', + 'query_params': {'key': 'value'} + } + assert response.json() == expected + + def test_get_request_query_params_by_kwarg(self): + response = self.requests.get('/', params={'key': 'value'}) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'GET', + 'query_params': {'key': 'value'} + } + assert response.json() == expected + + def test_get_with_headers(self): + response = self.requests.get('/headers/', headers={'User-Agent': 'example'}) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + headers = response.json()['headers'] + assert headers['USER-AGENT'] == 'example' + + def test_post_form_request(self): + response = self.requests.post('/', data={'key': 'value'}) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'POST', + 'query_params': {}, + 'POST': {'key': 'value'}, + 'FILES': {}, + 'JSON': None + } + assert response.json() == expected + + def test_post_json_request(self): + response = self.requests.post('/', json={'key': 'value'}) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'POST', + 'query_params': {}, + 'POST': {}, + 'FILES': {}, + 'JSON': {'key': 'value'} + } + assert response.json() == expected + + def test_post_multipart_request(self): + files = { + 'file': ('report.csv', 'some,data,to,send\nanother,row,to,send\n') + } + response = self.requests.post('/', files=files) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'POST', + 'query_params': {}, + 'FILES': {'file': ['report.csv', 'some,data,to,send\nanother,row,to,send\n']}, + 'POST': {}, + 'JSON': None + } + assert response.json() == expected + + # cookies/session auth From e76ca6eb8838148ccb1c25a2a8a735a42f644d99 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 15 Aug 2016 16:06:04 +0100 Subject: [PATCH 03/43] Address typos --- rest_framework/test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/rest_framework/test.py b/rest_framework/test.py index 1fd530a0c3..e1d8eff82e 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -31,8 +31,8 @@ def force_authenticate(request, user=None, token=None): class DjangoTestAdapter(BaseAdapter): """ - A transport adaptor for `requests`, that makes requests via the - Django WSGI app, rather than making actual HTTP requests ovet the network. + A transport adapter for `requests`, that makes requests via the + Django WSGI app, rather than making actual HTTP requests over the network. """ def __init__(self): self.app = WSGIHandler() @@ -55,9 +55,9 @@ def get_environ(self, request): # Set request headers. for key, value in request.headers.items(): key = key.upper() - if key in ('CONNECTION', 'CONTENT_LENGTH', 'CONTENT-TYPE'): + if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'): continue - kwargs['HTTP_%s' % key] = value + kwargs['HTTP_%s' % key.replace('-', '_')] = value return self.factory.generic(method, url, **kwargs).environ From 6ede654315e415362f4c7c8e38a3f641039bfc46 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 17 Aug 2016 12:11:01 +0100 Subject: [PATCH 04/43] Graceful fallback if requests is not installed. --- rest_framework/compat.py | 7 ++ rest_framework/test.py | 161 +++++++++++++++++----------------- tests/test_requests_client.py | 6 +- 3 files changed, 92 insertions(+), 82 deletions(-) diff --git a/rest_framework/compat.py b/rest_framework/compat.py index cee430a84b..bda346fa82 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -178,6 +178,13 @@ def value_from_object(field, obj): uritemplate = None +# requests is optional +try: + import requests +except ImportError: + requests = None + + # Django-guardian is optional. Import only if guardian is in INSTALLED_APPS # Fixes (#1712). We keep the try/except for the test suite. guardian = None diff --git a/rest_framework/test.py b/rest_framework/test.py index e1d8eff82e..eba4b96cfa 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -15,12 +15,8 @@ from django.utils import six from django.utils.encoding import force_bytes from django.utils.http import urlencode -from requests import Session -from requests.adapters import BaseAdapter -from requests.models import Response -from requests.structures import CaseInsensitiveDict -from requests.utils import get_encoding_from_headers +from rest_framework.compat import requests from rest_framework.settings import api_settings @@ -29,81 +25,81 @@ def force_authenticate(request, user=None, token=None): request._force_auth_token = token -class DjangoTestAdapter(BaseAdapter): - """ - A transport adapter for `requests`, that makes requests via the - Django WSGI app, rather than making actual HTTP requests over the network. - """ - def __init__(self): - self.app = WSGIHandler() - self.factory = DjangoRequestFactory() - - def get_environ(self, request): - """ - Given a `requests.PreparedRequest` instance, return a WSGI environ dict. - """ - method = request.method - url = request.url - kwargs = {} - - # Set request content, if any exists. - if request.body is not None: - kwargs['data'] = request.body - if 'content-type' in request.headers: - kwargs['content_type'] = request.headers['content-type'] - - # Set request headers. - for key, value in request.headers.items(): - key = key.upper() - if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'): - continue - kwargs['HTTP_%s' % key.replace('-', '_')] = value - - return self.factory.generic(method, url, **kwargs).environ - - def send(self, request, *args, **kwargs): +if requests is not None: + class DjangoTestAdapter(requests.adapters.BaseAdapter): """ - Make an outgoing request to the Django WSGI application. + A transport adapter for `requests`, that makes requests via the + Django WSGI app, rather than making actual HTTP requests over the network. """ - response = Response() - - def start_response(status, headers): - status_code, _, reason_phrase = status.partition(' ') - response.status_code = int(status_code) - response.reason = reason_phrase - response.headers = CaseInsensitiveDict(headers) - response.encoding = get_encoding_from_headers(response.headers) - - environ = self.get_environ(request) - raw_bytes = self.app(environ, start_response) - - response.request = request - response.url = request.url - response.raw = io.BytesIO(b''.join(raw_bytes)) - - return response - - def close(self): - pass - - -class DjangoTestSession(Session): - def __init__(self, *args, **kwargs): - super(DjangoTestSession, self).__init__(*args, **kwargs) - - adapter = DjangoTestAdapter() - hostnames = list(settings.ALLOWED_HOSTS) + ['testserver'] - - for hostname in hostnames: - if hostname == '*': - hostname = '' - self.mount('http://%s' % hostname, adapter) - self.mount('https://%s' % hostname, adapter) - - def request(self, method, url, *args, **kwargs): - if ':' not in url: - url = 'http://testserver/' + url.lstrip('/') - return super(DjangoTestSession, self).request(method, url, *args, **kwargs) + def __init__(self): + self.app = WSGIHandler() + self.factory = DjangoRequestFactory() + + def get_environ(self, request): + """ + Given a `requests.PreparedRequest` instance, return a WSGI environ dict. + """ + method = request.method + url = request.url + kwargs = {} + + # Set request content, if any exists. + if request.body is not None: + kwargs['data'] = request.body + if 'content-type' in request.headers: + kwargs['content_type'] = request.headers['content-type'] + + # Set request headers. + for key, value in request.headers.items(): + key = key.upper() + if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'): + continue + kwargs['HTTP_%s' % key.replace('-', '_')] = value + + return self.factory.generic(method, url, **kwargs).environ + + def send(self, request, *args, **kwargs): + """ + Make an outgoing request to the Django WSGI application. + """ + response = requests.models.Response() + + def start_response(status, headers): + status_code, _, reason_phrase = status.partition(' ') + response.status_code = int(status_code) + response.reason = reason_phrase + response.headers = requests.structures.CaseInsensitiveDict(headers) + response.encoding = requests.utils.get_encoding_from_headers(response.headers) + + environ = self.get_environ(request) + raw_bytes = self.app(environ, start_response) + + response.request = request + response.url = request.url + response.raw = io.BytesIO(b''.join(raw_bytes)) + + return response + + def close(self): + pass + + class DjangoTestSession(requests.Session): + def __init__(self, *args, **kwargs): + super(DjangoTestSession, self).__init__(*args, **kwargs) + + adapter = DjangoTestAdapter() + hostnames = list(settings.ALLOWED_HOSTS) + ['testserver'] + + for hostname in hostnames: + if hostname == '*': + hostname = '' + self.mount('http://%s' % hostname, adapter) + self.mount('https://%s' % hostname, adapter) + + def request(self, method, url, *args, **kwargs): + if ':' not in url: + url = 'http://testserver/' + url.lstrip('/') + return super(DjangoTestSession, self).request(method, url, *args, **kwargs) class APIRequestFactory(DjangoRequestFactory): @@ -306,9 +302,12 @@ class APITransactionTestCase(testcases.TransactionTestCase): class APITestCase(testcases.TestCase): client_class = APIClient - def _pre_setup(self): - super(APITestCase, self)._pre_setup() - self.requests = DjangoTestSession() + @property + def requests(self): + if not hasattr(self, '_requests'): + assert requests is not None, 'requests must be installed' + self._requests = DjangoTestSession() + return self._requests class APISimpleTestCase(testcases.SimpleTestCase): diff --git a/tests/test_requests_client.py b/tests/test_requests_client.py index 0687dd92e1..24e29d3b86 100644 --- a/tests/test_requests_client.py +++ b/tests/test_requests_client.py @@ -1,8 +1,11 @@ from __future__ import unicode_literals +import unittest + from django.conf.urls import url from django.test import override_settings +from rest_framework.compat import requests from rest_framework.response import Response from rest_framework.test import APITestCase from rest_framework.views import APIView @@ -37,7 +40,7 @@ def post(self, request): class Headers(APIView): def get(self, request): headers = { - key[5:]: value + key[5:].replace('_', '-'): value for key, value in request.META.items() if key.startswith('HTTP_') } @@ -53,6 +56,7 @@ def get(self, request): ] +@unittest.skipUnless(requests, 'requests not installed') @override_settings(ROOT_URLCONF='tests.test_requests_client') class RequestsClientTests(APITestCase): def test_get_request(self): From 049a39e060ab8bbd028ffaa0fb2f5104a71f400d Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 17 Aug 2016 15:43:12 +0100 Subject: [PATCH 05/43] Add cookie support --- rest_framework/test.py | 41 +++++++++++++------ tests/test_requests_client.py | 76 ++++++++++++++++++++++++++++++++++- 2 files changed, 102 insertions(+), 15 deletions(-) diff --git a/rest_framework/test.py b/rest_framework/test.py index eba4b96cfa..bc8ecc5db6 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -26,7 +26,7 @@ def force_authenticate(request, user=None, token=None): if requests is not None: - class DjangoTestAdapter(requests.adapters.BaseAdapter): + class DjangoTestAdapter(requests.adapters.HTTPAdapter): """ A transport adapter for `requests`, that makes requests via the Django WSGI app, rather than making actual HTTP requests over the network. @@ -62,23 +62,38 @@ def send(self, request, *args, **kwargs): """ Make an outgoing request to the Django WSGI application. """ - response = requests.models.Response() + raw_kwargs = {} - def start_response(status, headers): - status_code, _, reason_phrase = status.partition(' ') - response.status_code = int(status_code) - response.reason = reason_phrase - response.headers = requests.structures.CaseInsensitiveDict(headers) - response.encoding = requests.utils.get_encoding_from_headers(response.headers) + def start_response(wsgi_status, wsgi_headers): + class MockOriginalResponse(object): + def __init__(self, headers): + self.msg = requests.packages.urllib3._collections.HTTPHeaderDict(headers) + self.closed = False + def isclosed(self): + return self.closed + + def close(self): + self.closed = True + + status, _, reason = wsgi_status.partition(' ') + raw_kwargs['status'] = int(status) + raw_kwargs['reason'] = reason + raw_kwargs['headers'] = wsgi_headers + raw_kwargs['version'] = 11 + raw_kwargs['preload_content'] = False + raw_kwargs['original_response'] = MockOriginalResponse(wsgi_headers) + + # Make the outgoing request via WSGI. environ = self.get_environ(request) - raw_bytes = self.app(environ, start_response) + wsgi_response = self.app(environ, start_response) - response.request = request - response.url = request.url - response.raw = io.BytesIO(b''.join(raw_bytes)) + # Build the underlying urllib3.HTTPResponse + raw_kwargs['body'] = io.BytesIO(b''.join(wsgi_response)) + raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs) - return response + # Build the requests.Response + return self.build_response(request, raw) def close(self): pass diff --git a/tests/test_requests_client.py b/tests/test_requests_client.py index 24e29d3b86..10158efa7d 100644 --- a/tests/test_requests_client.py +++ b/tests/test_requests_client.py @@ -37,7 +37,7 @@ def post(self, request): }) -class Headers(APIView): +class HeadersView(APIView): def get(self, request): headers = { key[5:].replace('_', '-'): value @@ -50,9 +50,32 @@ def get(self, request): }) +class SessionView(APIView): + def get(self, request): + return Response({ + key: value for key, value in request.session.items() + }) + + def post(self, request): + for key, value in request.data.items(): + request.session[key] = value + return Response({ + key: value for key, value in request.session.items() + }) + + +class CookiesView(APIView): + def get(self, request): + return Response({ + key: value for key, value in request.COOKIES.items() + }) + + urlpatterns = [ url(r'^$', Root.as_view()), - url(r'^headers/$', Headers.as_view()), + url(r'^headers/$', HeadersView.as_view()), + url(r'^session/$', SessionView.as_view()), + url(r'^cookies/$', CookiesView.as_view()), ] @@ -138,4 +161,53 @@ def test_post_multipart_request(self): } assert response.json() == expected + def test_session(self): + response = self.requests.get('/session/') + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = {} + assert response.json() == expected + + response = self.requests.post('/session/', json={'example': 'abc'}) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = {'example': 'abc'} + assert response.json() == expected + + response = self.requests.get('/session/') + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = {'example': 'abc'} + assert response.json() == expected + + def test_cookies(self): + """ + Test for explicitly setting a cookie. + """ + my_cookie = { + "version": 0, + "name": 'COOKIE_NAME', + "value": 'COOKIE_VALUE', + "port": None, + # "port_specified":False, + "domain": 'testserver.local', + # "domain_specified":False, + # "domain_initial_dot":False, + "path": '/', + # "path_specified":True, + "secure": False, + "expires": None, + "discard": True, + "comment": None, + "comment_url": None, + "rest": {}, + "rfc2109": False + } + self.requests.cookies.set(**my_cookie) + response = self.requests.get('/cookies/') + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = {'COOKIE_NAME': 'COOKIE_VALUE'} + assert response.json() == expected + # cookies/session auth From 64e19c738fce463df6fafd8f377ecc702f068813 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 17 Aug 2016 17:54:03 +0100 Subject: [PATCH 06/43] Tests for auth and CSRF --- tests/test_requests_client.py | 85 ++++++++++++++++++++++------------- 1 file changed, 55 insertions(+), 30 deletions(-) diff --git a/tests/test_requests_client.py b/tests/test_requests_client.py index 10158efa7d..aa99a71da5 100644 --- a/tests/test_requests_client.py +++ b/tests/test_requests_client.py @@ -3,9 +3,14 @@ import unittest from django.conf.urls import url +from django.contrib.auth import authenticate, login +from django.contrib.auth.models import User +from django.shortcuts import redirect from django.test import override_settings +from django.utils.decorators import method_decorator +from django.views.decorators.csrf import csrf_protect, ensure_csrf_cookie -from rest_framework.compat import requests +from rest_framework.compat import is_authenticated, requests from rest_framework.response import Response from rest_framework.test import APITestCase from rest_framework.views import APIView @@ -64,18 +69,33 @@ def post(self, request): }) -class CookiesView(APIView): +class AuthView(APIView): + @method_decorator(ensure_csrf_cookie) def get(self, request): + if is_authenticated(request.user): + username = request.user.username + else: + username = None return Response({ - key: value for key, value in request.COOKIES.items() + 'username': username }) + @method_decorator(csrf_protect) + def post(self, request): + username = request.data['username'] + password = request.data['password'] + user = authenticate(username=username, password=password) + if user is None: + return Response({'error': 'incorrect credentials'}) + login(request, user) + return redirect('/auth/') + urlpatterns = [ url(r'^$', Root.as_view()), url(r'^headers/$', HeadersView.as_view()), url(r'^session/$', SessionView.as_view()), - url(r'^cookies/$', CookiesView.as_view()), + url(r'^auth/$', AuthView.as_view()), ] @@ -180,34 +200,39 @@ def test_session(self): expected = {'example': 'abc'} assert response.json() == expected - def test_cookies(self): - """ - Test for explicitly setting a cookie. - """ - my_cookie = { - "version": 0, - "name": 'COOKIE_NAME', - "value": 'COOKIE_VALUE', - "port": None, - # "port_specified":False, - "domain": 'testserver.local', - # "domain_specified":False, - # "domain_initial_dot":False, - "path": '/', - # "path_specified":True, - "secure": False, - "expires": None, - "discard": True, - "comment": None, - "comment_url": None, - "rest": {}, - "rfc2109": False + def test_auth(self): + # Confirm session is not authenticated + response = self.requests.get('/auth/') + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'username': None } - self.requests.cookies.set(**my_cookie) - response = self.requests.get('/cookies/') + assert response.json() == expected + assert 'csrftoken' in response.cookies + csrftoken = response.cookies['csrftoken'] + + user = User.objects.create(username='tom') + user.set_password('password') + user.save() + + # Perform a login + response = self.requests.post('/auth/', json={ + 'username': 'tom', + 'password': 'password' + }, headers={'X-CSRFToken': csrftoken}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' - expected = {'COOKIE_NAME': 'COOKIE_VALUE'} + expected = { + 'username': 'tom' + } assert response.json() == expected - # cookies/session auth + # Confirm session is authenticated + response = self.requests.get('/auth/') + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'username': 'tom' + } + assert response.json() == expected From da47c345c09eaf4fffca9cfcd18bba7acd844e8b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 17 Aug 2016 18:09:19 +0100 Subject: [PATCH 07/43] Py3 compat --- rest_framework/test.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/rest_framework/test.py b/rest_framework/test.py index bc8ecc5db6..a95d185376 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -26,6 +26,21 @@ def force_authenticate(request, user=None, token=None): if requests is not None: + class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict): + def get_all(self, key): + return self.getheaders(self, key) + + class MockOriginalResponse(object): + def __init__(self, headers): + self.msg = HeaderDict(headers) + self.closed = False + + def isclosed(self): + return self.closed + + def close(self): + self.closed = True + class DjangoTestAdapter(requests.adapters.HTTPAdapter): """ A transport adapter for `requests`, that makes requests via the @@ -65,17 +80,6 @@ def send(self, request, *args, **kwargs): raw_kwargs = {} def start_response(wsgi_status, wsgi_headers): - class MockOriginalResponse(object): - def __init__(self, headers): - self.msg = requests.packages.urllib3._collections.HTTPHeaderDict(headers) - self.closed = False - - def isclosed(self): - return self.closed - - def close(self): - self.closed = True - status, _, reason = wsgi_status.partition(' ') raw_kwargs['status'] = int(status) raw_kwargs['reason'] = reason From 53117698e042691a7045688d41edae9cc118643b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 17 Aug 2016 18:47:01 +0100 Subject: [PATCH 08/43] py3 compat --- rest_framework/test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rest_framework/test.py b/rest_framework/test.py index a95d185376..bf22ff08d5 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -27,7 +27,7 @@ def force_authenticate(request, user=None, token=None): if requests is not None: class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict): - def get_all(self, key): + def get_all(self, key, default): return self.getheaders(self, key) class MockOriginalResponse(object): From 0b3db028a2a7a0a91c3111fc6febbdbcd9cbd6b5 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 17 Aug 2016 18:50:02 +0100 Subject: [PATCH 09/43] py3 compat --- rest_framework/test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rest_framework/test.py b/rest_framework/test.py index bf22ff08d5..ded9d5fe93 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -28,7 +28,7 @@ def force_authenticate(request, user=None, token=None): if requests is not None: class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict): def get_all(self, key, default): - return self.getheaders(self, key) + return self.getheaders(key) class MockOriginalResponse(object): def __init__(self, headers): From 0cc3f5008fcd41d1597776c43dc57502ed2a7542 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 18 Aug 2016 15:34:19 +0100 Subject: [PATCH 10/43] Add get_requests_client --- rest_framework/test.py | 12 +++++------- tests/test_requests_client.py | 37 ++++++++++++++++++++++------------- 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/rest_framework/test.py b/rest_framework/test.py index ded9d5fe93..e17c19a43f 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -121,6 +121,11 @@ def request(self, method, url, *args, **kwargs): return super(DjangoTestSession, self).request(method, url, *args, **kwargs) +def get_requests_client(): + assert requests is not None, 'requests must be installed' + return DjangoTestSession() + + class APIRequestFactory(DjangoRequestFactory): renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT @@ -321,13 +326,6 @@ class APITransactionTestCase(testcases.TransactionTestCase): class APITestCase(testcases.TestCase): client_class = APIClient - @property - def requests(self): - if not hasattr(self, '_requests'): - assert requests is not None, 'requests must be installed' - self._requests = DjangoTestSession() - return self._requests - class APISimpleTestCase(testcases.SimpleTestCase): client_class = APIClient diff --git a/tests/test_requests_client.py b/tests/test_requests_client.py index aa99a71da5..37bde10922 100644 --- a/tests/test_requests_client.py +++ b/tests/test_requests_client.py @@ -12,7 +12,7 @@ from rest_framework.compat import is_authenticated, requests from rest_framework.response import Response -from rest_framework.test import APITestCase +from rest_framework.test import APITestCase, get_requests_client from rest_framework.views import APIView @@ -103,7 +103,8 @@ def post(self, request): @override_settings(ROOT_URLCONF='tests.test_requests_client') class RequestsClientTests(APITestCase): def test_get_request(self): - response = self.requests.get('/') + client = get_requests_client() + response = client.get('/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -113,7 +114,8 @@ def test_get_request(self): assert response.json() == expected def test_get_request_query_params_in_url(self): - response = self.requests.get('/?key=value') + client = get_requests_client() + response = client.get('/?key=value') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -123,7 +125,8 @@ def test_get_request_query_params_in_url(self): assert response.json() == expected def test_get_request_query_params_by_kwarg(self): - response = self.requests.get('/', params={'key': 'value'}) + client = get_requests_client() + response = client.get('/', params={'key': 'value'}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -133,14 +136,16 @@ def test_get_request_query_params_by_kwarg(self): assert response.json() == expected def test_get_with_headers(self): - response = self.requests.get('/headers/', headers={'User-Agent': 'example'}) + client = get_requests_client() + response = client.get('/headers/', headers={'User-Agent': 'example'}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' headers = response.json()['headers'] assert headers['USER-AGENT'] == 'example' def test_post_form_request(self): - response = self.requests.post('/', data={'key': 'value'}) + client = get_requests_client() + response = client.post('/', data={'key': 'value'}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -153,7 +158,8 @@ def test_post_form_request(self): assert response.json() == expected def test_post_json_request(self): - response = self.requests.post('/', json={'key': 'value'}) + client = get_requests_client() + response = client.post('/', json={'key': 'value'}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -166,10 +172,11 @@ def test_post_json_request(self): assert response.json() == expected def test_post_multipart_request(self): + client = get_requests_client() files = { 'file': ('report.csv', 'some,data,to,send\nanother,row,to,send\n') } - response = self.requests.post('/', files=files) + response = client.post('/', files=files) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -182,19 +189,20 @@ def test_post_multipart_request(self): assert response.json() == expected def test_session(self): - response = self.requests.get('/session/') + client = get_requests_client() + response = client.get('/session/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = {} assert response.json() == expected - response = self.requests.post('/session/', json={'example': 'abc'}) + response = client.post('/session/', json={'example': 'abc'}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = {'example': 'abc'} assert response.json() == expected - response = self.requests.get('/session/') + response = client.get('/session/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = {'example': 'abc'} @@ -202,7 +210,8 @@ def test_session(self): def test_auth(self): # Confirm session is not authenticated - response = self.requests.get('/auth/') + client = get_requests_client() + response = client.get('/auth/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -217,7 +226,7 @@ def test_auth(self): user.save() # Perform a login - response = self.requests.post('/auth/', json={ + response = client.post('/auth/', json={ 'username': 'tom', 'password': 'password' }, headers={'X-CSRFToken': csrftoken}) @@ -229,7 +238,7 @@ def test_auth(self): assert response.json() == expected # Confirm session is authenticated - response = self.requests.get('/auth/') + response = client.get('/auth/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { From e4f692831e850a710c60a7343e76bca19e8e6e83 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 2 Sep 2016 18:04:19 +0100 Subject: [PATCH 11/43] Added SchemaGenerator.should_include_link --- rest_framework/schemas.py | 65 ++++++++++++++++++++++++--------------- 1 file changed, 40 insertions(+), 25 deletions(-) diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index 1b899450ff..bf1e6dd4a5 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -72,31 +72,10 @@ def get_schema(self, request=None): links = [] for path, method, category, action, callback in self.endpoints: - view = callback.cls() - for attr, val in getattr(callback, 'initkwargs', {}).items(): - setattr(view, attr, val) - view.args = () - view.kwargs = {} - view.format_kwarg = None - - actions = getattr(callback, 'actions', None) - if actions is not None: - if method == 'OPTIONS': - view.action = 'metadata' - else: - view.action = actions.get(method.lower()) - - if request is not None: - view.request = clone_request(request, method) - try: - view.check_permissions(view.request) - except exceptions.APIException: - continue - else: - view.request = None - - link = self.get_link(path, method, callback, view) - links.append((category, action, link)) + view = self.setup_view(callback, method, request) + if self.should_include_link(path, method, callback, view): + link = self.get_link(path, method, callback, view) + links.append((category, action, link)) if not links: return None @@ -215,8 +194,44 @@ def get_category(self, path, method, callback, action): except IndexError: return None + def setup_view(self, callback, method, request): + """ + Setup a view instance. + """ + view = callback.cls() + for attr, val in getattr(callback, 'initkwargs', {}).items(): + setattr(view, attr, val) + view.args = () + view.kwargs = {} + view.format_kwarg = None + + actions = getattr(callback, 'actions', None) + if actions is not None: + if method == 'OPTIONS': + view.action = 'metadata' + else: + view.action = actions.get(method.lower()) + + if request is not None: + view.request = clone_request(request, method) + else: + view.request = None + + return view + # Methods for generating each individual `Link` instance... + def should_include_link(self, path, method, callback, view): + if view.request is None: + return True + + try: + view.check_permissions(view.request) + except exceptions.APIException: + return False + + return True + def get_link(self, path, method, callback, view): """ Return a `coreapi.Link` instance for the given endpoint. From 46b9e4edf9ba7c5b10deef4515d287898cdebe38 Mon Sep 17 00:00:00 2001 From: Andy Schriner Date: Mon, 22 Aug 2016 11:38:33 -0700 Subject: [PATCH 12/43] add settings for html cutoff on related fields --- docs/api-guide/relations.md | 2 ++ docs/api-guide/settings.md | 16 ++++++++++++++++ rest_framework/relations.py | 31 +++++++++++++++++++++---------- rest_framework/settings.py | 4 ++++ 4 files changed, 43 insertions(+), 10 deletions(-) diff --git a/docs/api-guide/relations.md b/docs/api-guide/relations.md index 8695b2c1e6..75643a83e7 100644 --- a/docs/api-guide/relations.md +++ b/docs/api-guide/relations.md @@ -457,6 +457,8 @@ There are two keyword arguments you can use to control this behavior: - `html_cutoff` - If set this will be the maximum number of choices that will be displayed by a HTML select drop down. Set to `None` to disable any limiting. Defaults to `1000`. - `html_cutoff_text` - If set this will display a textual indicator if the maximum number of items have been cutoff in an HTML select drop down. Defaults to `"More than {count} items…"` +You can also control these globally using the settings `HTML_SELECT_CUTOFF` and `HTML_SELECT_CUTOFF_TEXT`. + In cases where the cutoff is being enforced you may want to instead use a plain input field in the HTML form. You can do so using the `style` keyword argument. For example: assigned_to = serializers.SlugRelatedField( diff --git a/docs/api-guide/settings.md b/docs/api-guide/settings.md index ea018053fe..67d317839f 100644 --- a/docs/api-guide/settings.md +++ b/docs/api-guide/settings.md @@ -382,6 +382,22 @@ This should be a function with the following signature: Default: `'rest_framework.views.get_view_description'` +## HTML Select Field cutoffs + +Global settings for [select field cutoffs for rendering relational fields](relations.md#select-field-cutoffs) in the browsable API. + +#### HTML_SELECT_CUTOFF + +Global setting for the `html_cutoff` value. Must be an integer. + +Default: 1000 + +#### HTML_SELECT_CUTOFF_TEXT + +A string representing a global setting for `html_cutoff_text`. + +Default: `"More than {count} items..."` + --- ## Miscellaneous settings diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 65c4c03187..7fe22ec5b8 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -18,6 +18,7 @@ Field, empty, get_attribute, is_simple_callable, iter_options ) from rest_framework.reverse import reverse +from rest_framework.settings import api_settings from rest_framework.utils import html @@ -71,14 +72,19 @@ def __str__(self): class RelatedField(Field): queryset = None - html_cutoff = 1000 - html_cutoff_text = _('More than {count} items...') + html_cutoff = None + html_cutoff_text = None def __init__(self, **kwargs): self.queryset = kwargs.pop('queryset', self.queryset) - self.html_cutoff = kwargs.pop('html_cutoff', self.html_cutoff) - self.html_cutoff_text = kwargs.pop('html_cutoff_text', self.html_cutoff_text) - + self.html_cutoff = kwargs.pop( + 'html_cutoff', + self.html_cutoff or int(api_settings.HTML_SELECT_CUTOFF) + ) + self.html_cutoff_text = kwargs.pop( + 'html_cutoff_text', + self.html_cutoff_text or _(api_settings.HTML_SELECT_CUTOFF_TEXT) + ) if not method_overridden('get_queryset', RelatedField, self): assert self.queryset is not None or kwargs.get('read_only', None), ( 'Relational field must provide a `queryset` argument, ' @@ -447,15 +453,20 @@ class ManyRelatedField(Field): 'not_a_list': _('Expected a list of items but got type "{input_type}".'), 'empty': _('This list may not be empty.') } - html_cutoff = 1000 - html_cutoff_text = _('More than {count} items...') + html_cutoff = None + html_cutoff_text = None def __init__(self, child_relation=None, *args, **kwargs): self.child_relation = child_relation self.allow_empty = kwargs.pop('allow_empty', True) - self.html_cutoff = kwargs.pop('html_cutoff', self.html_cutoff) - self.html_cutoff_text = kwargs.pop('html_cutoff_text', self.html_cutoff_text) - + self.html_cutoff = kwargs.pop( + 'html_cutoff', + self.html_cutoff or int(api_settings.HTML_SELECT_CUTOFF) + ) + self.html_cutoff_text = kwargs.pop( + 'html_cutoff_text', + self.html_cutoff_text or _(api_settings.HTML_SELECT_CUTOFF_TEXT) + ) assert child_relation is not None, '`child_relation` is a required argument.' super(ManyRelatedField, self).__init__(*args, **kwargs) self.child_relation.bind(field_name='', parent=self) diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 68c7709e87..89e27e743e 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -111,6 +111,10 @@ 'COMPACT_JSON': True, 'COERCE_DECIMAL_TO_STRING': True, 'UPLOADED_FILES_USE_URL': True, + + # Browseable API + 'HTML_SELECT_CUTOFF': 1000, + 'HTML_SELECT_CUTOFF_TEXT': "More than {count} items...", } From a556b9cb426f2893cb6807299a94044d585b7c2c Mon Sep 17 00:00:00 2001 From: Christian Sauer Date: Wed, 14 Sep 2016 18:01:30 -0400 Subject: [PATCH 13/43] Router doesn't work if prefix is blank, though project urls.py handles prefix --- tests/test_routers.py | 49 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/tests/test_routers.py b/tests/test_routers.py index f45039f802..d28e301a0a 100644 --- a/tests/test_routers.py +++ b/tests/test_routers.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals +import json from collections import namedtuple from django.conf.urls import include, url @@ -47,6 +48,21 @@ class MockViewSet(viewsets.ModelViewSet): serializer_class = None +class EmptyPrefixSerializer(serializers.HyperlinkedModelSerializer): + class Meta: + model = RouterTestModel + fields = ('uuid', 'text') + + +class EmptyPrefixViewSet(viewsets.ModelViewSet): + queryset = [RouterTestModel(id=1, uuid='111', text='First'), RouterTestModel(id=2, uuid='222', text='Second')] + serializer_class = EmptyPrefixSerializer + + def get_object(self, *args, **kwargs): + index = int(self.kwargs['pk']) - 1 + return self.queryset[index] + + notes_router = SimpleRouter() notes_router.register(r'notes', NoteViewSet) @@ -56,11 +72,19 @@ class MockViewSet(viewsets.ModelViewSet): namespaced_router = DefaultRouter() namespaced_router.register(r'example', MockViewSet, base_name='example') +empty_prefix_router = SimpleRouter() +empty_prefix_router.register(r'', EmptyPrefixViewSet, base_name='empty_prefix') +empty_prefix_urls = [ + url(r'^', include(empty_prefix_router.urls)), +] + urlpatterns = [ url(r'^non-namespaced/', include(namespaced_router.urls)), url(r'^namespaced/', include(namespaced_router.urls, namespace='example')), url(r'^example/', include(notes_router.urls)), url(r'^example2/', include(kwarged_notes_router.urls)), + + url(r'^empty-prefix/', include(empty_prefix_urls)), ] @@ -384,3 +408,28 @@ def test_list_and_detail_route_decorators(self): def test_inherited_list_and_detail_route_decorators(self): self._test_list_and_detail_route_decorators(SubDynamicListAndDetailViewSet) + + +@override_settings(ROOT_URLCONF='tests.test_routers') +class TestEmptyPrefix(TestCase): + def test_empty_prefix_list(self): + response = self.client.get('/empty-prefix/') + self.assertEqual(200, response.status_code) + self.assertEqual( + json.loads(response.content.decode('utf-8')), + [ + {'uuid': '111', 'text': 'First'}, + {'uuid': '222', 'text': 'Second'} + ] + ) + + def test_empty_prefix_detail(self): + response = self.client.get('/empty-prefix/1/') + self.assertEqual(200, response.status_code) + self.assertEqual( + json.loads(response.content.decode('utf-8')), + { + 'uuid': '111', + 'text': 'First' + } + ) From 8609c9ca8c0d9c747d594d2b4f5cc7cc8be98f0a Mon Sep 17 00:00:00 2001 From: Ryan P Kilby Date: Wed, 14 Sep 2016 23:44:48 -0400 Subject: [PATCH 14/43] Fix Django 1.10 to-many deprecation --- rest_framework/compat.py | 8 ++++++++ rest_framework/serializers.py | 10 +++++++--- tests/test_model_serializer.py | 4 ++-- tests/test_permissions.py | 10 +++++----- 4 files changed, 22 insertions(+), 10 deletions(-) diff --git a/rest_framework/compat.py b/rest_framework/compat.py index cee430a84b..9ee40b257f 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -277,3 +277,11 @@ def template_render(template, context=None, request=None): # backends template, e.g. django.template.backends.django.Template else: return template.render(context, request=request) + + +def set_many(instance, field, value): + if django.VERSION < (1, 10): + setattr(instance, field, value) + else: + field = getattr(instance, field) + field.set(value) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 4d1ed63aef..28f70bd406 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -23,7 +23,7 @@ from django.utils.translation import ugettext_lazy as _ from rest_framework.compat import JSONField as ModelJSONField -from rest_framework.compat import postgres_fields, unicode_to_repr +from rest_framework.compat import postgres_fields, set_many, unicode_to_repr from rest_framework.utils import model_meta from rest_framework.utils.field_mapping import ( ClassLookupDict, get_field_kwargs, get_nested_relation_kwargs, @@ -892,19 +892,23 @@ def create(self, validated_data): # Save many-to-many relationships after the instance is created. if many_to_many: for field_name, value in many_to_many.items(): - setattr(instance, field_name, value) + set_many(instance, field_name, value) return instance def update(self, instance, validated_data): raise_errors_on_nested_writes('update', self, validated_data) + info = model_meta.get_field_info(instance) # Simply set each attribute on the instance, and then save it. # Note that unlike `.create()` we don't need to treat many-to-many # relationships as being a special case. During updates we already # have an instance pk for the relationships to be associated with. for attr, value in validated_data.items(): - setattr(instance, attr, value) + if attr in info.relations and info.relations[attr].to_many: + set_many(instance, attr, value) + else: + setattr(instance, attr, value) instance.save() return instance diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py index 01243ff6ee..5dac16f2bc 100644 --- a/tests/test_model_serializer.py +++ b/tests/test_model_serializer.py @@ -20,7 +20,7 @@ from django.utils import six from rest_framework import serializers -from rest_framework.compat import unicode_repr +from rest_framework.compat import set_many, unicode_repr def dedent(blocktext): @@ -651,7 +651,7 @@ def setUp(self): foreign_key=self.foreign_key_target, one_to_one=self.one_to_one_target, ) - self.instance.many_to_many = self.many_to_many_targets + set_many(self.instance, 'many_to_many', self.many_to_many_targets) self.instance.save() def test_pk_retrival(self): diff --git a/tests/test_permissions.py b/tests/test_permissions.py index 5cef22628b..0445f27caa 100644 --- a/tests/test_permissions.py +++ b/tests/test_permissions.py @@ -12,7 +12,7 @@ HTTP_HEADER_ENCODING, authentication, generics, permissions, serializers, status ) -from rest_framework.compat import guardian +from rest_framework.compat import guardian, set_many from rest_framework.filters import DjangoObjectPermissionsFilter from rest_framework.routers import DefaultRouter from rest_framework.test import APIRequestFactory @@ -74,15 +74,15 @@ class ModelPermissionsIntegrationTests(TestCase): def setUp(self): User.objects.create_user('disallowed', 'disallowed@example.com', 'password') user = User.objects.create_user('permitted', 'permitted@example.com', 'password') - user.user_permissions = [ + set_many(user, 'user_permissions', [ Permission.objects.get(codename='add_basicmodel'), Permission.objects.get(codename='change_basicmodel'), Permission.objects.get(codename='delete_basicmodel') - ] + ]) user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password') - user.user_permissions = [ + set_many(user, 'user_permissions', [ Permission.objects.get(codename='change_basicmodel'), - ] + ]) self.permitted_credentials = basic_auth_header('permitted', 'password') self.disallowed_credentials = basic_auth_header('disallowed', 'password') From 197b63ab85edb0bb1f02a5bc05ce87b1e4d541a5 Mon Sep 17 00:00:00 2001 From: Ryan P Kilby Date: Wed, 14 Sep 2016 23:51:51 -0400 Subject: [PATCH 15/43] Add django.core.urlresolvers compatibility --- docs/api-guide/reverse.md | 4 ++-- docs/api-guide/testing.md | 2 +- rest_framework/compat.py | 10 ++++++++++ rest_framework/relations.py | 6 +++--- rest_framework/reverse.py | 6 +++--- rest_framework/routers.py | 2 +- rest_framework/schemas.py | 5 +++-- rest_framework/templatetags/rest_framework.py | 3 +-- rest_framework/urlpatterns.py | 2 +- rest_framework/utils/breadcrumbs.py | 2 +- tests/test_filters.py | 3 +-- tests/test_permissions.py | 3 +-- tests/test_reverse.py | 2 +- tests/test_urlpatterns.py | 8 ++++---- tests/utils.py | 2 +- 15 files changed, 34 insertions(+), 26 deletions(-) diff --git a/docs/api-guide/reverse.md b/docs/api-guide/reverse.md index 71fb83f9e9..35d88e2db1 100644 --- a/docs/api-guide/reverse.md +++ b/docs/api-guide/reverse.md @@ -23,7 +23,7 @@ There's no requirement for you to use them, but if you do then the self-describi **Signature:** `reverse(viewname, *args, **kwargs)` -Has the same behavior as [`django.core.urlresolvers.reverse`][reverse], except that it returns a fully qualified URL, using the request to determine the host and port. +Has the same behavior as [`django.urls.reverse`][reverse], except that it returns a fully qualified URL, using the request to determine the host and port. You should **include the request as a keyword argument** to the function, for example: @@ -44,7 +44,7 @@ You should **include the request as a keyword argument** to the function, for ex **Signature:** `reverse_lazy(viewname, *args, **kwargs)` -Has the same behavior as [`django.core.urlresolvers.reverse_lazy`][reverse-lazy], except that it returns a fully qualified URL, using the request to determine the host and port. +Has the same behavior as [`django.urls.reverse_lazy`][reverse-lazy], except that it returns a fully qualified URL, using the request to determine the host and port. As with the `reverse` function, you should **include the request as a keyword argument** to the function, for example: diff --git a/docs/api-guide/testing.md b/docs/api-guide/testing.md index 69da7d1056..18f9e19e90 100644 --- a/docs/api-guide/testing.md +++ b/docs/api-guide/testing.md @@ -197,7 +197,7 @@ REST framework includes the following test case classes, that mirror the existin You can use any of REST framework's test case classes as you would for the regular Django test case classes. The `self.client` attribute will be an `APIClient` instance. - from django.core.urlresolvers import reverse + from django.urls import reverse from rest_framework import status from rest_framework.test import APITestCase from myproject.apps.core.models import Account diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 9ee40b257f..958c481cd1 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -23,6 +23,16 @@ from django.utils import importlib # Will be removed in Django 1.9 +try: + from django.urls import ( + NoReverseMatch, RegexURLPattern, RegexURLResolver, ResolverMatch, Resolver404, get_script_prefix, reverse, reverse_lazy, resolve + ) +except ImportError: + from django.core.urlresolvers import ( # Will be removed in Django 2.0 + NoReverseMatch, RegexURLPattern, RegexURLResolver, ResolverMatch, Resolver404, get_script_prefix, reverse, reverse_lazy, resolve + ) + + try: import urlparse # Python 2.x except ImportError: diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 65c4c03187..4317d1199d 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -4,9 +4,6 @@ from collections import OrderedDict from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist -from django.core.urlresolvers import ( - NoReverseMatch, Resolver404, get_script_prefix, resolve -) from django.db.models import Manager from django.db.models.query import QuerySet from django.utils import six @@ -14,6 +11,9 @@ from django.utils.six.moves.urllib import parse as urlparse from django.utils.translation import ugettext_lazy as _ +from rest_framework.compat import ( + NoReverseMatch, Resolver404, get_script_prefix, resolve +) from rest_framework.fields import ( Field, empty, get_attribute, is_simple_callable, iter_options ) diff --git a/rest_framework/reverse.py b/rest_framework/reverse.py index 5a7ba09a87..fd418dcca0 100644 --- a/rest_framework/reverse.py +++ b/rest_framework/reverse.py @@ -3,11 +3,11 @@ """ from __future__ import unicode_literals -from django.core.urlresolvers import reverse as django_reverse -from django.core.urlresolvers import NoReverseMatch from django.utils import six from django.utils.functional import lazy +from rest_framework.compat import reverse as django_reverse +from rest_framework.compat import NoReverseMatch from rest_framework.settings import api_settings from rest_framework.utils.urls import replace_query_param @@ -54,7 +54,7 @@ def reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra def _reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra): """ - Same as `django.core.urlresolvers.reverse`, but optionally takes a request + Same as `django.urls.reverse`, but optionally takes a request and returns a fully qualified URL, using the request to get the base URL. """ if format is not None: diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 4eec70bda3..c6516b06bc 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -20,9 +20,9 @@ from django.conf.urls import url from django.core.exceptions import ImproperlyConfigured -from django.core.urlresolvers import NoReverseMatch from rest_framework import exceptions, renderers, views +from rest_framework.compat import NoReverseMatch from rest_framework.response import Response from rest_framework.reverse import reverse from rest_framework.schemas import SchemaGenerator diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index 1b899450ff..72cd8017a3 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -2,12 +2,13 @@ from django.conf import settings from django.contrib.admindocs.views import simplify_regex -from django.core.urlresolvers import RegexURLPattern, RegexURLResolver from django.utils import six from django.utils.encoding import force_text from rest_framework import exceptions, serializers -from rest_framework.compat import coreapi, uritemplate, urlparse +from rest_framework.compat import ( + RegexURLPattern, RegexURLResolver, coreapi, uritemplate, urlparse +) from rest_framework.request import clone_request from rest_framework.views import APIView diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index 3bb85e472c..c1c8a53968 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -3,14 +3,13 @@ import re from django import template -from django.core.urlresolvers import NoReverseMatch, reverse from django.template import loader from django.utils import six from django.utils.encoding import force_text, iri_to_uri from django.utils.html import escape, format_html, smart_urlquote from django.utils.safestring import SafeData, mark_safe -from rest_framework.compat import template_render +from rest_framework.compat import NoReverseMatch, reverse, template_render from rest_framework.renderers import HTMLFormRenderer from rest_framework.utils.urls import replace_query_param diff --git a/rest_framework/urlpatterns.py b/rest_framework/urlpatterns.py index 7a02bb0f0a..4ea55300e8 100644 --- a/rest_framework/urlpatterns.py +++ b/rest_framework/urlpatterns.py @@ -1,8 +1,8 @@ from __future__ import unicode_literals from django.conf.urls import include, url -from django.core.urlresolvers import RegexURLResolver +from rest_framework.compat import RegexURLResolver from rest_framework.settings import api_settings diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py index 2e3ab90842..74f4f7840a 100644 --- a/rest_framework/utils/breadcrumbs.py +++ b/rest_framework/utils/breadcrumbs.py @@ -1,6 +1,6 @@ from __future__ import unicode_literals -from django.core.urlresolvers import get_script_prefix, resolve +from rest_framework.compat import get_script_prefix, resolve def get_breadcrumbs(url, request=None): diff --git a/tests/test_filters.py b/tests/test_filters.py index 03d61fc37e..fdb3c1c0bd 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -6,7 +6,6 @@ from django.conf.urls import url from django.core.exceptions import ImproperlyConfigured -from django.core.urlresolvers import reverse from django.db import models from django.test import TestCase from django.test.utils import override_settings @@ -14,7 +13,7 @@ from django.utils.six.moves import reload_module from rest_framework import filters, generics, serializers, status -from rest_framework.compat import django_filters +from rest_framework.compat import django_filters, reverse from rest_framework.test import APIRequestFactory from .models import BaseFilterableItem, BasicModel, FilterableItem diff --git a/tests/test_permissions.py b/tests/test_permissions.py index 0445f27caa..f8561e61da 100644 --- a/tests/test_permissions.py +++ b/tests/test_permissions.py @@ -4,7 +4,6 @@ import unittest from django.contrib.auth.models import Group, Permission, User -from django.core.urlresolvers import ResolverMatch from django.db import models from django.test import TestCase @@ -12,7 +11,7 @@ HTTP_HEADER_ENCODING, authentication, generics, permissions, serializers, status ) -from rest_framework.compat import guardian, set_many +from rest_framework.compat import ResolverMatch, guardian, set_many from rest_framework.filters import DjangoObjectPermissionsFilter from rest_framework.routers import DefaultRouter from rest_framework.test import APIRequestFactory diff --git a/tests/test_reverse.py b/tests/test_reverse.py index 03d31f1f93..f30a8bf9a0 100644 --- a/tests/test_reverse.py +++ b/tests/test_reverse.py @@ -1,9 +1,9 @@ from __future__ import unicode_literals from django.conf.urls import url -from django.core.urlresolvers import NoReverseMatch from django.test import TestCase, override_settings +from rest_framework.compat import NoReverseMatch from rest_framework.reverse import reverse from rest_framework.test import APIRequestFactory diff --git a/tests/test_urlpatterns.py b/tests/test_urlpatterns.py index 78d37c1a85..33d367e1d7 100644 --- a/tests/test_urlpatterns.py +++ b/tests/test_urlpatterns.py @@ -3,9 +3,9 @@ from collections import namedtuple from django.conf.urls import include, url -from django.core import urlresolvers from django.test import TestCase +from rest_framework.compat import RegexURLResolver, Resolver404 from rest_framework.test import APIRequestFactory from rest_framework.urlpatterns import format_suffix_patterns @@ -28,7 +28,7 @@ def _resolve_urlpatterns(self, urlpatterns, test_paths): urlpatterns = format_suffix_patterns(urlpatterns) except Exception: self.fail("Failed to apply `format_suffix_patterns` on the supplied urlpatterns") - resolver = urlresolvers.RegexURLResolver(r'^/', urlpatterns) + resolver = RegexURLResolver(r'^/', urlpatterns) for test_path in test_paths: request = factory.get(test_path.path) try: @@ -43,7 +43,7 @@ def test_trailing_slash(self): urlpatterns = format_suffix_patterns([ url(r'^test/$', dummy_view), ]) - resolver = urlresolvers.RegexURLResolver(r'^/', urlpatterns) + resolver = RegexURLResolver(r'^/', urlpatterns) test_paths = [ (URLTestPath('/test.api', (), {'format': 'api'}), True), @@ -55,7 +55,7 @@ def test_trailing_slash(self): request = factory.get(test_path.path) try: callback, callback_args, callback_kwargs = resolver.resolve(request.path_info) - except urlresolvers.Resolver404: + except Resolver404: callback, callback_args, callback_kwargs = (None, None, None) if not expected_resolved: assert callback is None diff --git a/tests/utils.py b/tests/utils.py index 5b2d75864a..52582f0934 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,5 +1,5 @@ from django.core.exceptions import ObjectDoesNotExist -from django.core.urlresolvers import NoReverseMatch +from rest_framework.compat import NoReverseMatch class MockObject(object): From bb37cb79929e145714c0018d90aed93e1d58d5a4 Mon Sep 17 00:00:00 2001 From: Ryan P Kilby Date: Wed, 14 Sep 2016 23:53:14 -0400 Subject: [PATCH 16/43] Update django-filter & django-guardian --- requirements/requirements-optionals.txt | 4 ++-- tests/test_filters.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/requirements/requirements-optionals.txt b/requirements/requirements-optionals.txt index 20436e6b4c..87e5034cd7 100644 --- a/requirements/requirements-optionals.txt +++ b/requirements/requirements-optionals.txt @@ -1,5 +1,5 @@ # Optional packages which may be used with REST framework. markdown==2.6.4 -django-guardian==1.4.3 -django-filter==0.13.0 +django-guardian==1.4.6 +django-filter==0.14.0 coreapi==1.32.0 diff --git a/tests/test_filters.py b/tests/test_filters.py index fdb3c1c0bd..d696309c14 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -76,6 +76,7 @@ class BaseFilterableItemFilter(django_filters.FilterSet): class Meta: model = BaseFilterableItem + fields = '__all__' class BaseFilterableItemFilterRootView(generics.ListCreateAPIView): queryset = FilterableItem.objects.all() From a372a8edea70260ebfb559f3d0bd692132cdfb40 Mon Sep 17 00:00:00 2001 From: Christian Sauer Date: Thu, 15 Sep 2016 12:54:35 -0400 Subject: [PATCH 17/43] Check for empty router prefix; adjust URL accordingly It's easiest to fix this issue after we have made the regex. To try to fix it before would require doing something different for List vs Detail, which means we'd have to know which type of url we're constructing before acting accordingly. --- rest_framework/routers.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 4eec70bda3..e2fbfa77b5 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -83,6 +83,7 @@ def urls(self): class SimpleRouter(BaseRouter): + routes = [ # List route. Route( @@ -258,6 +259,13 @@ def get_urls(self): trailing_slash=self.trailing_slash ) + # If there is no prefix, the first part of the url is probably + # controlled by project's urls.py and the router is in an app, + # so a slash in the beginning will (A) cause Django to give + # warnings and (B) generate URLS that will require using '//'. + if not prefix and regex[:2] == '^/': + regex = '^' + regex[2:] + view = viewset.as_view(mapping, **route.initkwargs) name = route.name.format(basename=basename) ret.append(url(regex, view, name=name)) From a084924ced2ee4cb18d0f8c48b603f71ce6d2429 Mon Sep 17 00:00:00 2001 From: Ryan P Kilby Date: Wed, 14 Sep 2016 23:54:47 -0400 Subject: [PATCH 18/43] Fix misc django deprecations --- rest_framework/compat.py | 6 ++++ tests/browsable_api/test_form_rendering.py | 1 + tests/conftest.py | 15 ++++++---- tests/test_atomic_requests.py | 32 +++++++++++----------- tests/test_authentication.py | 3 +- tests/test_filters.py | 2 +- tests/test_model_serializer.py | 3 +- tests/test_request.py | 5 ++-- 8 files changed, 40 insertions(+), 27 deletions(-) diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 958c481cd1..1a94f22b68 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -138,6 +138,12 @@ def is_authenticated(user): return user.is_authenticated +def is_anonymous(user): + if django.VERSION < (1, 10): + return user.is_anonymous() + return user.is_anonymous + + def get_related_model(field): if django.VERSION < (1, 9): return _resolve_model(field.rel.to) diff --git a/tests/browsable_api/test_form_rendering.py b/tests/browsable_api/test_form_rendering.py index 5a31ae0dd4..8b79ab6ff8 100644 --- a/tests/browsable_api/test_form_rendering.py +++ b/tests/browsable_api/test_form_rendering.py @@ -11,6 +11,7 @@ class BasicSerializer(serializers.ModelSerializer): class Meta: model = BasicModel + fields = '__all__' class ManyPostView(generics.GenericAPIView): diff --git a/tests/conftest.py b/tests/conftest.py index a5123b9d8d..2566782266 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,13 @@ def pytest_configure(): from django.conf import settings + MIDDLEWARE = ( + 'django.middleware.common.CommonMiddleware', + 'django.contrib.sessions.middleware.SessionMiddleware', + 'django.contrib.auth.middleware.AuthenticationMiddleware', + 'django.contrib.messages.middleware.MessageMiddleware', + ) + settings.configure( DEBUG_PROPAGATE_EXCEPTIONS=True, DATABASES={ @@ -21,12 +28,8 @@ def pytest_configure(): 'APP_DIRS': True, }, ], - MIDDLEWARE_CLASSES=( - 'django.middleware.common.CommonMiddleware', - 'django.contrib.sessions.middleware.SessionMiddleware', - 'django.contrib.auth.middleware.AuthenticationMiddleware', - 'django.contrib.messages.middleware.MessageMiddleware', - ), + MIDDLEWARE=MIDDLEWARE, + MIDDLEWARE_CLASSES=MIDDLEWARE, INSTALLED_APPS=( 'django.contrib.auth', 'django.contrib.contenttypes', diff --git a/tests/test_atomic_requests.py b/tests/test_atomic_requests.py index 8342ad3aff..09d7f2fb1f 100644 --- a/tests/test_atomic_requests.py +++ b/tests/test_atomic_requests.py @@ -5,7 +5,7 @@ from django.conf.urls import url from django.db import connection, connections, transaction from django.http import Http404 -from django.test import TestCase, TransactionTestCase +from django.test import TestCase, TransactionTestCase, override_settings from django.utils.decorators import method_decorator from rest_framework import status @@ -36,6 +36,20 @@ def post(self, request, *args, **kwargs): raise APIException +class NonAtomicAPIExceptionView(APIView): + @method_decorator(transaction.non_atomic_requests) + def dispatch(self, *args, **kwargs): + return super(NonAtomicAPIExceptionView, self).dispatch(*args, **kwargs) + + def get(self, request, *args, **kwargs): + BasicModel.objects.all() + raise Http404 + +urlpatterns = ( + url(r'^$', NonAtomicAPIExceptionView.as_view()), +) + + @unittest.skipUnless( connection.features.uses_savepoints, "'atomic' requires transactions and savepoints." @@ -124,22 +138,8 @@ def test_api_exception_rollback_transaction(self): connection.features.uses_savepoints, "'atomic' requires transactions and savepoints." ) +@override_settings(ROOT_URLCONF='tests.test_atomic_requests') class NonAtomicDBTransactionAPIExceptionTests(TransactionTestCase): - @property - def urls(self): - class NonAtomicAPIExceptionView(APIView): - @method_decorator(transaction.non_atomic_requests) - def dispatch(self, *args, **kwargs): - return super(NonAtomicAPIExceptionView, self).dispatch(*args, **kwargs) - - def get(self, request, *args, **kwargs): - BasicModel.objects.all() - raise Http404 - - return ( - url(r'^$', NonAtomicAPIExceptionView.as_view()), - ) - def setUp(self): connections.databases['default']['ATOMIC_REQUESTS'] = True diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 5ef620abe2..6f17ea14f7 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -20,6 +20,7 @@ ) from rest_framework.authtoken.models import Token from rest_framework.authtoken.views import obtain_auth_token +from rest_framework.compat import is_authenticated from rest_framework.response import Response from rest_framework.test import APIClient, APIRequestFactory from rest_framework.views import APIView @@ -408,7 +409,7 @@ class AuthAccessingRenderer(renderers.BaseRenderer): def render(self, data, media_type=None, renderer_context=None): request = renderer_context['request'] - if request.user.is_authenticated(): + if is_authenticated(request.user): return b'authenticated' return b'not authenticated' diff --git a/tests/test_filters.py b/tests/test_filters.py index d696309c14..c67412dd73 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -456,7 +456,7 @@ class AttributeModel(models.Model): class SearchFilterModelFk(models.Model): title = models.CharField(max_length=20) - attribute = models.ForeignKey(AttributeModel) + attribute = models.ForeignKey(AttributeModel, on_delete=models.CASCADE) class SearchFilterFkSerializer(serializers.ModelSerializer): diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py index 5dac16f2bc..cd9b2dfc3b 100644 --- a/tests/test_model_serializer.py +++ b/tests/test_model_serializer.py @@ -962,7 +962,7 @@ class OneToOneTargetTestModel(models.Model): class OneToOneSourceTestModel(models.Model): - target = models.OneToOneField(OneToOneTargetTestModel, primary_key=True) + target = models.OneToOneField(OneToOneTargetTestModel, primary_key=True, on_delete=models.CASCADE) class TestModelFieldValues(TestCase): @@ -990,6 +990,7 @@ class Meta: class TestSerializer(serializers.ModelSerializer): class Meta: model = TestModel + fields = '__all__' extra_kwargs = {'field_1': {'required': False}} fields = TestSerializer().fields diff --git a/tests/test_request.py b/tests/test_request.py index dbfa695fd7..32fbbc50bb 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -13,6 +13,7 @@ from rest_framework import status from rest_framework.authentication import SessionAuthentication +from rest_framework.compat import is_anonymous from rest_framework.parsers import BaseParser, FormParser, MultiPartParser from rest_framework.request import Request from rest_framework.response import Response @@ -169,9 +170,9 @@ def test_user_can_login(self): def test_user_can_logout(self): self.request.user = self.user - self.assertFalse(self.request.user.is_anonymous()) + self.assertFalse(is_anonymous(self.request.user)) logout(self.request) - self.assertTrue(self.request.user.is_anonymous()) + self.assertTrue(is_anonymous(self.request.user)) def test_logged_in_user_is_set_on_wrapped_request(self): login(self.request, self.user) From 3bfb0b716874559044e8c5bee3e575a549e057ab Mon Sep 17 00:00:00 2001 From: Ryan P Kilby Date: Wed, 14 Sep 2016 23:57:52 -0400 Subject: [PATCH 19/43] Use TOC extension instead of header --- rest_framework/compat.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 1a94f22b68..8afe52f543 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -216,8 +216,13 @@ def value_from_object(field, obj): if markdown.version <= '2.2': HEADERID_EXT_PATH = 'headerid' - else: + LEVEL_PARAM = 'level' + elif markdown.version < '2.6': HEADERID_EXT_PATH = 'markdown.extensions.headerid' + LEVEL_PARAM = 'level' + else: + HEADERID_EXT_PATH = 'markdown.extensions.toc' + LEVEL_PARAM = 'baselevel' def apply_markdown(text): """ @@ -227,7 +232,7 @@ def apply_markdown(text): extensions = [HEADERID_EXT_PATH] extension_configs = { HEADERID_EXT_PATH: { - 'level': '2' + LEVEL_PARAM: '2' } } md = markdown.Markdown( From 3bdb0e9dd8b8f098111fad4a874ee99d392e944a Mon Sep 17 00:00:00 2001 From: Ryan P Kilby Date: Fri, 16 Sep 2016 13:08:26 -0400 Subject: [PATCH 20/43] Fix deprecations for py3k --- rest_framework/fields.py | 12 ++++++++++++ tests/test_schemas.py | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index f76e4e8011..7d677a408b 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -53,6 +53,18 @@ def is_simple_callable(obj): """ True if the object is a callable that takes no arguments. """ + if not hasattr(inspect, 'signature'): + return py2k_is_simple_callable(obj) + + if not callable(obj): + return False + + sig = inspect.signature(obj) + params = sig.parameters.values() + return all(param.default != param.empty for param in params) + + +def py2k_is_simple_callable(obj): function = inspect.isfunction(obj) method = inspect.ismethod(obj) diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 197e62eb0b..dc01d8cd84 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -215,4 +215,4 @@ def test_view(self): } } ) - self.assertEquals(schema, expected) + self.assertEqual(schema, expected) From a0a8b9890a821bb56aace86ed18e57b2a1137e07 Mon Sep 17 00:00:00 2001 From: Ryan P Kilby Date: Thu, 22 Sep 2016 13:56:27 -0400 Subject: [PATCH 21/43] Add py3k compatibility to is_simple_callable --- rest_framework/fields.py | 38 ++++++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 917a151e58..7f8391b8a8 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -49,20 +49,34 @@ class empty: pass -def is_simple_callable(obj): - """ - True if the object is a callable that takes no arguments. - """ - function = inspect.isfunction(obj) - method = inspect.ismethod(obj) +if six.PY3: + def is_simple_callable(obj): + """ + True if the object is a callable that takes no arguments. + """ + if not callable(obj): + return False + + sig = inspect.signature(obj) + params = sig.parameters.values() + return all(param.default != param.empty for param in params) + +else: + def is_simple_callable(obj): + function = inspect.isfunction(obj) + method = inspect.ismethod(obj) + + if not (function or method): + return False + + if method: + is_unbound = obj.im_self is None - if not (function or method): - return False + args, _, _, defaults = inspect.getargspec(obj) - args, _, _, defaults = inspect.getargspec(obj) - len_args = len(args) if function else len(args) - 1 - len_defaults = len(defaults) if defaults else 0 - return len_args <= len_defaults + len_args = len(args) if function or is_unbound else len(args) - 1 + len_defaults = len(defaults) if defaults else 0 + return len_args <= len_defaults def get_attribute(instance, attrs): From adcf6536e75272649da1d8e5a8b5ba7926e33147 Mon Sep 17 00:00:00 2001 From: Ryan P Kilby Date: Thu, 22 Sep 2016 13:56:37 -0400 Subject: [PATCH 22/43] Add is_simple_callable tests --- tests/test_fields.py | 62 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/tests/test_fields.py b/tests/test_fields.py index 4a4b741c53..c271afa9ee 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -1,6 +1,7 @@ import datetime import os import re +import unittest import uuid from decimal import Decimal @@ -11,6 +12,67 @@ import rest_framework from rest_framework import serializers +from rest_framework.fields import is_simple_callable + +try: + import typings +except ImportError: + typings = False + + +# Tests for helper functions. +# --------------------------- + +class TestIsSimpleCallable: + + def test_method(self): + class Foo: + @classmethod + def classmethod(cls): + pass + + def valid(self): + pass + + def valid_kwargs(self, param='value'): + pass + + def invalid(self, param): + pass + + assert is_simple_callable(Foo.classmethod) + + # unbound methods + assert not is_simple_callable(Foo.valid) + assert not is_simple_callable(Foo.valid_kwargs) + assert not is_simple_callable(Foo.invalid) + + # bound methods + assert is_simple_callable(Foo().valid) + assert is_simple_callable(Foo().valid_kwargs) + assert not is_simple_callable(Foo().invalid) + + def test_function(self): + def simple(): + pass + + def valid(param='value', param2='value'): + pass + + def invalid(param, param2='value'): + pass + + assert is_simple_callable(simple) + assert is_simple_callable(valid) + assert not is_simple_callable(invalid) + + @unittest.skipUnless(typings, 'requires python 3.5') + def test_type_annotation(self): + # The annotation will otherwise raise a syntax error in python < 3.5 + exec("def valid(param: str='value'): pass", locals()) + valid = locals()['valid'] + + assert is_simple_callable(valid) # Tests for field keyword arguments and core functionality. From b3afcb25d9f0621091231212fcc8c92631e035e4 Mon Sep 17 00:00:00 2001 From: Ryan P Kilby Date: Thu, 22 Sep 2016 15:19:48 -0400 Subject: [PATCH 23/43] Drop python 3.2 support (EOL, Dropped by Django) --- .travis.yml | 1 - tox.ini | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/.travis.yml b/.travis.yml index c9d9a16484..100a7cd8b4 100644 --- a/.travis.yml +++ b/.travis.yml @@ -14,7 +14,6 @@ env: - TOX_ENV=py35-django18 - TOX_ENV=py34-django18 - TOX_ENV=py33-django18 - - TOX_ENV=py32-django18 - TOX_ENV=py27-django18 - TOX_ENV=py27-django110 - TOX_ENV=py35-django110 diff --git a/tox.ini b/tox.ini index 1e8a7e5c46..c20021e4be 100644 --- a/tox.ini +++ b/tox.ini @@ -4,7 +4,7 @@ addopts=--tb=short [tox] envlist = py27-{lint,docs}, - {py27,py32,py33,py34,py35}-django18, + {py27,py33,py34,py35}-django18, {py27,py34,py35}-django19, {py27,py34,py35}-django110, {py27,py34,py35}-django{master} @@ -25,7 +25,6 @@ basepython = py35: python3.5 py34: python3.4 py33: python3.3 - py32: python3.2 py27: python2.7 [testenv:py27-lint] From b5167120764c997562de8fb9caf71249ae4a8c9e Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 28 Sep 2016 12:15:46 +0100 Subject: [PATCH 24/43] schema_renderers= should *set* the renderers, not append to them. --- rest_framework/routers.py | 57 ++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 27 deletions(-) diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 4eec70bda3..64f110cd98 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -289,42 +289,42 @@ def __init__(self, *args, **kwargs): self.root_renderers = list(api_settings.DEFAULT_RENDERER_CLASSES) super(DefaultRouter, self).__init__(*args, **kwargs) + def get_schema_root_view(self, api_urls=None): + """ + Return a schema root view. + """ + schema_renderers = self.schema_renderers + schema_generator = SchemaGenerator( + title=self.schema_title, + url=self.schema_url, + patterns=api_urls + ) + + class APISchemaView(views.APIView): + _ignore_model_permissions = True + renderer_classes = schema_renderers + + def get(self, request, *args, **kwargs): + schema = schema_generator.get_schema(request) + if schema is None: + raise exceptions.PermissionDenied() + return Response(schema) + + return APISchemaView.as_view() + def get_api_root_view(self, api_urls=None): """ - Return a view to use as the API root. + Return a basic root view. """ api_root_dict = OrderedDict() list_name = self.routes[0].name for prefix, viewset, basename in self.registry: api_root_dict[prefix] = list_name.format(basename=basename) - view_renderers = list(self.root_renderers) - schema_media_types = [] - - if api_urls and self.schema_title: - view_renderers += list(self.schema_renderers) - schema_generator = SchemaGenerator( - title=self.schema_title, - url=self.schema_url, - patterns=api_urls - ) - schema_media_types = [ - renderer.media_type - for renderer in self.schema_renderers - ] - - class APIRoot(views.APIView): + class APIRootView(views.APIView): _ignore_model_permissions = True - renderer_classes = view_renderers def get(self, request, *args, **kwargs): - if request.accepted_renderer.media_type in schema_media_types: - # Return a schema response. - schema = schema_generator.get_schema(request) - if schema is None: - raise exceptions.PermissionDenied() - return Response(schema) - # Return a plain {"name": "hyperlink"} response. ret = OrderedDict() namespace = request.resolver_match.namespace @@ -345,7 +345,7 @@ def get(self, request, *args, **kwargs): return Response(ret) - return APIRoot.as_view() + return APIRootView.as_view() def get_urls(self): """ @@ -355,7 +355,10 @@ def get_urls(self): urls = super(DefaultRouter, self).get_urls() if self.include_root_view: - view = self.get_api_root_view(api_urls=urls) + if self.schema_title: + view = self.get_schema_root_view(api_urls=urls) + else: + view = self.get_api_root_view(api_urls=urls) root_url = url(r'^$', view, name=self.root_view_name) urls.append(root_url) From 37b3475e5d048f7f7e751dee6bdb23695fe484bd Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 29 Sep 2016 09:46:28 +0100 Subject: [PATCH 25/43] API client (#4424) --- README.md | 2 +- docs/api-guide/permissions.md | 7 +- docs/topics/release-notes.md | 25 + requirements/requirements-optionals.txt | 2 +- rest_framework/__init__.py | 2 +- rest_framework/fields.py | 2 +- rest_framework/relations.py | 6 +- rest_framework/renderers.py | 9 +- rest_framework/request.py | 5 + rest_framework/schemas.py | 7 + rest_framework/serializers.py | 8 +- .../static/rest_framework/js/csrf.js | 2 +- .../templates/rest_framework/admin.html | 1 + .../templates/rest_framework/base.html | 1 + .../vertical/checkbox_multiple.html | 4 +- rest_framework/test.py | 15 +- rest_framework/views.py | 20 +- tests/test_api_client.py | 452 ++++++++++++++++++ tests/test_request.py | 11 + tests/test_schemas.py | 1 + 20 files changed, 557 insertions(+), 25 deletions(-) create mode 100644 tests/test_api_client.py diff --git a/README.md b/README.md index 179f2891a8..e1e2526091 100644 --- a/README.md +++ b/README.md @@ -169,7 +169,7 @@ You may also want to [follow the author on Twitter][twitter]. # Security -If you believe you’ve found something in Django REST framework which has security implications, please **do not raise the issue in a public forum**. +If you believe you've found something in Django REST framework which has security implications, please **do not raise the issue in a public forum**. Send a description of the issue via email to [rest-framework-security@googlegroups.com][security-mail]. The project maintainers will then work with you to resolve any issues where required, prior to any public disclosure. diff --git a/docs/api-guide/permissions.md b/docs/api-guide/permissions.md index e0838e94a9..7cdb595313 100644 --- a/docs/api-guide/permissions.md +++ b/docs/api-guide/permissions.md @@ -92,7 +92,7 @@ Or, if you're using the `@api_view` decorator with function based views. from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response - @api_view('GET') + @api_view(['GET']) @permission_classes((IsAuthenticated, )) def example_view(request, format=None): content = { @@ -261,6 +261,10 @@ The [REST Condition][rest-condition] package is another extension for building c The [DRY Rest Permissions][dry-rest-permissions] package provides the ability to define different permissions for individual default and custom actions. This package is made for apps with permissions that are derived from relationships defined in the app's data model. It also supports permission checks being returned to a client app through the API's serializer. Additionally it supports adding permissions to the default and custom list actions to restrict the data they retrive per user. +## Django Rest Framework Roles + +The [Django Rest Framework Roles][django-rest-framework-roles] package makes it easier to parameterize your API over multiple types of users. + [cite]: https://developer.apple.com/library/mac/#documentation/security/Conceptual/AuthenticationAndAuthorizationGuide/Authorization/Authorization.html [authentication]: authentication.md [throttling]: throttling.md @@ -275,3 +279,4 @@ The [DRY Rest Permissions][dry-rest-permissions] package provides the ability to [composed-permissions]: https://github.com/niwibe/djangorestframework-composed-permissions [rest-condition]: https://github.com/caxap/rest_condition [dry-rest-permissions]: https://github.com/Helioscene/dry-rest-permissions +[django-rest-framework-roles]: https://github.com/computer-lab/django-rest-framework-roles diff --git a/docs/topics/release-notes.md b/docs/topics/release-notes.md index 78a5a8ba90..24728a252e 100644 --- a/docs/topics/release-notes.md +++ b/docs/topics/release-notes.md @@ -40,6 +40,18 @@ You can determine your currently installed version using `pip freeze`: ## 3.4.x series +### 3.4.5 + +**Date**: [19th August 2016][3.4.5-milestone] + +* Improve debug error handling. ([#4416][gh4416], [#4409][gh4409]) +* Allow custom CSRF_HEADER_NAME setting. ([#4415][gh4415], [#4410][gh4410]) +* Include .action attribute on viewsets when generating schemas. ([#4408][gh4408], [#4398][gh4398]) +* Do not include request.FILES items in request.POST. ([#4407][gh4407]) +* Fix rendering of checkbox multiple. ([#4403][gh4403]) +* Fix docstring of Field.get_default. ([#4404][gh4404]) +* Replace utf8 character with its ascii counterpart in README. ([#4412][gh4412]) + ### 3.4.4 **Date**: [12th August 2016][3.4.4-milestone] @@ -560,6 +572,7 @@ For older release notes, [please see the version 2.x documentation][old-release- [3.4.2-milestone]: https://github.com/tomchristie/django-rest-framework/issues?q=milestone%3A%223.4.2+Release%22 [3.4.3-milestone]: https://github.com/tomchristie/django-rest-framework/issues?q=milestone%3A%223.4.3+Release%22 [3.4.4-milestone]: https://github.com/tomchristie/django-rest-framework/issues?q=milestone%3A%223.4.4+Release%22 +[3.4.5-milestone]: https://github.com/tomchristie/django-rest-framework/issues?q=milestone%3A%223.4.5+Release%22 [gh2013]: https://github.com/tomchristie/django-rest-framework/issues/2013 @@ -1065,3 +1078,15 @@ For older release notes, [please see the version 2.x documentation][old-release- [gh4392]: https://github.com/tomchristie/django-rest-framework/issues/4392 [gh4393]: https://github.com/tomchristie/django-rest-framework/issues/4393 [gh4394]: https://github.com/tomchristie/django-rest-framework/issues/4394 + + +[gh4416]: https://github.com/tomchristie/django-rest-framework/issues/4416 +[gh4409]: https://github.com/tomchristie/django-rest-framework/issues/4409 +[gh4415]: https://github.com/tomchristie/django-rest-framework/issues/4415 +[gh4410]: https://github.com/tomchristie/django-rest-framework/issues/4410 +[gh4408]: https://github.com/tomchristie/django-rest-framework/issues/4408 +[gh4398]: https://github.com/tomchristie/django-rest-framework/issues/4398 +[gh4407]: https://github.com/tomchristie/django-rest-framework/issues/4407 +[gh4403]: https://github.com/tomchristie/django-rest-framework/issues/4403 +[gh4404]: https://github.com/tomchristie/django-rest-framework/issues/4404 +[gh4412]: https://github.com/tomchristie/django-rest-framework/issues/4412 diff --git a/requirements/requirements-optionals.txt b/requirements/requirements-optionals.txt index 20436e6b4c..afade0aa0c 100644 --- a/requirements/requirements-optionals.txt +++ b/requirements/requirements-optionals.txt @@ -2,4 +2,4 @@ markdown==2.6.4 django-guardian==1.4.3 django-filter==0.13.0 -coreapi==1.32.0 +coreapi==2.0.0 diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index 999c5de315..3f8736c258 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -8,7 +8,7 @@ """ __title__ = 'Django REST framework' -__version__ = '3.4.4' +__version__ = '3.4.5' __author__ = 'Tom Christie' __license__ = 'BSD 2-Clause' __copyright__ = 'Copyright 2011-2016 Tom Christie' diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 8f12b2df48..f76e4e8011 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -432,7 +432,7 @@ def get_default(self): is provided for this field. If a default has not been set for this field then this will simply - return `empty`, indicating that no value should be set in the + raise `SkipField`, indicating that no value should be set in the validated data for this field. """ if self.default is empty or getattr(self.root, 'partial', False): diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 4b6b3bea45..65c4c03187 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -10,7 +10,7 @@ from django.db.models import Manager from django.db.models.query import QuerySet from django.utils import six -from django.utils.encoding import smart_text +from django.utils.encoding import python_2_unicode_compatible, smart_text from django.utils.six.moves.urllib import parse as urlparse from django.utils.translation import ugettext_lazy as _ @@ -47,6 +47,7 @@ def __getnewargs__(self): is_hyperlink = True +@python_2_unicode_compatible class PKOnlyObject(object): """ This is a mock object, used for when we only need the pk of the object @@ -56,6 +57,9 @@ class PKOnlyObject(object): def __init__(self, pk): self.pk = pk + def __str__(self): + return "%s" % self.pk + # We assume that 'validators' are intended for the child serializer, # rather than the parent serializer. diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 371cd6ec77..11e9fb9607 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -645,6 +645,12 @@ def get_context(self, data, accepted_media_type, renderer_context): else: paginator = None + csrf_cookie_name = settings.CSRF_COOKIE_NAME + csrf_header_name = getattr(settings, 'CSRF_HEADER_NAME', 'HTTP_X_CSRFToken') # Fallback for Django 1.8 + if csrf_header_name.startswith('HTTP_'): + csrf_header_name = csrf_header_name[5:] + csrf_header_name = csrf_header_name.replace('_', '-') + context = { 'content': self.get_content(renderer, data, accepted_media_type, renderer_context), 'view': view, @@ -675,7 +681,8 @@ def get_context(self, data, accepted_media_type, renderer_context): 'display_edit_forms': bool(response.status_code != 403), 'api_settings': api_settings, - 'csrf_cookie_name': settings.CSRF_COOKIE_NAME, + 'csrf_cookie_name': csrf_cookie_name, + 'csrf_header_name': csrf_header_name } return context diff --git a/rest_framework/request.py b/rest_framework/request.py index f5738bfd50..355cccad77 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -391,3 +391,8 @@ def QUERY_PARAMS(self): '`request.QUERY_PARAMS` has been deprecated in favor of `request.query_params` ' 'since version 3.0, and has been fully removed as of version 3.2.' ) + + def force_plaintext_errors(self, value): + # Hack to allow our exception handler to force choice of + # plaintext or html error responses. + self._request.is_ajax = lambda: value diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index 0618e94fd2..c9834c64d0 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -79,6 +79,13 @@ def get_schema(self, request=None): view.kwargs = {} view.format_kwarg = None + actions = getattr(callback, 'actions', None) + if actions is not None: + if method == 'OPTIONS': + view.action = 'metadata' + else: + view.action = actions.get(method.lower()) + if request is not None: view.request = clone_request(request, method) try: diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 41412af8ad..4d1ed63aef 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -12,6 +12,7 @@ """ from __future__ import unicode_literals +import traceback import warnings from django.db import models @@ -870,19 +871,20 @@ def create(self, validated_data): try: instance = ModelClass.objects.create(**validated_data) - except TypeError as exc: + except TypeError: + tb = traceback.format_exc() msg = ( 'Got a `TypeError` when calling `%s.objects.create()`. ' 'This may be because you have a writable field on the ' 'serializer class that is not a valid argument to ' '`%s.objects.create()`. You may need to make the field ' 'read-only, or override the %s.create() method to handle ' - 'this correctly.\nOriginal exception text was: %s.' % + 'this correctly.\nOriginal exception was:\n %s' % ( ModelClass.__name__, ModelClass.__name__, self.__class__.__name__, - exc + tb ) ) raise TypeError(msg) diff --git a/rest_framework/static/rest_framework/js/csrf.js b/rest_framework/static/rest_framework/js/csrf.js index f8ab4428cb..97c8d01242 100644 --- a/rest_framework/static/rest_framework/js/csrf.js +++ b/rest_framework/static/rest_framework/js/csrf.js @@ -46,7 +46,7 @@ $.ajaxSetup({ // Send the token to same-origin, relative URLs only. // Send the token only if the method warrants CSRF protection // Using the CSRFToken value acquired earlier - xhr.setRequestHeader("X-CSRFToken", csrftoken); + xhr.setRequestHeader(window.drf.csrfHeaderName, csrftoken); } } }); diff --git a/rest_framework/templates/rest_framework/admin.html b/rest_framework/templates/rest_framework/admin.html index 89af81ef74..eb2b8f1c7e 100644 --- a/rest_framework/templates/rest_framework/admin.html +++ b/rest_framework/templates/rest_framework/admin.html @@ -232,6 +232,7 @@ {% block script %} diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html index 4c1136087c..989a086ea7 100644 --- a/rest_framework/templates/rest_framework/base.html +++ b/rest_framework/templates/rest_framework/base.html @@ -263,6 +263,7 @@

{{ name }}

{% block script %} diff --git a/rest_framework/templates/rest_framework/vertical/checkbox_multiple.html b/rest_framework/templates/rest_framework/vertical/checkbox_multiple.html index b933f4ff51..7a43b3f58b 100644 --- a/rest_framework/templates/rest_framework/vertical/checkbox_multiple.html +++ b/rest_framework/templates/rest_framework/vertical/checkbox_multiple.html @@ -9,7 +9,7 @@
{% for key, text in field.choices.items %} {% endfor %} @@ -18,7 +18,7 @@ {% for key, text in field.choices.items %}
diff --git a/rest_framework/test.py b/rest_framework/test.py index e17c19a43f..b8e486b216 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -16,7 +16,7 @@ from django.utils.encoding import force_bytes from django.utils.http import urlencode -from rest_framework.compat import requests +from rest_framework.compat import coreapi, requests from rest_framework.settings import api_settings @@ -60,7 +60,10 @@ def get_environ(self, request): # Set request content, if any exists. if request.body is not None: - kwargs['data'] = request.body + if hasattr(request.body, 'read'): + kwargs['data'] = request.body.read() + else: + kwargs['data'] = request.body if 'content-type' in request.headers: kwargs['content_type'] = request.headers['content-type'] @@ -126,6 +129,14 @@ def get_requests_client(): return DjangoTestSession() +def get_api_client(): + assert coreapi is not None, 'coreapi must be installed' + session = get_requests_client() + return coreapi.Client(transports=[ + coreapi.transports.HTTPTransport(session=session) + ]) + + class APIRequestFactory(DjangoRequestFactory): renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT diff --git a/rest_framework/views.py b/rest_framework/views.py index b86bb7eaa6..15d8c6cde2 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -3,17 +3,14 @@ """ from __future__ import unicode_literals -import sys - from django.conf import settings from django.core.exceptions import PermissionDenied from django.db import models from django.http import Http404 -from django.http.response import HttpResponse, HttpResponseBase +from django.http.response import HttpResponseBase from django.utils import six from django.utils.encoding import smart_text from django.utils.translation import ugettext_lazy as _ -from django.views import debug from django.views.decorators.csrf import csrf_exempt from django.views.generic import View @@ -95,11 +92,6 @@ def exception_handler(exc, context): set_rollback() return Response(data, status=status.HTTP_403_FORBIDDEN) - # throw django's error page if debug is True - if settings.DEBUG: - exception_reporter = debug.ExceptionReporter(context.get('request'), *sys.exc_info()) - return HttpResponse(exception_reporter.get_traceback_html(), status=500) - return None @@ -439,11 +431,19 @@ def handle_exception(self, exc): response = exception_handler(exc, context) if response is None: - raise + self.raise_uncaught_exception(exc) response.exception = True return response + def raise_uncaught_exception(self, exc): + if settings.DEBUG: + request = self.request + renderer_format = getattr(request.accepted_renderer, 'format') + use_plaintext_traceback = renderer_format not in ('html', 'api', 'admin') + request.force_plaintext_errors(use_plaintext_traceback) + raise + # Note: Views are made CSRF exempt from within `as_view` as to prevent # accidental removal of this exemption in cases where `dispatch` needs to # be overridden. diff --git a/tests/test_api_client.py b/tests/test_api_client.py new file mode 100644 index 0000000000..9daf3f3fe4 --- /dev/null +++ b/tests/test_api_client.py @@ -0,0 +1,452 @@ +from __future__ import unicode_literals + +import os +import tempfile +import unittest + +from django.conf.urls import url +from django.http import HttpResponse +from django.test import override_settings + +from rest_framework.compat import coreapi +from rest_framework.parsers import FileUploadParser +from rest_framework.renderers import CoreJSONRenderer +from rest_framework.response import Response +from rest_framework.test import APITestCase, get_api_client +from rest_framework.views import APIView + + +def get_schema(): + return coreapi.Document( + url='https://api.example.com/', + title='Example API', + content={ + 'simple_link': coreapi.Link('/example/', description='example link'), + 'location': { + 'query': coreapi.Link('/example/', fields=[ + coreapi.Field(name='example', description='example field') + ]), + 'form': coreapi.Link('/example/', action='post', fields=[ + coreapi.Field(name='example'), + ]), + 'body': coreapi.Link('/example/', action='post', fields=[ + coreapi.Field(name='example', location='body') + ]), + 'path': coreapi.Link('/example/{id}', fields=[ + coreapi.Field(name='id', location='path') + ]) + }, + 'encoding': { + 'multipart': coreapi.Link('/example/', action='post', encoding='multipart/form-data', fields=[ + coreapi.Field(name='example') + ]), + 'multipart-body': coreapi.Link('/example/', action='post', encoding='multipart/form-data', fields=[ + coreapi.Field(name='example', location='body') + ]), + 'urlencoded': coreapi.Link('/example/', action='post', encoding='application/x-www-form-urlencoded', fields=[ + coreapi.Field(name='example') + ]), + 'urlencoded-body': coreapi.Link('/example/', action='post', encoding='application/x-www-form-urlencoded', fields=[ + coreapi.Field(name='example', location='body') + ]), + 'raw_upload': coreapi.Link('/upload/', action='post', encoding='application/octet-stream', fields=[ + coreapi.Field(name='example', location='body') + ]), + }, + 'response': { + 'download': coreapi.Link('/download/'), + 'text': coreapi.Link('/text/') + } + } + ) + + +def _iterlists(querydict): + if hasattr(querydict, 'iterlists'): + return querydict.iterlists() + return querydict.lists() + + +def _get_query_params(request): + # Return query params in a plain dict, using a list value if more + # than one item is present for a given key. + return { + key: (value[0] if len(value) == 1 else value) + for key, value in + _iterlists(request.query_params) + } + + +def _get_data(request): + if not isinstance(request.data, dict): + return request.data + # Coerce multidict into regular dict, and remove files to + # make assertions simpler. + if hasattr(request.data, 'iterlists') or hasattr(request.data, 'lists'): + # Use a list value if a QueryDict contains multiple items for a key. + return { + key: value[0] if len(value) == 1 else value + for key, value in _iterlists(request.data) + if key not in request.FILES + } + return { + key: value + for key, value in request.data.items() + if key not in request.FILES + } + + +def _get_files(request): + if not request.FILES: + return {} + return { + key: {'name': value.name, 'content': value.read()} + for key, value in request.FILES.items() + } + + +class SchemaView(APIView): + renderer_classes = [CoreJSONRenderer] + + def get(self, request): + schema = get_schema() + return Response(schema) + + +class ListView(APIView): + def get(self, request): + return Response({ + 'method': request.method, + 'query_params': _get_query_params(request) + }) + + def post(self, request): + if request.content_type: + content_type = request.content_type.split(';')[0] + else: + content_type = None + + return Response({ + 'method': request.method, + 'query_params': _get_query_params(request), + 'data': _get_data(request), + 'files': _get_files(request), + 'content_type': content_type + }) + + +class DetailView(APIView): + def get(self, request, id): + return Response({ + 'id': id, + 'method': request.method, + 'query_params': _get_query_params(request) + }) + + +class UploadView(APIView): + parser_classes = [FileUploadParser] + + def post(self, request): + return Response({ + 'method': request.method, + 'files': _get_files(request), + 'content_type': request.content_type + }) + + +class DownloadView(APIView): + def get(self, request): + return HttpResponse('some file content', content_type='image/png') + + +class TextView(APIView): + def get(self, request): + return HttpResponse('123', content_type='text/plain') + + +urlpatterns = [ + url(r'^$', SchemaView.as_view()), + url(r'^example/$', ListView.as_view()), + url(r'^example/(?P[0-9]+)/$', DetailView.as_view()), + url(r'^upload/$', UploadView.as_view()), + url(r'^download/$', DownloadView.as_view()), + url(r'^text/$', TextView.as_view()), +] + + +@unittest.skipUnless(coreapi, 'coreapi not installed') +@override_settings(ROOT_URLCONF='tests.test_api_client') +class APIClientTests(APITestCase): + def test_api_client(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + assert schema.title == 'Example API' + assert schema.url == 'https://api.example.com/' + assert schema['simple_link'].description == 'example link' + assert schema['location']['query'].fields[0].description == 'example field' + data = client.action(schema, ['simple_link']) + expected = { + 'method': 'GET', + 'query_params': {} + } + assert data == expected + + def test_query_params(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + data = client.action(schema, ['location', 'query'], params={'example': 123}) + expected = { + 'method': 'GET', + 'query_params': {'example': '123'} + } + assert data == expected + + def test_query_params_with_multiple_values(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + data = client.action(schema, ['location', 'query'], params={'example': [1, 2, 3]}) + expected = { + 'method': 'GET', + 'query_params': {'example': ['1', '2', '3']} + } + assert data == expected + + def test_form_params(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + data = client.action(schema, ['location', 'form'], params={'example': 123}) + expected = { + 'method': 'POST', + 'content_type': 'application/json', + 'query_params': {}, + 'data': {'example': 123}, + 'files': {} + } + assert data == expected + + def test_body_params(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + data = client.action(schema, ['location', 'body'], params={'example': 123}) + expected = { + 'method': 'POST', + 'content_type': 'application/json', + 'query_params': {}, + 'data': 123, + 'files': {} + } + assert data == expected + + def test_path_params(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + data = client.action(schema, ['location', 'path'], params={'id': 123}) + expected = { + 'method': 'GET', + 'query_params': {}, + 'id': '123' + } + assert data == expected + + def test_multipart_encoding(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + + temp = tempfile.NamedTemporaryFile() + temp.write(b'example file content') + temp.flush() + + with open(temp.name, 'rb') as upload: + name = os.path.basename(upload.name) + data = client.action(schema, ['encoding', 'multipart'], params={'example': upload}) + + expected = { + 'method': 'POST', + 'content_type': 'multipart/form-data', + 'query_params': {}, + 'data': {}, + 'files': {'example': {'name': name, 'content': 'example file content'}} + } + assert data == expected + + def test_multipart_encoding_no_file(self): + # When no file is included, multipart encoding should still be used. + client = get_api_client() + schema = client.get('http://api.example.com/') + + data = client.action(schema, ['encoding', 'multipart'], params={'example': 123}) + + expected = { + 'method': 'POST', + 'content_type': 'multipart/form-data', + 'query_params': {}, + 'data': {'example': '123'}, + 'files': {} + } + assert data == expected + + def test_multipart_encoding_multiple_values(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + + data = client.action(schema, ['encoding', 'multipart'], params={'example': [1, 2, 3]}) + + expected = { + 'method': 'POST', + 'content_type': 'multipart/form-data', + 'query_params': {}, + 'data': {'example': ['1', '2', '3']}, + 'files': {} + } + assert data == expected + + def test_multipart_encoding_string_file_content(self): + # Test for `coreapi.utils.File` support. + from coreapi.utils import File + + client = get_api_client() + schema = client.get('http://api.example.com/') + + example = File(name='example.txt', content='123') + data = client.action(schema, ['encoding', 'multipart'], params={'example': example}) + + expected = { + 'method': 'POST', + 'content_type': 'multipart/form-data', + 'query_params': {}, + 'data': {}, + 'files': {'example': {'name': 'example.txt', 'content': '123'}} + } + assert data == expected + + def test_multipart_encoding_in_body(self): + from coreapi.utils import File + + client = get_api_client() + schema = client.get('http://api.example.com/') + + example = {'foo': File(name='example.txt', content='123'), 'bar': 'abc'} + data = client.action(schema, ['encoding', 'multipart-body'], params={'example': example}) + + expected = { + 'method': 'POST', + 'content_type': 'multipart/form-data', + 'query_params': {}, + 'data': {'bar': 'abc'}, + 'files': {'foo': {'name': 'example.txt', 'content': '123'}} + } + assert data == expected + + # URLencoded + + def test_urlencoded_encoding(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + data = client.action(schema, ['encoding', 'urlencoded'], params={'example': 123}) + expected = { + 'method': 'POST', + 'content_type': 'application/x-www-form-urlencoded', + 'query_params': {}, + 'data': {'example': '123'}, + 'files': {} + } + assert data == expected + + def test_urlencoded_encoding_multiple_values(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + data = client.action(schema, ['encoding', 'urlencoded'], params={'example': [1, 2, 3]}) + expected = { + 'method': 'POST', + 'content_type': 'application/x-www-form-urlencoded', + 'query_params': {}, + 'data': {'example': ['1', '2', '3']}, + 'files': {} + } + assert data == expected + + def test_urlencoded_encoding_in_body(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + data = client.action(schema, ['encoding', 'urlencoded-body'], params={'example': {'foo': 123, 'bar': True}}) + expected = { + 'method': 'POST', + 'content_type': 'application/x-www-form-urlencoded', + 'query_params': {}, + 'data': {'foo': '123', 'bar': 'true'}, + 'files': {} + } + assert data == expected + + # Raw uploads + + def test_raw_upload(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + + temp = tempfile.NamedTemporaryFile() + temp.write(b'example file content') + temp.flush() + + with open(temp.name, 'rb') as upload: + name = os.path.basename(upload.name) + data = client.action(schema, ['encoding', 'raw_upload'], params={'example': upload}) + + expected = { + 'method': 'POST', + 'files': {'file': {'name': name, 'content': 'example file content'}}, + 'content_type': 'application/octet-stream' + } + assert data == expected + + def test_raw_upload_string_file_content(self): + from coreapi.utils import File + + client = get_api_client() + schema = client.get('http://api.example.com/') + + example = File('example.txt', '123') + data = client.action(schema, ['encoding', 'raw_upload'], params={'example': example}) + + expected = { + 'method': 'POST', + 'files': {'file': {'name': 'example.txt', 'content': '123'}}, + 'content_type': 'text/plain' + } + assert data == expected + + def test_raw_upload_explicit_content_type(self): + from coreapi.utils import File + + client = get_api_client() + schema = client.get('http://api.example.com/') + + example = File('example.txt', '123', 'text/html') + data = client.action(schema, ['encoding', 'raw_upload'], params={'example': example}) + + expected = { + 'method': 'POST', + 'files': {'file': {'name': 'example.txt', 'content': '123'}}, + 'content_type': 'text/html' + } + assert data == expected + + # Responses + + def test_text_response(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + + data = client.action(schema, ['response', 'text']) + + expected = '123' + assert data == expected + + def test_download_response(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + + data = client.action(schema, ['response', 'download']) + assert data.basename == 'download.png' + assert data.read() == b'some file content' diff --git a/tests/test_request.py b/tests/test_request.py index dee636d766..dbfa695fd7 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -7,6 +7,7 @@ from django.contrib.auth import authenticate, login, logout from django.contrib.auth.models import User from django.contrib.sessions.middleware import SessionMiddleware +from django.core.files.uploadedfile import SimpleUploadedFile from django.test import TestCase, override_settings from django.utils import six @@ -78,6 +79,16 @@ def test_request_POST_with_form_content(self): request.parsers = (FormParser(), MultiPartParser()) self.assertEqual(list(request.POST.items()), list(data.items())) + def test_request_POST_with_files(self): + """ + Ensure request.POST returns no content for POST request with file content. + """ + upload = SimpleUploadedFile("file.txt", b"file_content") + request = Request(factory.post('/', {'upload': upload})) + request.parsers = (FormParser(), MultiPartParser()) + self.assertEqual(list(request.POST.keys()), []) + self.assertEqual(list(request.FILES.keys()), ['upload']) + def test_standard_behaviour_determines_form_content_PUT(self): """ Ensure request.data returns content for PUT request with form content. diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 81b796c35a..c866e09bed 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -49,6 +49,7 @@ def custom_list_action(self, request): def get_serializer(self, *args, **kwargs): assert self.request + assert self.action return super(ExampleViewSet, self).get_serializer(*args, **kwargs) From 61b11890495d5bb014d97cd4a8ce3b1cf951454e Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 29 Sep 2016 10:24:30 +0100 Subject: [PATCH 26/43] Fix release notes --- docs/topics/release-notes.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/docs/topics/release-notes.md b/docs/topics/release-notes.md index 6ef6cb83ab..446abdd14a 100644 --- a/docs/topics/release-notes.md +++ b/docs/topics/release-notes.md @@ -594,11 +594,8 @@ For older release notes, [please see the version 2.x documentation][old-release- [3.4.3-milestone]: https://github.com/tomchristie/django-rest-framework/issues?q=milestone%3A%223.4.3+Release%22 [3.4.4-milestone]: https://github.com/tomchristie/django-rest-framework/issues?q=milestone%3A%223.4.4+Release%22 [3.4.5-milestone]: https://github.com/tomchristie/django-rest-framework/issues?q=milestone%3A%223.4.5+Release%22 -<<<<<<< HEAD -======= [3.4.6-milestone]: https://github.com/tomchristie/django-rest-framework/issues?q=milestone%3A%223.4.6+Release%22 [3.4.7-milestone]: https://github.com/tomchristie/django-rest-framework/issues?q=milestone%3A%223.4.7+Release%22 ->>>>>>> master [gh2013]: https://github.com/tomchristie/django-rest-framework/issues/2013 From b689a3bdaa521b14d13266c866180d19585babe4 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 29 Sep 2016 12:03:14 +0100 Subject: [PATCH 27/43] Add note about 'User account is disabled.' vs 'Unable to log in' --- rest_framework/authtoken/serializers.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py index df0c48b86a..90d3bd96ed 100644 --- a/rest_framework/authtoken/serializers.py +++ b/rest_framework/authtoken/serializers.py @@ -16,6 +16,9 @@ def validate(self, attrs): user = authenticate(username=username, password=password) if user: + # From Django 1.10 onwards the `authenticate` call simply + # returns `None` for is_active=False users. + # (Assuming the default `ModelBackend` authentication backend.) if not user.is_active: msg = _('User account is disabled.') raise serializers.ValidationError(msg) From c3a9538ad90b9236ec91ec19e149075087b30765 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 30 Sep 2016 13:29:01 +0100 Subject: [PATCH 28/43] Clean up schema generation (#4527) --- rest_framework/schemas.py | 281 ++++++++++++++++++++++---------------- rest_framework/views.py | 1 + tests/test_schemas.py | 69 ++++++---- 3 files changed, 207 insertions(+), 144 deletions(-) diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index f2ec1f9e1f..39dd8d910c 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -32,85 +32,66 @@ def is_api_view(callback): return (cls is not None) and issubclass(cls, APIView) -class SchemaGenerator(object): - default_mapping = { - 'get': 'read', - 'post': 'create', - 'put': 'update', - 'patch': 'partial_update', - 'delete': 'destroy', - } - known_actions = ( - 'create', 'read', 'retrieve', 'list', - 'update', 'partial_update', 'destroy' - ) - - def __init__(self, title=None, url=None, patterns=None, urlconf=None): - assert coreapi, '`coreapi` must be installed for schema support.' - - if patterns is None and urlconf is not None: - if isinstance(urlconf, six.string_types): - urls = import_module(urlconf) - else: - urls = urlconf - self.patterns = urls.urlpatterns - elif patterns is None and urlconf is None: - urls = import_module(settings.ROOT_URLCONF) - self.patterns = urls.urlpatterns - else: - self.patterns = patterns - - if url and not url.endswith('/'): - url += '/' +def insert_into(target, keys, value): + """ + Nested dictionary insertion. - self.title = title - self.url = url - self.endpoints = None + >>> example = {} + >>> insert_into(example, ['a', 'b', 'c'], 123) + >>> example + {'a': {'b': {'c': 123}}} + """ + for key in keys[:-1]: + if key not in target: + target[key] = {} + target = target[key] + target[keys[-1]] = value - def get_schema(self, request=None): - if self.endpoints is None: - self.endpoints = self.get_api_endpoints(self.patterns) - links = [] - for path, method, category, action, callback in self.endpoints: - view = self.setup_view(callback, method, request) - if self.should_include_link(path, method, callback, view): - link = self.get_link(path, method, callback, view) - links.append((category, action, link)) +def is_custom_action(action): + return action not in set([ + 'read', 'retrieve', 'list', + 'create', 'update', 'partial_update', 'delete', 'destroy' + ]) - if not links: - return None - # Generate the schema content structure, eg: - # {'users': {'list': Link()}} - content = {} - for category, action, link in links: - if category is None: - content[action] = link - elif category in content: - content[category][action] = link +class EndpointInspector(object): + """ + A class to determine the available API endpoints that a project exposes. + """ + def __init__(self, patterns=None, urlconf=None): + if patterns is None: + if urlconf is None: + # Use the default Django URL conf + urls = import_module(settings.ROOT_URLCONF) + patterns = urls.urlpatterns else: - content[category] = {action: link} + # Load the given URLconf module + if isinstance(urlconf, six.string_types): + urls = import_module(urlconf) + else: + urls = urlconf + patterns = urls.urlpatterns - # Return the schema document. - return coreapi.Document(title=self.title, content=content, url=self.url) + self.patterns = patterns - def get_api_endpoints(self, patterns, prefix=''): + def get_api_endpoints(self, patterns=None, prefix=''): """ Return a list of all available API endpoints by inspecting the URL conf. """ + if patterns is None: + patterns = self.patterns + api_endpoints = [] for pattern in patterns: path_regex = prefix + pattern.regex.pattern if isinstance(pattern, RegexURLPattern): - path = self.get_path(path_regex) + path = self.get_path_from_regex(path_regex) callback = pattern.callback if self.should_include_endpoint(path, callback): for method in self.get_allowed_methods(callback): - action = self.get_action(path, method, callback) - category = self.get_category(path, method, callback, action) - endpoint = (path, method, category, action, callback) + endpoint = (path, method, callback) api_endpoints.append(endpoint) elif isinstance(pattern, RegexURLResolver): @@ -122,7 +103,7 @@ def get_api_endpoints(self, patterns, prefix=''): return api_endpoints - def get_path(self, path_regex): + def get_path_from_regex(self, path_regex): """ Given a URL conf regex, return a URI template string. """ @@ -157,47 +138,60 @@ def get_allowed_methods(self, callback): callback.cls().allowed_methods if method not in ('OPTIONS', 'HEAD') ] - def get_action(self, path, method, callback): - """ - Return a descriptive action string for the endpoint, eg. 'list'. - """ - actions = getattr(callback, 'actions', self.default_mapping) - return actions[method.lower()] - def get_category(self, path, method, callback, action): - """ - Return a descriptive category string for the endpoint, eg. 'users'. +class SchemaGenerator(object): + # Map methods onto 'actions' that are the names used in the link layout. + default_mapping = { + 'get': 'read', + 'post': 'create', + 'put': 'update', + 'patch': 'partial_update', + 'delete': 'destroy', + } + # Coerce the following viewset actions into different names. + coerce_actions = { + 'retrieve': 'read', + 'destroy': 'delete' + } + endpoint_inspector_cls = EndpointInspector + + def __init__(self, title=None, url=None, patterns=None, urlconf=None): + assert coreapi, '`coreapi` must be installed for schema support.' - Examples of category/action pairs that should be generated for various - endpoints: + if url and not url.endswith('/'): + url += '/' - /users/ [users][list], [users][create] - /users/{pk}/ [users][read], [users][update], [users][destroy] - /users/enabled/ [users][enabled] (custom action) - /users/{pk}/star/ [users][star] (custom action) - /users/{pk}/groups/ [groups][list], [groups][create] - /users/{pk}/groups/{pk}/ [groups][read], [groups][update], [groups][destroy] + self.endpoint_inspector = self.endpoint_inspector_cls(patterns, urlconf) + self.title = title + self.url = url + self.endpoints = None + + def get_schema(self, request=None): """ - path_components = path.strip('/').split('/') - path_components = [ - component for component in path_components - if '{' not in component - ] - if action in self.known_actions: - # Default action, eg "/users/", "/users/{pk}/" - idx = -1 - else: - # Custom action, eg "/users/{pk}/activate/", "/users/active/" - idx = -2 + Generate a `coreapi.Document` representing the API schema. + """ + if self.endpoints is None: + self.endpoints = self.endpoint_inspector.get_api_endpoints() - try: - return path_components[idx] - except IndexError: + links = {} + for path, method, callback in self.endpoints: + view = self.create_view(callback, method, request) + if not self.has_view_permissions(view): + continue + link = self.get_link(path, method, view) + keys = self.get_keys(path, method, view) + insert_into(links, keys, link) + + if not links: return None - def setup_view(self, callback, method, request): + return coreapi.Document(title=self.title, url=self.url, content=links) + + # Methods used when we generate a view instance from the raw callback... + + def create_view(self, callback, method, request=None): """ - Setup a view instance. + Given a callback, return an actual view instance. """ view = callback.cls() for attr, val in getattr(callback, 'initkwargs', {}).items(): @@ -205,6 +199,7 @@ def setup_view(self, callback, method, request): view.args = () view.kwargs = {} view.format_kwarg = None + view.request = None actions = getattr(callback, 'actions', None) if actions is not None: @@ -215,14 +210,13 @@ def setup_view(self, callback, method, request): if request is not None: view.request = clone_request(request, method) - else: - view.request = None return view - # Methods for generating each individual `Link` instance... - - def should_include_link(self, path, method, callback, view): + def has_view_permissions(self, view): + """ + Return `True` if the incoming request has the correct view permissions. + """ if view.request is None: return True @@ -230,20 +224,35 @@ def should_include_link(self, path, method, callback, view): view.check_permissions(view.request) except exceptions.APIException: return False + return True + def is_list_endpoint(self, path, method, view): + """ + Return True if the given path/method appears to represent a list endpoint. + """ + if hasattr(view, 'action'): + return view.action == 'list' + + if method.lower() != 'get': + return False + path_components = path.strip('/').split('/') + if path_components and '{' in path_components[-1]: + return False return True - def get_link(self, path, method, callback, view): + # Methods for generating each individual `Link` instance... + + def get_link(self, path, method, view): """ Return a `coreapi.Link` instance for the given endpoint. """ - fields = self.get_path_fields(path, method, callback, view) - fields += self.get_serializer_fields(path, method, callback, view) - fields += self.get_pagination_fields(path, method, callback, view) - fields += self.get_filter_fields(path, method, callback, view) + fields = self.get_path_fields(path, method, view) + fields += self.get_serializer_fields(path, method, view) + fields += self.get_pagination_fields(path, method, view) + fields += self.get_filter_fields(path, method, view) if fields and any([field.location in ('form', 'body') for field in fields]): - encoding = self.get_encoding(path, method, callback, view) + encoding = self.get_encoding(path, method, view) else: encoding = None @@ -257,7 +266,7 @@ def get_link(self, path, method, callback, view): fields=fields ) - def get_encoding(self, path, method, callback, view): + def get_encoding(self, path, method, view): """ Return the 'encoding' parameter to use for a given endpoint. """ @@ -278,7 +287,7 @@ def get_encoding(self, path, method, callback, view): return None - def get_path_fields(self, path, method, callback, view): + def get_path_fields(self, path, method, view): """ Return a list of `coreapi.Field` instances corresponding to any templated path variables. @@ -291,7 +300,7 @@ def get_path_fields(self, path, method, callback, view): return fields - def get_serializer_fields(self, path, method, callback, view): + def get_serializer_fields(self, path, method, view): """ Return a list of `coreapi.Field` instances corresponding to any request body input, as determined by the serializer class. @@ -327,11 +336,8 @@ def get_serializer_fields(self, path, method, callback, view): return fields - def get_pagination_fields(self, path, method, callback, view): - if method != 'GET': - return [] - - if hasattr(callback, 'actions') and ('list' not in callback.actions.values()): + def get_pagination_fields(self, path, method, view): + if not self.is_list_endpoint(path, method, view): return [] if not getattr(view, 'pagination_class', None): @@ -340,17 +346,54 @@ def get_pagination_fields(self, path, method, callback, view): paginator = view.pagination_class() return as_query_fields(paginator.get_fields(view)) - def get_filter_fields(self, path, method, callback, view): - if method != 'GET': - return [] - - if hasattr(callback, 'actions') and ('list' not in callback.actions.values()): + def get_filter_fields(self, path, method, view): + if not self.is_list_endpoint(path, method, view): return [] - if not hasattr(view, 'filter_backends'): + if not getattr(view, 'filter_backends', None): return [] fields = [] for filter_backend in view.filter_backends: fields += as_query_fields(filter_backend().get_fields(view)) return fields + + # Method for generating the link layout.... + + def get_keys(self, path, method, view): + """ + Return a list of keys that should be used to layout a link within + the schema document. + + /users/ ("users", "list"), ("users", "create") + /users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete") + /users/enabled/ ("users", "enabled") # custom viewset list action + /users/{pk}/star/ ("users", "star") # custom viewset detail action + /users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create") + /users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete") + """ + if hasattr(view, 'action'): + # Viewsets have explicitly named actions. + if view.action in self.coerce_actions: + action = self.coerce_actions[view.action] + else: + action = view.action + else: + # Views have no associated action, so we determine one from the method. + if self.is_list_endpoint(path, method, view): + action = 'list' + else: + action = self.default_mapping[method.lower()] + + named_path_components = [ + component for component + in path.strip('/').split('/') + if '{' not in component + ] + + if is_custom_action(action): + # Custom action, eg "/users/{pk}/activate/", "/users/active/" + return named_path_components[:-1] + [action] + + # Default action, eg "/users/", "/users/{pk}/" + return named_path_components + [action] diff --git a/rest_framework/views.py b/rest_framework/views.py index 15d8c6cde2..23d48962c0 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -130,6 +130,7 @@ def force_evaluation(): view = super(APIView, cls).as_view(**initkwargs) view.cls = cls + view.initkwargs = initkwargs # Note: session based authentication is explicitly CSRF validated, # all other authentication is CSRF exempt. diff --git a/tests/test_schemas.py b/tests/test_schemas.py index dc01d8cd84..1d98a4618c 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -6,7 +6,6 @@ from rest_framework import filters, pagination, permissions, serializers from rest_framework.compat import coreapi from rest_framework.decorators import detail_route, list_route -from rest_framework.response import Response from rest_framework.routers import DefaultRouter from rest_framework.schemas import SchemaGenerator from rest_framework.test import APIClient @@ -55,24 +54,11 @@ def get_serializer(self, *args, **kwargs): return super(ExampleViewSet, self).get_serializer(*args, **kwargs) -class ExampleView(APIView): - permission_classes = [permissions.IsAuthenticatedOrReadOnly] - - def get(self, request, *args, **kwargs): - return Response() - - def post(self, request, *args, **kwargs): - return Response() - - router = DefaultRouter(schema_title='Example API' if coreapi else None) router.register('example', ExampleViewSet, base_name='example') urlpatterns = [ url(r'^', include(router.urls)) ] -urlpatterns2 = [ - url(r'^example-view/$', ExampleView.as_view(), name='example-view') -] @unittest.skipUnless(coreapi, 'coreapi is not installed') @@ -99,7 +85,7 @@ def test_anonymous_request(self): url='/example/custom_list_action/', action='get' ), - 'retrieve': coreapi.Link( + 'read': coreapi.Link( url='/example/{pk}/', action='get', fields=[ @@ -138,7 +124,7 @@ def test_authenticated_request(self): coreapi.Field('b', required=False, location='form') ] ), - 'retrieve': coreapi.Link( + 'read': coreapi.Link( url='/example/{pk}/', action='get', fields=[ @@ -179,7 +165,7 @@ def test_authenticated_request(self): coreapi.Field('b', required=False, location='form') ] ), - 'destroy': coreapi.Link( + 'delete': coreapi.Link( url='/example/{pk}/', action='delete', fields=[ @@ -192,25 +178,58 @@ def test_authenticated_request(self): self.assertEqual(response.data, expected) +class ExampleListView(APIView): + permission_classes = [permissions.IsAuthenticatedOrReadOnly] + + def get(self, *args, **kwargs): + pass + + def post(self, request, *args, **kwargs): + pass + + +class ExampleDetailView(APIView): + permission_classes = [permissions.IsAuthenticatedOrReadOnly] + + def get(self, *args, **kwargs): + pass + + @unittest.skipUnless(coreapi, 'coreapi is not installed') class TestSchemaGenerator(TestCase): - def test_view(self): - schema_generator = SchemaGenerator(title='Test View', patterns=urlpatterns2) - schema = schema_generator.get_schema() + def setUp(self): + self.patterns = [ + url('^example/?$', ExampleListView.as_view()), + url('^example/(?P\d+)/?$', ExampleDetailView.as_view()), + ] + + def test_schema_for_regular_views(self): + """ + Ensure that schema generation works for APIView classes. + """ + generator = SchemaGenerator(title='Example API', patterns=self.patterns) + schema = generator.get_schema() expected = coreapi.Document( url='', - title='Test View', + title='Example API', content={ - 'example-view': { + 'example': { 'create': coreapi.Link( - url='/example-view/', + url='/example/', action='post', fields=[] ), - 'read': coreapi.Link( - url='/example-view/', + 'list': coreapi.Link( + url='/example/', action='get', fields=[] + ), + 'read': coreapi.Link( + url='/example/{pk}/', + action='get', + fields=[ + coreapi.Field('pk', required=True, location='path') + ] ) } } From 4ad5256e886a23cfbd3d854a878b20b6cfdc54fd Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 30 Sep 2016 15:23:47 +0100 Subject: [PATCH 29/43] Handle multiple methods on custom action (#4529) --- rest_framework/schemas.py | 23 +++++++++++++++++------ tests/test_schemas.py | 36 +++++++++++++++++++++++++++++++++++- 2 files changed, 52 insertions(+), 7 deletions(-) diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index 39dd8d910c..1c2f1a546d 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -173,6 +173,16 @@ def get_schema(self, request=None): if self.endpoints is None: self.endpoints = self.endpoint_inspector.get_api_endpoints() + links = self.get_links(request) + if not links: + return None + return coreapi.Document(title=self.title, url=self.url, content=links) + + def get_links(self, request=None): + """ + Return a dictionary containing all the links that should be + included in the API schema. + """ links = {} for path, method, callback in self.endpoints: view = self.create_view(callback, method, request) @@ -181,11 +191,7 @@ def get_schema(self, request=None): link = self.get_link(path, method, view) keys = self.get_keys(path, method, view) insert_into(links, keys, link) - - if not links: - return None - - return coreapi.Document(title=self.title, url=self.url, content=links) + return links # Methods used when we generate a view instance from the raw callback... @@ -200,6 +206,7 @@ def create_view(self, callback, method, request=None): view.kwargs = {} view.format_kwarg = None view.request = None + view.action_map = getattr(callback, 'actions', None) actions = getattr(callback, 'actions', None) if actions is not None: @@ -393,7 +400,11 @@ def get_keys(self, path, method, view): if is_custom_action(action): # Custom action, eg "/users/{pk}/activate/", "/users/active/" - return named_path_components[:-1] + [action] + if len(view.action_map) > 1: + action = self.default_mapping[method.lower()] + return named_path_components + [action] + else: + return named_path_components[:-1] + [action] # Default action, eg "/users/", "/users/{pk}/" return named_path_components + [action] diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 1d98a4618c..5794f968ab 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -22,6 +22,10 @@ class ExamplePagination(pagination.PageNumberPagination): page_size = 100 +class EmptySerializer(serializers.Serializer): + pass + + class ExampleSerializer(serializers.Serializer): a = serializers.CharField(required=True, help_text='A field description') b = serializers.CharField(required=False) @@ -48,6 +52,10 @@ def custom_action(self, request, pk): def custom_list_action(self, request): return super(ExampleViewSet, self).list(self, request) + @list_route(methods=['post', 'get'], serializer_class=EmptySerializer) + def custom_list_action_multiple_methods(self, request): + return super(ExampleViewSet, self).list(self, request) + def get_serializer(self, *args, **kwargs): assert self.request assert self.action @@ -85,6 +93,12 @@ def test_anonymous_request(self): url='/example/custom_list_action/', action='get' ), + 'custom_list_action_multiple_methods': { + 'read': coreapi.Link( + url='/example/custom_list_action_multiple_methods/', + action='get' + ) + }, 'read': coreapi.Link( url='/example/{pk}/', action='get', @@ -145,6 +159,16 @@ def test_authenticated_request(self): url='/example/custom_list_action/', action='get' ), + 'custom_list_action_multiple_methods': { + 'read': coreapi.Link( + url='/example/custom_list_action_multiple_methods/', + action='get' + ), + 'create': coreapi.Link( + url='/example/custom_list_action_multiple_methods/', + action='post' + ) + }, 'update': coreapi.Link( url='/example/{pk}/', action='put', @@ -201,6 +225,7 @@ def setUp(self): self.patterns = [ url('^example/?$', ExampleListView.as_view()), url('^example/(?P\d+)/?$', ExampleDetailView.as_view()), + url('^example/(?P\d+)/sub/?$', ExampleDetailView.as_view()), ] def test_schema_for_regular_views(self): @@ -230,7 +255,16 @@ def test_schema_for_regular_views(self): fields=[ coreapi.Field('pk', required=True, location='path') ] - ) + ), + 'sub': { + 'list': coreapi.Link( + url='/example/{pk}/sub/', + action='get', + fields=[ + coreapi.Field('pk', required=True, location='path') + ] + ) + } } } ) From 8044d38c21453194a6f426e866573ab41cc7fa1c Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 4 Oct 2016 13:57:34 +0100 Subject: [PATCH 30/43] RequestsClient, CoreAPIClient --- docs/api-guide/testing.md | 93 +++++++++++++++++++++++++++++++++++ rest_framework/routers.py | 3 +- rest_framework/schemas.py | 17 ++++++- rest_framework/test.py | 53 ++++++++++++-------- tests/test_api_client.py | 62 +++++++++++++++-------- tests/test_requests_client.py | 63 ++++++++++++++---------- 6 files changed, 221 insertions(+), 70 deletions(-) diff --git a/docs/api-guide/testing.md b/docs/api-guide/testing.md index 18f9e19e90..fdd60b7f4c 100644 --- a/docs/api-guide/testing.md +++ b/docs/api-guide/testing.md @@ -184,6 +184,99 @@ As usual CSRF validation will only apply to any session authenticated views. Th --- +# RequestsClient + +REST framework also includes a client for interacting with your application +using the popular Python library, `requests`. + +This exposes exactly the same interface as if you were using a requests session +directly. + + client = RequestsClient() + response = client.get('http://testserver/users/') + +Note that the requests client requires you to pass fully qualified URLs. + +## Headers & Authentication + +Custom headers and authentication credentials can be provided in the same way +as [when using a standard `requests.Session` instance](http://docs.python-requests.org/en/master/user/advanced/#session-objects). + + from requests.auth import HTTPBasicAuth + + client.auth = HTTPBasicAuth('user', 'pass') + client.headers.update({'x-test': 'true'}) + +## CSRF + +If you're using `SessionAuthentication` then you'll need to include a CSRF token +for any `POST`, `PUT`, `PATCH` or `DELETE` requests. + +You can do so by following the same flow that a JavaScript based client would use. +First make a `GET` request in order to obtain a CRSF token, then present that +token in the following request. + +For example... + + client = RequestsClient() + + # Obtain a CSRF token. + response = client.get('/homepage/') + assert response.status_code == 200 + csrftoken = response.cookies['csrftoken'] + + # Interact with the API. + response = client.post('/organisations/', json={ + 'name': 'MegaCorp', + 'status': 'active' + }, headers={'X-CSRFToken': csrftoken}) + assert response.status_code == 200 + +## Live tests + +With careful usage both the `RequestsClient` and the `CoreAPIClient` provide +the ability to write test cases that can run either in development, or be run +directly against your staging server or production environment. + +Using this style to create basic tests of a few core piece of functionality is +a powerful way to validate your live service. Doing so may require some careful +attention to setup and teardown to ensure that the tests run in a way that they +do not directly affect customer data. + +--- + +# CoreAPIClient + +The CoreAPIClient allows you to interact with your API using the Python +`coreapi` client library. + + # Fetch the API schema + url = reverse('schema') + client = CoreAPIClient() + schema = client.get(url) + + # Create a new organisation + params = {'name': 'MegaCorp', 'status': 'active'} + client.action(schema, ['organisations', 'create'], params) + + # Ensure that the organisation exists in the listing + data = client.action(schema, ['organisations', 'list']) + assert(len(data) == 1) + assert(data == [{'name': 'MegaCorp', 'status': 'active'}]) + +## Headers & Authentication + +Custom headers and authentication may be used with `CoreAPIClient` in a +similar way as with `RequestsClient`. + + from requests.auth import HTTPBasicAuth + + client = CoreAPIClient() + client.session.auth = HTTPBasicAuth('user', 'pass') + client.session.headers.update({'x-test': 'true'}) + +--- + # Test cases REST framework includes the following test case classes, that mirror the existing Django test case classes, but use `APIClient` instead of Django's default `Client`. diff --git a/rest_framework/routers.py b/rest_framework/routers.py index a02a0f1bf3..859b604607 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -23,6 +23,7 @@ from rest_framework import exceptions, renderers, views from rest_framework.compat import NoReverseMatch +from rest_framework.renderers import BrowsableAPIRenderer from rest_framework.response import Response from rest_framework.reverse import reverse from rest_framework.schemas import SchemaGenerator @@ -281,7 +282,7 @@ class DefaultRouter(SimpleRouter): include_root_view = True include_format_suffixes = True root_view_name = 'api-root' - default_schema_renderers = [renderers.CoreJSONRenderer] + default_schema_renderers = [renderers.CoreJSONRenderer, BrowsableAPIRenderer] def __init__(self, *args, **kwargs): if 'schema_renderers' in kwargs: diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index 1c2f1a546d..0928e76614 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -1,3 +1,4 @@ +from collections import OrderedDict from importlib import import_module from django.conf import settings @@ -55,6 +56,18 @@ def is_custom_action(action): ]) +def endpoint_ordering(endpoint): + path, method, callback = endpoint + method_priority = { + 'GET': 0, + 'POST': 1, + 'PUT': 2, + 'PATCH': 3, + 'DELETE': 4 + }.get(method, 5) + return (path, method_priority) + + class EndpointInspector(object): """ A class to determine the available API endpoints that a project exposes. @@ -101,6 +114,8 @@ def get_api_endpoints(self, patterns=None, prefix=''): ) api_endpoints.extend(nested_endpoints) + api_endpoints = sorted(api_endpoints, key=endpoint_ordering) + return api_endpoints def get_path_from_regex(self, path_regex): @@ -183,7 +198,7 @@ def get_links(self, request=None): Return a dictionary containing all the links that should be included in the API schema. """ - links = {} + links = OrderedDict() for path, method, callback in self.endpoints: view = self.create_view(callback, method, request) if not self.has_view_permissions(view): diff --git a/rest_framework/test.py b/rest_framework/test.py index 1b3ad80c21..16b1b4cd52 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -7,6 +7,7 @@ import io from django.conf import settings +from django.core.exceptions import ImproperlyConfigured from django.core.handlers.wsgi import WSGIHandler from django.test import testcases from django.test.client import Client as DjangoClient @@ -105,36 +106,46 @@ def start_response(wsgi_status, wsgi_headers): def close(self): pass - class DjangoTestSession(requests.Session): - def __init__(self, *args, **kwargs): - super(DjangoTestSession, self).__init__(*args, **kwargs) + class NoExternalRequestsAdapter(requests.adapters.HTTPAdapter): + def send(self, request, *args, **kwargs): + msg = ( + 'RequestsClient refusing to make an outgoing network request ' + 'to "%s". Only "testserver" or hostnames in your ALLOWED_HOSTS ' + 'setting are valid.' % request.url + ) + raise RuntimeError(msg) + class RequestsClient(requests.Session): + def __init__(self, *args, **kwargs): + super(RequestsClient, self).__init__(*args, **kwargs) adapter = DjangoTestAdapter() - hostnames = list(settings.ALLOWED_HOSTS) + ['testserver'] - - for hostname in hostnames: - if hostname == '*': - hostname = '' - self.mount('http://%s' % hostname, adapter) - self.mount('https://%s' % hostname, adapter) + self.mount('http://', adapter) + self.mount('https://', adapter) def request(self, method, url, *args, **kwargs): if ':' not in url: - url = 'http://testserver/' + url.lstrip('/') - return super(DjangoTestSession, self).request(method, url, *args, **kwargs) + raise ValueError('Missing "http:" or "https:". Use a fully qualified URL, eg "http://testserver%s"' % url) + return super(RequestsClient, self).request(method, url, *args, **kwargs) + +else: + def RequestsClient(*args, **kwargs): + raise ImproperlyConfigured('requests must be installed in order to use RequestsClient.') -def get_requests_client(): - assert requests is not None, 'requests must be installed' - return DjangoTestSession() +if coreapi is not None: + class CoreAPIClient(coreapi.Client): + def __init__(self, *args, **kwargs): + self._session = RequestsClient() + kwargs['transports'] = [coreapi.transports.HTTPTransport(session=self.session)] + return super(CoreAPIClient, self).__init__(*args, **kwargs) + @property + def session(self): + return self._session -def get_api_client(): - assert coreapi is not None, 'coreapi must be installed' - session = get_requests_client() - return coreapi.Client(transports=[ - coreapi.transports.HTTPTransport(session=session) - ]) +else: + def CoreAPIClient(*args, **kwargs): + raise ImproperlyConfigured('coreapi must be installed in order to use CoreAPIClient.') class APIRequestFactory(DjangoRequestFactory): diff --git a/tests/test_api_client.py b/tests/test_api_client.py index 9daf3f3fe4..a6d72357aa 100644 --- a/tests/test_api_client.py +++ b/tests/test_api_client.py @@ -12,7 +12,7 @@ from rest_framework.parsers import FileUploadParser from rest_framework.renderers import CoreJSONRenderer from rest_framework.response import Response -from rest_framework.test import APITestCase, get_api_client +from rest_framework.test import APITestCase, CoreAPIClient from rest_framework.views import APIView @@ -22,6 +22,7 @@ def get_schema(): title='Example API', content={ 'simple_link': coreapi.Link('/example/', description='example link'), + 'headers': coreapi.Link('/headers/'), 'location': { 'query': coreapi.Link('/example/', fields=[ coreapi.Field(name='example', description='example field') @@ -165,6 +166,19 @@ def get(self, request): return HttpResponse('123', content_type='text/plain') +class HeadersView(APIView): + def get(self, request): + headers = { + key[5:].replace('_', '-'): value + for key, value in request.META.items() + if key.startswith('HTTP_') + } + return Response({ + 'method': request.method, + 'headers': headers + }) + + urlpatterns = [ url(r'^$', SchemaView.as_view()), url(r'^example/$', ListView.as_view()), @@ -172,6 +186,7 @@ def get(self, request): url(r'^upload/$', UploadView.as_view()), url(r'^download/$', DownloadView.as_view()), url(r'^text/$', TextView.as_view()), + url(r'^headers/$', HeadersView.as_view()), ] @@ -179,7 +194,7 @@ def get(self, request): @override_settings(ROOT_URLCONF='tests.test_api_client') class APIClientTests(APITestCase): def test_api_client(self): - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') assert schema.title == 'Example API' assert schema.url == 'https://api.example.com/' @@ -193,7 +208,7 @@ def test_api_client(self): assert data == expected def test_query_params(self): - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') data = client.action(schema, ['location', 'query'], params={'example': 123}) expected = { @@ -202,8 +217,15 @@ def test_query_params(self): } assert data == expected + def test_session_headers(self): + client = CoreAPIClient() + client.session.headers.update({'X-Custom-Header': 'foo'}) + schema = client.get('http://api.example.com/') + data = client.action(schema, ['headers']) + assert data['headers']['X-CUSTOM-HEADER'] == 'foo' + def test_query_params_with_multiple_values(self): - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') data = client.action(schema, ['location', 'query'], params={'example': [1, 2, 3]}) expected = { @@ -213,7 +235,7 @@ def test_query_params_with_multiple_values(self): assert data == expected def test_form_params(self): - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') data = client.action(schema, ['location', 'form'], params={'example': 123}) expected = { @@ -226,7 +248,7 @@ def test_form_params(self): assert data == expected def test_body_params(self): - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') data = client.action(schema, ['location', 'body'], params={'example': 123}) expected = { @@ -239,7 +261,7 @@ def test_body_params(self): assert data == expected def test_path_params(self): - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') data = client.action(schema, ['location', 'path'], params={'id': 123}) expected = { @@ -250,7 +272,7 @@ def test_path_params(self): assert data == expected def test_multipart_encoding(self): - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') temp = tempfile.NamedTemporaryFile() @@ -272,7 +294,7 @@ def test_multipart_encoding(self): def test_multipart_encoding_no_file(self): # When no file is included, multipart encoding should still be used. - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') data = client.action(schema, ['encoding', 'multipart'], params={'example': 123}) @@ -287,7 +309,7 @@ def test_multipart_encoding_no_file(self): assert data == expected def test_multipart_encoding_multiple_values(self): - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') data = client.action(schema, ['encoding', 'multipart'], params={'example': [1, 2, 3]}) @@ -305,7 +327,7 @@ def test_multipart_encoding_string_file_content(self): # Test for `coreapi.utils.File` support. from coreapi.utils import File - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') example = File(name='example.txt', content='123') @@ -323,7 +345,7 @@ def test_multipart_encoding_string_file_content(self): def test_multipart_encoding_in_body(self): from coreapi.utils import File - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') example = {'foo': File(name='example.txt', content='123'), 'bar': 'abc'} @@ -341,7 +363,7 @@ def test_multipart_encoding_in_body(self): # URLencoded def test_urlencoded_encoding(self): - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') data = client.action(schema, ['encoding', 'urlencoded'], params={'example': 123}) expected = { @@ -354,7 +376,7 @@ def test_urlencoded_encoding(self): assert data == expected def test_urlencoded_encoding_multiple_values(self): - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') data = client.action(schema, ['encoding', 'urlencoded'], params={'example': [1, 2, 3]}) expected = { @@ -367,7 +389,7 @@ def test_urlencoded_encoding_multiple_values(self): assert data == expected def test_urlencoded_encoding_in_body(self): - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') data = client.action(schema, ['encoding', 'urlencoded-body'], params={'example': {'foo': 123, 'bar': True}}) expected = { @@ -382,7 +404,7 @@ def test_urlencoded_encoding_in_body(self): # Raw uploads def test_raw_upload(self): - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') temp = tempfile.NamedTemporaryFile() @@ -403,7 +425,7 @@ def test_raw_upload(self): def test_raw_upload_string_file_content(self): from coreapi.utils import File - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') example = File('example.txt', '123') @@ -419,7 +441,7 @@ def test_raw_upload_string_file_content(self): def test_raw_upload_explicit_content_type(self): from coreapi.utils import File - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') example = File('example.txt', '123', 'text/html') @@ -435,7 +457,7 @@ def test_raw_upload_explicit_content_type(self): # Responses def test_text_response(self): - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') data = client.action(schema, ['response', 'text']) @@ -444,7 +466,7 @@ def test_text_response(self): assert data == expected def test_download_response(self): - client = get_api_client() + client = CoreAPIClient() schema = client.get('http://api.example.com/') data = client.action(schema, ['response', 'download']) diff --git a/tests/test_requests_client.py b/tests/test_requests_client.py index 37bde10922..791ca4ff2d 100644 --- a/tests/test_requests_client.py +++ b/tests/test_requests_client.py @@ -12,7 +12,7 @@ from rest_framework.compat import is_authenticated, requests from rest_framework.response import Response -from rest_framework.test import APITestCase, get_requests_client +from rest_framework.test import APITestCase, RequestsClient from rest_framework.views import APIView @@ -92,10 +92,10 @@ def post(self, request): urlpatterns = [ - url(r'^$', Root.as_view()), - url(r'^headers/$', HeadersView.as_view()), - url(r'^session/$', SessionView.as_view()), - url(r'^auth/$', AuthView.as_view()), + url(r'^$', Root.as_view(), name='root'), + url(r'^headers/$', HeadersView.as_view(), name='headers'), + url(r'^session/$', SessionView.as_view(), name='session'), + url(r'^auth/$', AuthView.as_view(), name='auth'), ] @@ -103,8 +103,8 @@ def post(self, request): @override_settings(ROOT_URLCONF='tests.test_requests_client') class RequestsClientTests(APITestCase): def test_get_request(self): - client = get_requests_client() - response = client.get('/') + client = RequestsClient() + response = client.get('http://testserver/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -114,8 +114,8 @@ def test_get_request(self): assert response.json() == expected def test_get_request_query_params_in_url(self): - client = get_requests_client() - response = client.get('/?key=value') + client = RequestsClient() + response = client.get('http://testserver/?key=value') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -125,8 +125,8 @@ def test_get_request_query_params_in_url(self): assert response.json() == expected def test_get_request_query_params_by_kwarg(self): - client = get_requests_client() - response = client.get('/', params={'key': 'value'}) + client = RequestsClient() + response = client.get('http://testserver/', params={'key': 'value'}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -136,16 +136,25 @@ def test_get_request_query_params_by_kwarg(self): assert response.json() == expected def test_get_with_headers(self): - client = get_requests_client() - response = client.get('/headers/', headers={'User-Agent': 'example'}) + client = RequestsClient() + response = client.get('http://testserver/headers/', headers={'User-Agent': 'example'}) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + headers = response.json()['headers'] + assert headers['USER-AGENT'] == 'example' + + def test_get_with_session_headers(self): + client = RequestsClient() + client.headers.update({'User-Agent': 'example'}) + response = client.get('http://testserver/headers/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' headers = response.json()['headers'] assert headers['USER-AGENT'] == 'example' def test_post_form_request(self): - client = get_requests_client() - response = client.post('/', data={'key': 'value'}) + client = RequestsClient() + response = client.post('http://testserver/', data={'key': 'value'}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -158,8 +167,8 @@ def test_post_form_request(self): assert response.json() == expected def test_post_json_request(self): - client = get_requests_client() - response = client.post('/', json={'key': 'value'}) + client = RequestsClient() + response = client.post('http://testserver/', json={'key': 'value'}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -172,11 +181,11 @@ def test_post_json_request(self): assert response.json() == expected def test_post_multipart_request(self): - client = get_requests_client() + client = RequestsClient() files = { 'file': ('report.csv', 'some,data,to,send\nanother,row,to,send\n') } - response = client.post('/', files=files) + response = client.post('http://testserver/', files=files) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -189,20 +198,20 @@ def test_post_multipart_request(self): assert response.json() == expected def test_session(self): - client = get_requests_client() - response = client.get('/session/') + client = RequestsClient() + response = client.get('http://testserver/session/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = {} assert response.json() == expected - response = client.post('/session/', json={'example': 'abc'}) + response = client.post('http://testserver/session/', json={'example': 'abc'}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = {'example': 'abc'} assert response.json() == expected - response = client.get('/session/') + response = client.get('http://testserver/session/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = {'example': 'abc'} @@ -210,8 +219,8 @@ def test_session(self): def test_auth(self): # Confirm session is not authenticated - client = get_requests_client() - response = client.get('/auth/') + client = RequestsClient() + response = client.get('http://testserver/auth/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -226,7 +235,7 @@ def test_auth(self): user.save() # Perform a login - response = client.post('/auth/', json={ + response = client.post('http://testserver/auth/', json={ 'username': 'tom', 'password': 'password' }, headers={'X-CSRFToken': csrftoken}) @@ -238,7 +247,7 @@ def test_auth(self): assert response.json() == expected # Confirm session is authenticated - response = client.get('/auth/') + response = client.get('http://testserver/auth/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { From a8501f72b78d833a7e95cb879e8856b0e265f2cc Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 5 Oct 2016 12:42:52 +0100 Subject: [PATCH 31/43] exclude_from_schema --- docs/api-guide/views.md | 12 ++++++++-- rest_framework/decorators.py | 3 ++- rest_framework/routers.py | 2 ++ rest_framework/schemas.py | 46 +++++++++++++++++++----------------- rest_framework/views.py | 3 +++ 5 files changed, 41 insertions(+), 25 deletions(-) diff --git a/docs/api-guide/views.md b/docs/api-guide/views.md index 62f14087fe..55f6664e07 100644 --- a/docs/api-guide/views.md +++ b/docs/api-guide/views.md @@ -127,7 +127,7 @@ REST framework also allows you to work with regular function based views. It pr ## @api_view() -**Signature:** `@api_view(http_method_names=['GET'])` +**Signature:** `@api_view(http_method_names=['GET'], exclude_from_schema=False)` The core of this functionality is the `api_view` decorator, which takes a list of HTTP methods that your view should respond to. For example, this is how you would write a very simple view that just manually returns some data: @@ -139,7 +139,7 @@ The core of this functionality is the `api_view` decorator, which takes a list o This view will use the default renderers, parsers, authentication classes etc specified in the [settings]. -By default only `GET` methods will be accepted. Other methods will respond with "405 Method Not Allowed". To alter this behavior, specify which methods the view allows, like so: +By default only `GET` methods will be accepted. Other methods will respond with "405 Method Not Allowed". To alter this behaviour, specify which methods the view allows, like so: @api_view(['GET', 'POST']) def hello_world(request): @@ -147,6 +147,13 @@ By default only `GET` methods will be accepted. Other methods will respond with return Response({"message": "Got some data!", "data": request.data}) return Response({"message": "Hello, world!"}) +You can also mark an API view as being omitted from any [auto-generated schema][schemas], +using the `exclude_from_schema` argument.: + + @api_view(['GET'], exclude_from_schema=True) + def api_docs(request): + ... + ## API policy decorators To override the default settings, REST framework provides a set of additional decorators which can be added to your views. These must come *after* (below) the `@api_view` decorator. For example, to create a view that uses a [throttle][throttling] to ensure it can only be called once per day by a particular user, use the `@throttle_classes` decorator, passing a list of throttle classes: @@ -178,3 +185,4 @@ Each of these decorators takes a single argument which must be a list or tuple o [cite2]: http://www.boredomandlaziness.org/2012/05/djangos-cbvs-are-not-mistake-but.html [settings]: settings.md [throttling]: throttling.md +[schemas]: schemas.md diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index 554e5236cc..bf9b32aaa7 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -15,7 +15,7 @@ from rest_framework.views import APIView -def api_view(http_method_names=None): +def api_view(http_method_names=None, exclude_from_schema=False): """ Decorator that converts a function-based view into an APIView subclass. Takes a list of allowed methods for the view as an argument. @@ -72,6 +72,7 @@ def handler(self, *args, **kwargs): WrappedAPIView.permission_classes = getattr(func, 'permission_classes', APIView.permission_classes) + WrappedAPIView.exclude_from_schema = exclude_from_schema return WrappedAPIView.as_view() return decorator diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 859b604607..9ef8d11f06 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -311,6 +311,7 @@ def get_schema_root_view(self, api_urls=None): class APISchemaView(views.APIView): _ignore_model_permissions = True + exclude_from_schema = True renderer_classes = schema_renderers def get(self, request, *args, **kwargs): @@ -332,6 +333,7 @@ def get_api_root_view(self, api_urls=None): class APIRootView(views.APIView): _ignore_model_permissions = True + exclude_from_schema = True def get(self, request, *args, **kwargs): # Return a plain {"name": "hyperlink"} response. diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index 0928e76614..c3b36cd1ed 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -56,6 +56,22 @@ def is_custom_action(action): ]) +def is_list_view(path, method, view): + """ + Return True if the given path/method appears to represent a list view. + """ + if hasattr(view, 'action'): + # Viewsets have an explicitly defined action, which we can inspect. + return view.action == 'list' + + if method.lower() != 'get': + return False + path_components = path.strip('/').split('/') + if path_components and '{' in path_components[-1]: + return False + return True + + def endpoint_ordering(endpoint): path, method, callback = endpoint method_priority = { @@ -136,9 +152,6 @@ def should_include_endpoint(self, path, callback): if path.endswith('.{format}') or path.endswith('.{format}/'): return False # Ignore .json style URLs. - if path == '/': - return False # Ignore the root endpoint. - return True def get_allowed_methods(self, callback): @@ -201,7 +214,7 @@ def get_links(self, request=None): links = OrderedDict() for path, method, callback in self.endpoints: view = self.create_view(callback, method, request) - if not self.has_view_permissions(view): + if not self.should_include_view(path, method, view): continue link = self.get_link(path, method, view) keys = self.get_keys(path, method, view) @@ -235,10 +248,13 @@ def create_view(self, callback, method, request=None): return view - def has_view_permissions(self, view): + def should_include_view(self, path, method, view): """ Return `True` if the incoming request has the correct view permissions. """ + if getattr(view, 'exclude_from_schema', False): + return False + if view.request is None: return True @@ -248,20 +264,6 @@ def has_view_permissions(self, view): return False return True - def is_list_endpoint(self, path, method, view): - """ - Return True if the given path/method appears to represent a list endpoint. - """ - if hasattr(view, 'action'): - return view.action == 'list' - - if method.lower() != 'get': - return False - path_components = path.strip('/').split('/') - if path_components and '{' in path_components[-1]: - return False - return True - # Methods for generating each individual `Link` instance... def get_link(self, path, method, view): @@ -359,7 +361,7 @@ def get_serializer_fields(self, path, method, view): return fields def get_pagination_fields(self, path, method, view): - if not self.is_list_endpoint(path, method, view): + if not is_list_view(path, method, view): return [] if not getattr(view, 'pagination_class', None): @@ -369,7 +371,7 @@ def get_pagination_fields(self, path, method, view): return as_query_fields(paginator.get_fields(view)) def get_filter_fields(self, path, method, view): - if not self.is_list_endpoint(path, method, view): + if not is_list_view(path, method, view): return [] if not getattr(view, 'filter_backends', None): @@ -402,7 +404,7 @@ def get_keys(self, path, method, view): action = view.action else: # Views have no associated action, so we determine one from the method. - if self.is_list_endpoint(path, method, view): + if is_list_view(path, method, view): action = 'list' else: action = self.default_mapping[method.lower()] diff --git a/rest_framework/views.py b/rest_framework/views.py index 23d48962c0..bb9aba72ca 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -110,6 +110,9 @@ class APIView(View): # Allow dependency injection of other settings to make testing easier. settings = api_settings + # Mark the view as being included or excluded from schema generation. + exclude_from_schema = False + @classmethod def as_view(cls, **initkwargs): """ From cd826ce422b2bbf18a675c50b989f51e253e229e Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 5 Oct 2016 13:46:39 +0100 Subject: [PATCH 32/43] Added 'get_schema_view()' shortcut --- docs/api-guide/schemas.md | 75 ++++++++++--------- .../7-schemas-and-client-libraries.md | 21 ++++-- rest_framework/schemas.py | 52 ++++++++++--- 3 files changed, 93 insertions(+), 55 deletions(-) diff --git a/docs/api-guide/schemas.md b/docs/api-guide/schemas.md index 7f8af723eb..57b73addc9 100644 --- a/docs/api-guide/schemas.md +++ b/docs/api-guide/schemas.md @@ -102,15 +102,20 @@ REST framework includes functionality for auto-generating a schema, or allows you to specify one explicitly. There are a few different ways to add a schema to your API, depending on exactly what you need. -## Using DefaultRouter +## The get_schema_view shortcut -If you're using `DefaultRouter` then you can include an auto-generated schema, -simply by adding a `schema_title` argument to the router. +The simplest way to include a schema in your project is to use the +`get_schema_view()` function. - router = DefaultRouter(schema_title='Server Monitoring API') + schema_view = get_schema_view(title="Server Monitoring API") -The schema will be included at the root URL, `/`, and presented to clients -that include the Core JSON media type in their `Accept` header. + urlpatterns = [ + url('^$', schema_view), + ... + ] + +Once the view has been added, you'll be able to make API requests to retrieve +the auto-generated schema definition. $ http http://127.0.0.1:8000/ Accept:application/vnd.coreapi+json HTTP/1.0 200 OK @@ -125,48 +130,43 @@ that include the Core JSON media type in their `Accept` header. ... } -This is a great zero-configuration option for when you want to get up and -running really quickly. - -The other available options to `DefaultRouter` are: +The arguments to `get_schema_view()` are: -#### schema_renderers - -May be used to pass the set of renderer classes that can be used to render schema output. - - from rest_framework.renderers import CoreJSONRenderer - from my_custom_package import APIBlueprintRenderer +#### `title` - router = DefaultRouter(schema_title='Server Monitoring API', schema_renderers=[ - CoreJSONRenderer, APIBlueprintRenderer - ]) +May be used to provide a descriptive title for the schema definition. -#### schema_url +#### `url` -May be used to pass the root URL for the schema. This can either be used to ensure that -the schema URLs include a canonical hostname and schema, or to ensure that all the -schema URLs include a path prefix. +May be used to pass a canonical URL for the schema. - router = DefaultRouter( - schema_title='Server Monitoring API', - schema_url='https://www.example.org/api/' + schema_view = get_schema_view( + title='Server Monitoring API', + url='https://www.example.org/api/' ) -If you want more flexibility over the schema output then you'll need to consider -using `SchemaGenerator` instead. - -#### root_renderers +#### `renderer_classes` May be used to pass the set of renderer classes that can be used to render the API root endpoint. -## Using SchemaGenerator + from rest_framework.renderers import CoreJSONRenderer + from my_custom_package import APIBlueprintRenderer -The most common way to add a schema to your API is to use the `SchemaGenerator` -class to auto-generate the `Document` instance, and to return that from a view. + schema_view = get_schema_view( + title='Server Monitoring API', + url='https://www.example.org/api/', + renderer_classes=[CoreJSONRenderer, APIBlueprintRenderer] + ) + +## Using an explicit schema view + +If you need a little more control than the `get_schema_view()` shortcut gives you, +then you can use the `SchemaGenerator` class directly to auto-generate the +`Document` instance, and to return that from a view. This option gives you the flexibility of setting up the schema endpoint -with whatever behavior you want. For example, you can apply different -permission, throttling or authentication policies to the schema endpoint. +with whatever behaviour you want. For example, you can apply different +permission, throttling, or authentication policies to the schema endpoint. Here's an example of using `SchemaGenerator` together with a view to return the schema. @@ -176,12 +176,13 @@ return the schema. from rest_framework.decorators import api_view, renderer_classes from rest_framework import renderers, response, schemas + generator = schemas.SchemaGenerator(title='Bookings API') @api_view() @renderer_classes([renderers.CoreJSONRenderer]) def schema_view(request): - generator = schemas.SchemaGenerator(title='Bookings API') - return response.Response(generator.get_schema()) + schema = generator.get_schema(request) + return response.Response(schema) **urls.py:** diff --git a/docs/tutorial/7-schemas-and-client-libraries.md b/docs/tutorial/7-schemas-and-client-libraries.md index 77cfdd3b30..3d4d2c9410 100644 --- a/docs/tutorial/7-schemas-and-client-libraries.md +++ b/docs/tutorial/7-schemas-and-client-libraries.md @@ -33,10 +33,17 @@ API schema. $ pip install coreapi -We can now include a schema for our API, by adding a `schema_title` argument to -the router instantiation. +We can now include a schema for our API, by including an autogenerated schema +view in our URL configuration. - router = DefaultRouter(schema_title='Pastebin API') + from rest_framework.schemas import get_schema_view + + schema_view = get_schema_view(title='Pastebin API') + + urlpatterns = [ + url('^schema/$', schema_view), + ... + ] If you visit the API root endpoint in a browser you should now see `corejson` representation become available as an option. @@ -46,7 +53,7 @@ representation become available as an option. We can also request the schema from the command line, by specifying the desired content type in the `Accept` header. - $ http http://127.0.0.1:8000/ Accept:application/vnd.coreapi+json + $ http http://127.0.0.1:8000/schema/ Accept:application/vnd.coreapi+json HTTP/1.0 200 OK Allow: GET, HEAD, OPTIONS Content-Type: application/vnd.coreapi+json @@ -91,8 +98,8 @@ Now check that it is available on the command line... First we'll load the API schema using the command line client. - $ coreapi get http://127.0.0.1:8000/ - + $ coreapi get http://127.0.0.1:8000/schema/ + snippets: { highlight(pk) list() @@ -150,7 +157,7 @@ Now if we fetch the schema again, we should be able to see the full set of available interactions. $ coreapi reload - Pastebin API "http://127.0.0.1:8000/"> + Pastebin API "http://127.0.0.1:8000/schema/"> snippets: { create(code, [title], [linenos], [language], [style]) destroy(pk) diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index c3b36cd1ed..d3c9d50258 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -6,11 +6,13 @@ from django.utils import six from django.utils.encoding import force_text -from rest_framework import exceptions, serializers +from rest_framework import exceptions, renderers, serializers from rest_framework.compat import ( RegexURLPattern, RegexURLResolver, coreapi, uritemplate, urlparse ) from rest_framework.request import clone_request +from rest_framework.response import Response +from rest_framework.settings import api_settings from rest_framework.views import APIView @@ -92,15 +94,14 @@ def __init__(self, patterns=None, urlconf=None): if patterns is None: if urlconf is None: # Use the default Django URL conf - urls = import_module(settings.ROOT_URLCONF) - patterns = urls.urlpatterns + urlconf = settings.ROOT_URLCONF + + # Load the given URLconf module + if isinstance(urlconf, six.string_types): + urls = import_module(urlconf) else: - # Load the given URLconf module - if isinstance(urlconf, six.string_types): - urls = import_module(urlconf) - else: - urls = urlconf - patterns = urls.urlpatterns + urls = urlconf + patterns = urls.urlpatterns self.patterns = patterns @@ -189,7 +190,8 @@ def __init__(self, title=None, url=None, patterns=None, urlconf=None): if url and not url.endswith('/'): url += '/' - self.endpoint_inspector = self.endpoint_inspector_cls(patterns, urlconf) + self.patterns = patterns + self.urlconf = urlconf self.title = title self.url = url self.endpoints = None @@ -199,7 +201,8 @@ def get_schema(self, request=None): Generate a `coreapi.Document` representing the API schema. """ if self.endpoints is None: - self.endpoints = self.endpoint_inspector.get_api_endpoints() + inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf) + self.endpoints = inspector.get_api_endpoints() links = self.get_links(request) if not links: @@ -425,3 +428,30 @@ def get_keys(self, path, method, view): # Default action, eg "/users/", "/users/{pk}/" return named_path_components + [action] + + +def get_schema_view(title=None, url=None, renderer_classes=None): + """ + Return a schema view. + """ + generator = SchemaGenerator(title=title, url=url) + if renderer_classes is None: + if renderers.BrowsableAPIRenderer in api_settings.DEFAULT_RENDERER_CLASSES: + rclasses = [renderers.CoreJSONRenderer, renderers.BrowsableAPIRenderer] + else: + rclasses = [renderers.CoreJSONRenderer] + else: + rclasses = renderer_classes + + class SchemaView(APIView): + _ignore_model_permissions = True + exclude_from_schema = True + renderer_classes = rclasses + + def get(self, request, *args, **kwargs): + schema = generator.get_schema(request) + if schema is None: + raise exceptions.PermissionDenied() + return Response(schema) + + return SchemaView.as_view() From 7e3a3a4081ad3c77769a1a5f69f40411401a6ca3 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 5 Oct 2016 15:37:25 +0100 Subject: [PATCH 33/43] Added schema descriptions --- docs/api-guide/schemas.md | 8 ++++++ rest_framework/schemas.py | 51 +++++++++++++++++++++++++++++---------- tests/test_schemas.py | 12 ++++----- 3 files changed, 52 insertions(+), 19 deletions(-) diff --git a/docs/api-guide/schemas.md b/docs/api-guide/schemas.md index 57b73addc9..868207dace 100644 --- a/docs/api-guide/schemas.md +++ b/docs/api-guide/schemas.md @@ -242,6 +242,14 @@ You could then either: --- +# Schemas as documentation + +One common usage of API schemas is to use them to build documentation pages. + +The schema generation in REST framework uses docstrings to automatically + +--- + # Alternate schema formats In order to support an alternate schema format, you need to implement a custom renderer diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index d3c9d50258..dc4d8acb79 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -1,3 +1,4 @@ +import re from collections import OrderedDict from importlib import import_module @@ -15,6 +16,8 @@ from rest_framework.settings import api_settings from rest_framework.views import APIView +header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:') + def as_query_fields(items): """ @@ -53,8 +56,7 @@ def insert_into(target, keys, value): def is_custom_action(action): return action not in set([ - 'read', 'retrieve', 'list', - 'create', 'update', 'partial_update', 'delete', 'destroy' + 'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy' ]) @@ -171,17 +173,12 @@ def get_allowed_methods(self, callback): class SchemaGenerator(object): # Map methods onto 'actions' that are the names used in the link layout. default_mapping = { - 'get': 'read', + 'get': 'retrieve', 'post': 'create', 'put': 'update', 'patch': 'partial_update', 'delete': 'destroy', } - # Coerce the following viewset actions into different names. - coerce_actions = { - 'retrieve': 'read', - 'destroy': 'delete' - } endpoint_inspector_cls = EndpointInspector def __init__(self, title=None, url=None, patterns=None, urlconf=None): @@ -283,6 +280,8 @@ def get_link(self, path, method, view): else: encoding = None + description = self.get_description(path, method, view) + if self.url and path.startswith('/'): path = path[1:] @@ -290,9 +289,38 @@ def get_link(self, path, method, view): url=urlparse.urljoin(self.url, path), action=method.lower(), encoding=encoding, - fields=fields + fields=fields, + description=description ) + def get_description(self, path, method, view): + """ + Determine a link description. + + This with either be the class docstring, or a specific section for it. + + For views, use method names, eg `get:` to introduce a section. + For viewsets, use action names, eg `retrieve:` to introduce a section. + + The section names will correspond to the methods on the class. + """ + view_description = view.get_view_description() + lines = [line.strip() for line in view_description.splitlines()] + current_section = '' + sections = {'': ''} + + for line in lines: + if header_regex.match(line): + current_section, seperator, lead = line.partition(':') + sections[current_section] = lead.strip() + else: + sections[current_section] += line + '\n' + + header = getattr(view, 'action', method.lower()) + if header in sections: + return sections[header].strip() + return sections[''].strip() + def get_encoding(self, path, method, view): """ Return the 'encoding' parameter to use for a given endpoint. @@ -401,10 +429,7 @@ def get_keys(self, path, method, view): """ if hasattr(view, 'action'): # Viewsets have explicitly named actions. - if view.action in self.coerce_actions: - action = self.coerce_actions[view.action] - else: - action = view.action + action = view.action else: # Views have no associated action, so we determine one from the method. if is_list_view(path, method, view): diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 5794f968ab..1fc7c218ec 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -94,12 +94,12 @@ def test_anonymous_request(self): action='get' ), 'custom_list_action_multiple_methods': { - 'read': coreapi.Link( + 'retrieve': coreapi.Link( url='/example/custom_list_action_multiple_methods/', action='get' ) }, - 'read': coreapi.Link( + 'retrieve': coreapi.Link( url='/example/{pk}/', action='get', fields=[ @@ -138,7 +138,7 @@ def test_authenticated_request(self): coreapi.Field('b', required=False, location='form') ] ), - 'read': coreapi.Link( + 'retrieve': coreapi.Link( url='/example/{pk}/', action='get', fields=[ @@ -160,7 +160,7 @@ def test_authenticated_request(self): action='get' ), 'custom_list_action_multiple_methods': { - 'read': coreapi.Link( + 'retrieve': coreapi.Link( url='/example/custom_list_action_multiple_methods/', action='get' ), @@ -189,7 +189,7 @@ def test_authenticated_request(self): coreapi.Field('b', required=False, location='form') ] ), - 'delete': coreapi.Link( + 'destroy': coreapi.Link( url='/example/{pk}/', action='delete', fields=[ @@ -249,7 +249,7 @@ def test_schema_for_regular_views(self): action='get', fields=[] ), - 'read': coreapi.Link( + 'retrieve': coreapi.Link( url='/example/{pk}/', action='get', fields=[ From 1084dca22b0b23e87ff37708f8071cfbf4f23bb4 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 5 Oct 2016 16:14:23 +0100 Subject: [PATCH 34/43] Better descriptions for schemas --- docs/api-guide/schemas.md | 55 +++++++++++++++++++++++++++++++++++++++ rest_framework/schemas.py | 21 ++++++++------- tests/test_schemas.py | 4 +++ 3 files changed, 71 insertions(+), 9 deletions(-) diff --git a/docs/api-guide/schemas.md b/docs/api-guide/schemas.md index 868207dace..16d2bbb011 100644 --- a/docs/api-guide/schemas.md +++ b/docs/api-guide/schemas.md @@ -247,6 +247,61 @@ You could then either: One common usage of API schemas is to use them to build documentation pages. The schema generation in REST framework uses docstrings to automatically +populate descriptions in the schema document. + +These descriptions will be based on: + +* The corresponding method docstring if one exists. +* A named section within the class docstring, which can be either single line or multi-line. +* The class docstring. + +## Examples + +An `APIView`, with an explicit method docstring. + + class ListUsernames(APIView): + def get(self, request): + """ + Return a list of all user names in the system. + """ + usernames = [user.username for user in User.objects.all()] + return Response(usernames) + +A `ViewSet`, with an explict action docstring. + + class ListUsernames(ViewSet): + def list(self, request): + """ + Return a list of all user names in the system. + """ + usernames = [user.username for user in User.objects.all()] + return Response(usernames) + +A generic view with sections in the class docstring, using single-line style. + + class UserList(generics.ListCreateAPIView): + """ + get: Create a new user. + post: List all the users. + """ + queryset = User.objects.all() + serializer_class = UserSerializer + permission_classes = (IsAdminUser,) + +A generic viewset with sections in the class docstring, using multi-line style. + + class UserViewSet(viewsets.ModelViewSet): + """ + API endpoint that allows users to be viewed or edited. + + retrieve: + Return a user instance. + + list: + Return all users, ordered by most recently joined. + """ + queryset = User.objects.all().order_by('-date_joined') + serializer_class = UserSerializer --- diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index dc4d8acb79..ad311dcfdf 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -5,7 +5,7 @@ from django.conf import settings from django.contrib.admindocs.views import simplify_regex from django.utils import six -from django.utils.encoding import force_text +from django.utils.encoding import force_text, smart_text from rest_framework import exceptions, renderers, serializers from rest_framework.compat import ( @@ -14,6 +14,7 @@ from rest_framework.request import clone_request from rest_framework.response import Response from rest_framework.settings import api_settings +from rest_framework.utils import formatting from rest_framework.views import APIView header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:') @@ -297,15 +298,17 @@ def get_description(self, path, method, view): """ Determine a link description. - This with either be the class docstring, or a specific section for it. - - For views, use method names, eg `get:` to introduce a section. - For viewsets, use action names, eg `retrieve:` to introduce a section. - - The section names will correspond to the methods on the class. + This will be based on the method docstring if one exists, + or else the class docstring. """ - view_description = view.get_view_description() - lines = [line.strip() for line in view_description.splitlines()] + method_name = getattr(view, 'action', method.lower()) + method_docstring = getattr(view, method_name, None).__doc__ + if method_docstring: + # An explicit docstring on the method or action. + return formatting.dedent(smart_text(method_docstring)) + + description = view.get_view_description() + lines = [line.strip() for line in description.splitlines()] current_section = '' sections = {'': ''} diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 1fc7c218ec..edc602d986 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -46,6 +46,9 @@ class ExampleViewSet(ModelViewSet): @detail_route(methods=['post'], serializer_class=AnotherSerializer) def custom_action(self, request, pk): + """ + A description of custom action. + """ return super(ExampleSerializer, self).retrieve(self, request) @list_route() @@ -149,6 +152,7 @@ def test_authenticated_request(self): url='/example/{pk}/custom_action/', action='post', encoding='application/json', + description='A description of custom action.', fields=[ coreapi.Field('pk', required=True, location='path'), coreapi.Field('c', required=True, location='form'), From 7edee804aa0547339e097bdba4e56300a8955d59 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 6 Oct 2016 15:04:55 +0100 Subject: [PATCH 35/43] Add type annotation to schema generation --- rest_framework/renderers.py | 4 ++++ rest_framework/schemas.py | 27 +++++++++++++++++++++++++-- tests/test_schemas.py | 16 ++++++++-------- 3 files changed, 37 insertions(+), 10 deletions(-) diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 11e9fb9607..5de138d966 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -276,6 +276,10 @@ class HTMLFormRenderer(BaseRenderer): 'base_template': 'input.html', 'input_type': 'number' }, + serializers.FloatField: { + 'base_template': 'input.html', + 'input_type': 'number' + }, serializers.DateTimeField: { 'base_template': 'input.html', 'input_type': 'datetime-local' diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index ad311dcfdf..db41539fdc 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -15,10 +15,25 @@ from rest_framework.response import Response from rest_framework.settings import api_settings from rest_framework.utils import formatting +from rest_framework.utils.field_mapping import ClassLookupDict from rest_framework.views import APIView + header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:') +types_lookup = ClassLookupDict({ + serializers.Field: 'string', + serializers.IntegerField: 'integer', + serializers.FloatField: 'number', + serializers.DecimalField: 'number', + serializers.BooleanField: 'boolean', + serializers.FileField: 'file', + serializers.MultipleChoiceField: 'array', + serializers.ManyRelatedField: 'array', + serializers.Serializer: 'object', + serializers.ListSerializer: 'array' +}) + def as_query_fields(items): """ @@ -372,7 +387,14 @@ def get_serializer_fields(self, path, method, view): serializer = view.get_serializer() if isinstance(serializer, serializers.ListSerializer): - return [coreapi.Field(name='data', location='body', required=True)] + return [ + coreapi.Field( + name='data', + location='body', + required=True, + type='array' + ) + ] if not isinstance(serializer, serializers.Serializer): return [] @@ -388,7 +410,8 @@ def get_serializer_fields(self, path, method, view): name=field.source, location='form', required=required, - description=description + description=description, + type=types_lookup[field] ) fields.append(field) diff --git a/tests/test_schemas.py b/tests/test_schemas.py index edc602d986..c8b40fbb41 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -137,8 +137,8 @@ def test_authenticated_request(self): action='post', encoding='application/json', fields=[ - coreapi.Field('a', required=True, location='form', description='A field description'), - coreapi.Field('b', required=False, location='form') + coreapi.Field('a', required=True, location='form', type='string', description='A field description'), + coreapi.Field('b', required=False, location='form', type='string') ] ), 'retrieve': coreapi.Link( @@ -155,8 +155,8 @@ def test_authenticated_request(self): description='A description of custom action.', fields=[ coreapi.Field('pk', required=True, location='path'), - coreapi.Field('c', required=True, location='form'), - coreapi.Field('d', required=False, location='form'), + coreapi.Field('c', required=True, location='form', type='string'), + coreapi.Field('d', required=False, location='form', type='string'), ] ), 'custom_list_action': coreapi.Link( @@ -179,8 +179,8 @@ def test_authenticated_request(self): encoding='application/json', fields=[ coreapi.Field('pk', required=True, location='path'), - coreapi.Field('a', required=True, location='form', description='A field description'), - coreapi.Field('b', required=False, location='form') + coreapi.Field('a', required=True, location='form', type='string', description='A field description'), + coreapi.Field('b', required=False, location='form', type='string') ] ), 'partial_update': coreapi.Link( @@ -189,8 +189,8 @@ def test_authenticated_request(self): encoding='application/json', fields=[ coreapi.Field('pk', required=True, location='path'), - coreapi.Field('a', required=False, location='form', description='A field description'), - coreapi.Field('b', required=False, location='form') + coreapi.Field('a', required=False, location='form', type='string', description='A field description'), + coreapi.Field('b', required=False, location='form', type='string') ] ), 'destroy': coreapi.Link( From b44ab76d2cdc382026e748f94b6a2d5450d9fcd0 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 6 Oct 2016 16:22:03 +0100 Subject: [PATCH 36/43] Coerce schema 'pk' in path to actual field name --- rest_framework/schemas.py | 20 ++++++++++++++++++++ tests/test_schemas.py | 32 ++++++++++++++++---------------- 2 files changed, 36 insertions(+), 16 deletions(-) diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index db41539fdc..48fb2a3924 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -16,6 +16,7 @@ from rest_framework.settings import api_settings from rest_framework.utils import formatting from rest_framework.utils.field_mapping import ClassLookupDict +from rest_framework.utils.model_meta import _get_pk from rest_framework.views import APIView @@ -35,6 +36,11 @@ }) +def get_pk_name(model): + meta = model._meta.concrete_model._meta + return _get_pk(meta).name + + def as_query_fields(items): """ Take a list of Fields and plain strings. @@ -196,6 +202,9 @@ class SchemaGenerator(object): 'delete': 'destroy', } endpoint_inspector_cls = EndpointInspector + # 'pk' isn't great as an externally exposed name for an identifier, + # so by default we prefer to use the actual model field name for schemas. + coerce_pk = True def __init__(self, title=None, url=None, patterns=None, urlconf=None): assert coreapi, '`coreapi` must be installed for schema support.' @@ -230,6 +239,7 @@ def get_links(self, request=None): links = OrderedDict() for path, method, callback in self.endpoints: view = self.create_view(callback, method, request) + path = self.coerce_path(path, method, view) if not self.should_include_view(path, method, view): continue link = self.get_link(path, method, view) @@ -280,6 +290,16 @@ def should_include_view(self, path, method, view): return False return True + def coerce_path(self, path, method, view): + if not self.coerce_pk or '{pk}' not in path: + return path + model = getattr(getattr(view, 'queryset', None), 'model', None) + if model: + field_name = get_pk_name(model) + else: + field_name = 'id' + return path.replace('{pk}', '{%s}' % field_name) + # Methods for generating each individual `Link` instance... def get_link(self, path, method, view): diff --git a/tests/test_schemas.py b/tests/test_schemas.py index c8b40fbb41..1914994a7c 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -103,10 +103,10 @@ def test_anonymous_request(self): ) }, 'retrieve': coreapi.Link( - url='/example/{pk}/', + url='/example/{id}/', action='get', fields=[ - coreapi.Field('pk', required=True, location='path') + coreapi.Field('id', required=True, location='path') ] ) } @@ -142,19 +142,19 @@ def test_authenticated_request(self): ] ), 'retrieve': coreapi.Link( - url='/example/{pk}/', + url='/example/{id}/', action='get', fields=[ - coreapi.Field('pk', required=True, location='path') + coreapi.Field('id', required=True, location='path') ] ), 'custom_action': coreapi.Link( - url='/example/{pk}/custom_action/', + url='/example/{id}/custom_action/', action='post', encoding='application/json', description='A description of custom action.', fields=[ - coreapi.Field('pk', required=True, location='path'), + coreapi.Field('id', required=True, location='path'), coreapi.Field('c', required=True, location='form', type='string'), coreapi.Field('d', required=False, location='form', type='string'), ] @@ -174,30 +174,30 @@ def test_authenticated_request(self): ) }, 'update': coreapi.Link( - url='/example/{pk}/', + url='/example/{id}/', action='put', encoding='application/json', fields=[ - coreapi.Field('pk', required=True, location='path'), + coreapi.Field('id', required=True, location='path'), coreapi.Field('a', required=True, location='form', type='string', description='A field description'), coreapi.Field('b', required=False, location='form', type='string') ] ), 'partial_update': coreapi.Link( - url='/example/{pk}/', + url='/example/{id}/', action='patch', encoding='application/json', fields=[ - coreapi.Field('pk', required=True, location='path'), + coreapi.Field('id', required=True, location='path'), coreapi.Field('a', required=False, location='form', type='string', description='A field description'), coreapi.Field('b', required=False, location='form', type='string') ] ), 'destroy': coreapi.Link( - url='/example/{pk}/', + url='/example/{id}/', action='delete', fields=[ - coreapi.Field('pk', required=True, location='path') + coreapi.Field('id', required=True, location='path') ] ) } @@ -254,18 +254,18 @@ def test_schema_for_regular_views(self): fields=[] ), 'retrieve': coreapi.Link( - url='/example/{pk}/', + url='/example/{id}/', action='get', fields=[ - coreapi.Field('pk', required=True, location='path') + coreapi.Field('id', required=True, location='path') ] ), 'sub': { 'list': coreapi.Link( - url='/example/{pk}/sub/', + url='/example/{id}/sub/', action='get', fields=[ - coreapi.Field('pk', required=True, location='path') + coreapi.Field('id', required=True, location='path') ] ) } From 072442025c79a4f8c22855a456fd7c2d8f277ce2 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 6 Oct 2016 16:59:10 +0100 Subject: [PATCH 37/43] Deprecations move into assertion errors --- rest_framework/serializers.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 28f70bd406..7e99d40b31 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -13,7 +13,6 @@ from __future__ import unicode_literals import traceback -import warnings from django.db import models from django.db.models import DurationField as ModelDurationField @@ -1016,16 +1015,14 @@ def get_field_names(self, declared_fields, info): ) ) - if fields is None and exclude is None: - warnings.warn( - "Creating a ModelSerializer without either the 'fields' " - "attribute or the 'exclude' attribute is deprecated " - "since 3.3.0. Add an explicit fields = '__all__' to the " - "{serializer_class} serializer.".format( - serializer_class=self.__class__.__name__ - ), - DeprecationWarning - ) + assert not (fields is None and exclude is None), ( + "Creating a ModelSerializer without either the 'fields' attribute " + "or the 'exclude' attribute has been deprecated since 3.3.0, " + "and is now disallowed. Add an explicit fields = '__all__' to the " + "{serializer_class} serializer.".format( + serializer_class=self.__class__.__name__ + ), + ) if fields == ALL_FIELDS: fields = None From 3eb7fe6f64a507a9c3879aa6f9e6c79c4ddaa691 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 7 Oct 2016 12:09:41 +0100 Subject: [PATCH 38/43] Use get_schema_view in tests --- tests/test_schemas.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 1914994a7c..59ff8f3387 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -7,7 +7,7 @@ from rest_framework.compat import coreapi from rest_framework.decorators import detail_route, list_route from rest_framework.routers import DefaultRouter -from rest_framework.schemas import SchemaGenerator +from rest_framework.schemas import SchemaGenerator, get_schema_view from rest_framework.test import APIClient from rest_framework.views import APIView from rest_framework.viewsets import ModelViewSet @@ -65,9 +65,16 @@ def get_serializer(self, *args, **kwargs): return super(ExampleViewSet, self).get_serializer(*args, **kwargs) -router = DefaultRouter(schema_title='Example API' if coreapi else None) +if coreapi: + schema_view = get_schema_view(title='Example API') +else: + def schema_view(request): + pass + +router = DefaultRouter() router.register('example', ExampleViewSet, base_name='example') urlpatterns = [ + url(r'^$', schema_view), url(r'^', include(router.urls)) ] From 5858168d55653f54bb8b6d703c0e520ab5d05721 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 7 Oct 2016 12:56:05 +0100 Subject: [PATCH 39/43] Updte CoreJSON media type --- rest_framework/renderers.py | 2 +- tests/test_schemas.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 5de138d966..97984daf98 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -813,7 +813,7 @@ def render(self, data, accepted_media_type=None, renderer_context=None): class CoreJSONRenderer(BaseRenderer): - media_type = 'application/vnd.coreapi+json' + media_type = 'application/coreapi+json' charset = None format = 'corejson' diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 59ff8f3387..330f6f4d56 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -84,7 +84,7 @@ def schema_view(request): class TestRouterGeneratedSchema(TestCase): def test_anonymous_request(self): client = APIClient() - response = client.get('/', HTTP_ACCEPT='application/vnd.coreapi+json') + response = client.get('/', HTTP_ACCEPT='application/coreapi+json') self.assertEqual(response.status_code, 200) expected = coreapi.Document( url='', @@ -124,7 +124,7 @@ def test_anonymous_request(self): def test_authenticated_request(self): client = APIClient() client.force_authenticate(MockUser()) - response = client.get('/', HTTP_ACCEPT='application/vnd.coreapi+json') + response = client.get('/', HTTP_ACCEPT='application/coreapi+json') self.assertEqual(response.status_code, 200) expected = coreapi.Document( url='', From 69b4acd88ff47a431863bd6dc08f65def14cb1e1 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 7 Oct 2016 15:13:11 +0100 Subject: [PATCH 40/43] Handle schema structure correctly when path prefixs exist. Closes #4401 --- rest_framework/schemas.py | 66 +++++++++++++++++++++++++++++++++------ tests/test_schemas.py | 53 +++++++++++++++++++++++++++++++ 2 files changed, 110 insertions(+), 9 deletions(-) diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index 48fb2a3924..ac81782a92 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -1,3 +1,4 @@ +import os import re from collections import OrderedDict from importlib import import_module @@ -237,18 +238,63 @@ def get_links(self, request=None): included in the API schema. """ links = OrderedDict() + + # Generate (path, method, view) given (path, method, callback). + paths = [] + view_endpoints = [] for path, method, callback in self.endpoints: view = self.create_view(callback, method, request) + if getattr(view, 'exclude_from_schema', False): + continue path = self.coerce_path(path, method, view) - if not self.should_include_view(path, method, view): + paths.append(path) + view_endpoints.append((path, method, view)) + + # Only generate the path prefix for paths that will be included + prefix = self.determine_path_prefix(paths) + + for path, method, view in view_endpoints: + if not self.has_view_permissions(path, method, view): continue link = self.get_link(path, method, view) - keys = self.get_keys(path, method, view) + subpath = path[len(prefix):] + keys = self.get_keys(subpath, method, view) insert_into(links, keys, link) return links # Methods used when we generate a view instance from the raw callback... + def determine_path_prefix(self, paths): + """ + Given a list of all paths, return the common prefix which should be + discounted when generating a schema structure. + + This will be the longest common string that does not include that last + component of the URL, or the last component before a path parameter. + + For example: + + /api/v1/users/ + /api/v1/users/{pk}/ + + The path prefix is '/api/v1/' + """ + prefixes = [] + for path in paths: + components = path.strip('/').split('/') + initial_components = [] + for component in components: + if '{' in component: + break + initial_components.append(component) + prefix = '/'.join(initial_components[:-1]) + if not prefix: + # We can just break early in the case that there's at least + # one URL that doesn't have a path prefix. + return '/' + prefixes.append('/' + prefix + '/') + return os.path.commonprefix(prefixes) + def create_view(self, callback, method, request=None): """ Given a callback, return an actual view instance. @@ -274,13 +320,10 @@ def create_view(self, callback, method, request=None): return view - def should_include_view(self, path, method, view): + def has_view_permissions(self, path, method, view): """ Return `True` if the incoming request has the correct view permissions. """ - if getattr(view, 'exclude_from_schema', False): - return False - if view.request is None: return True @@ -291,6 +334,11 @@ def should_include_view(self, path, method, view): return True def coerce_path(self, path, method, view): + """ + Coerce {pk} path arguments into the name of the model field, + where possible. This is cleaner for an external representation. + (Ie. "this is an identifier", not "this is a database primary key") + """ if not self.coerce_pk or '{pk}' not in path: return path model = getattr(getattr(view, 'queryset', None), 'model', None) @@ -461,7 +509,7 @@ def get_filter_fields(self, path, method, view): # Method for generating the link layout.... - def get_keys(self, path, method, view): + def get_keys(self, subpath, method, view): """ Return a list of keys that should be used to layout a link within the schema document. @@ -478,14 +526,14 @@ def get_keys(self, path, method, view): action = view.action else: # Views have no associated action, so we determine one from the method. - if is_list_view(path, method, view): + if is_list_view(subpath, method, view): action = 'list' else: action = self.default_mapping[method.lower()] named_path_components = [ component for component - in path.strip('/').split('/') + in subpath.strip('/').split('/') if '{' not in component ] diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 330f6f4d56..0a422f078f 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -280,3 +280,56 @@ def test_schema_for_regular_views(self): } ) self.assertEqual(schema, expected) + + +@unittest.skipUnless(coreapi, 'coreapi is not installed') +class TestSchemaGeneratorNotAtRoot(TestCase): + def setUp(self): + self.patterns = [ + url('^api/v1/example/?$', ExampleListView.as_view()), + url('^api/v1/example/(?P\d+)/?$', ExampleDetailView.as_view()), + url('^api/v1/example/(?P\d+)/sub/?$', ExampleDetailView.as_view()), + ] + + def test_schema_for_regular_views(self): + """ + Ensure that schema generation with an API that is not at the URL + root continues to use correct structure for link keys. + """ + generator = SchemaGenerator(title='Example API', patterns=self.patterns) + schema = generator.get_schema() + expected = coreapi.Document( + url='', + title='Example API', + content={ + 'example': { + 'create': coreapi.Link( + url='/api/v1/example/', + action='post', + fields=[] + ), + 'list': coreapi.Link( + url='/api/v1/example/', + action='get', + fields=[] + ), + 'retrieve': coreapi.Link( + url='/api/v1/example/{id}/', + action='get', + fields=[ + coreapi.Field('id', required=True, location='path') + ] + ), + 'sub': { + 'list': coreapi.Link( + url='/api/v1/example/{id}/sub/', + action='get', + fields=[ + coreapi.Field('id', required=True, location='path') + ] + ) + } + } + } + ) + self.assertEqual(schema, expected) From fcf932f118c718d2bc3e0a0a3b7fe6e26899cc60 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 7 Oct 2016 16:51:26 +0100 Subject: [PATCH 41/43] Add PendingDeprecation to Router schema generation. --- rest_framework/routers.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 9ef8d11f06..7a2d981a3f 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -16,6 +16,7 @@ from __future__ import unicode_literals import itertools +import warnings from collections import OrderedDict, namedtuple from django.conf.urls import url @@ -285,6 +286,12 @@ class DefaultRouter(SimpleRouter): default_schema_renderers = [renderers.CoreJSONRenderer, BrowsableAPIRenderer] def __init__(self, *args, **kwargs): + if 'schema_title' in kwargs: + warnings.warn( + "Including a schema directly via a router is now pending " + "deprecation. Use `get_schema_view()` instead.", + PendingDeprecationWarning + ) if 'schema_renderers' in kwargs: assert 'schema_title' in kwargs, 'Missing "schema_title" argument.' if 'schema_url' in kwargs: From 18aebbbe01a86d7d6d7638364a428c67f8374fd8 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 10 Oct 2016 11:48:24 +0100 Subject: [PATCH 42/43] Added SCHEMA_COERCE_PATH_PK and SCHEMA_COERCE_METHOD_NAMES --- docs/api-guide/settings.md | 22 ++++++++++++++++ docs/tutorial/1-serialization.md | 14 +++++----- docs/tutorial/2-requests-and-responses.md | 8 +++--- docs/tutorial/3-class-based-views.md | 18 ++++++------- .../4-authentication-and-permissions.md | 4 +-- .../5-relationships-and-hyperlinked-apis.md | 14 +++++----- docs/tutorial/6-viewsets-and-routers.md | 6 ++--- .../7-schemas-and-client-libraries.md | 26 +++++++++---------- rest_framework/schemas.py | 24 ++++++++++++++--- rest_framework/settings.py | 7 +++++ tests/test_schemas.py | 14 +++++----- 11 files changed, 102 insertions(+), 55 deletions(-) diff --git a/docs/api-guide/settings.md b/docs/api-guide/settings.md index 67d317839f..58ceeeeb41 100644 --- a/docs/api-guide/settings.md +++ b/docs/api-guide/settings.md @@ -234,6 +234,28 @@ Default: --- +## Schema generation controls + +#### SCHEMA_COERCE_PATH_PK + +If set, this maps the `'pk'` identifier in the URL conf onto the actual field +name when generating a schema path parameter. Typically this will be `'id'`. +This gives a more suitable representation as "primary key" is an implementation +detail, wheras "identifier" is a more general concept. + +Default: `True` + +#### SCHEMA_COERCE_METHOD_NAMES + +If set, this is used to map internal viewset method names onto external action +names used in the schema generation. This allows us to generate names that +are more suitable for an external representation than those that are used +internally in the codebase. + +Default: `{'retrieve': 'read', 'destroy': 'delete'}` + +--- + ## Content type controls #### URL_FORMAT_OVERRIDE diff --git a/docs/tutorial/1-serialization.md b/docs/tutorial/1-serialization.md index 87856e0370..434072e11d 100644 --- a/docs/tutorial/1-serialization.md +++ b/docs/tutorial/1-serialization.md @@ -88,7 +88,7 @@ The first thing we need to get started on our Web API is to provide a way of ser class SnippetSerializer(serializers.Serializer): - pk = serializers.IntegerField(read_only=True) + id = serializers.IntegerField(read_only=True) title = serializers.CharField(required=False, allow_blank=True, max_length=100) code = serializers.CharField(style={'base_template': 'textarea.html'}) linenos = serializers.BooleanField(required=False) @@ -144,13 +144,13 @@ We've now got a few snippet instances to play with. Let's take a look at serial serializer = SnippetSerializer(snippet) serializer.data - # {'pk': 2, 'title': u'', 'code': u'print "hello, world"\n', 'linenos': False, 'language': u'python', 'style': u'friendly'} + # {'id': 2, 'title': u'', 'code': u'print "hello, world"\n', 'linenos': False, 'language': u'python', 'style': u'friendly'} At this point we've translated the model instance into Python native datatypes. To finalize the serialization process we render the data into `json`. content = JSONRenderer().render(serializer.data) content - # '{"pk": 2, "title": "", "code": "print \\"hello, world\\"\\n", "linenos": false, "language": "python", "style": "friendly"}' + # '{"id": 2, "title": "", "code": "print \\"hello, world\\"\\n", "linenos": false, "language": "python", "style": "friendly"}' Deserialization is similar. First we parse a stream into Python native datatypes... @@ -175,7 +175,7 @@ We can also serialize querysets instead of model instances. To do so we simply serializer = SnippetSerializer(Snippet.objects.all(), many=True) serializer.data - # [OrderedDict([('pk', 1), ('title', u''), ('code', u'foo = "bar"\n'), ('linenos', False), ('language', 'python'), ('style', 'friendly')]), OrderedDict([('pk', 2), ('title', u''), ('code', u'print "hello, world"\n'), ('linenos', False), ('language', 'python'), ('style', 'friendly')]), OrderedDict([('pk', 3), ('title', u''), ('code', u'print "hello, world"'), ('linenos', False), ('language', 'python'), ('style', 'friendly')])] + # [OrderedDict([('id', 1), ('title', u''), ('code', u'foo = "bar"\n'), ('linenos', False), ('language', 'python'), ('style', 'friendly')]), OrderedDict([('id', 2), ('title', u''), ('code', u'print "hello, world"\n'), ('linenos', False), ('language', 'python'), ('style', 'friendly')]), OrderedDict([('id', 3), ('title', u''), ('code', u'print "hello, world"'), ('linenos', False), ('language', 'python'), ('style', 'friendly')])] ## Using ModelSerializers @@ -259,12 +259,12 @@ Note that because we want to be able to POST to this view from clients that won' We'll also need a view which corresponds to an individual snippet, and can be used to retrieve, update or delete the snippet. @csrf_exempt - def snippet_detail(request, pk): + def snippet_detail(request, id): """ Retrieve, update or delete a code snippet. """ try: - snippet = Snippet.objects.get(pk=pk) + snippet = Snippet.objects.get(id=id) except Snippet.DoesNotExist: return HttpResponse(status=404) @@ -291,7 +291,7 @@ Finally we need to wire these views up. Create the `snippets/urls.py` file: urlpatterns = [ url(r'^snippets/$', views.snippet_list), - url(r'^snippets/(?P[0-9]+)/$', views.snippet_detail), + url(r'^snippets/(?P[0-9]+)/$', views.snippet_detail), ] We also need to wire up the root urlconf, in the `tutorial/urls.py` file, to include our snippet app's URLs. diff --git a/docs/tutorial/2-requests-and-responses.md b/docs/tutorial/2-requests-and-responses.md index 5c020a1f70..4aa0062a33 100644 --- a/docs/tutorial/2-requests-and-responses.md +++ b/docs/tutorial/2-requests-and-responses.md @@ -66,12 +66,12 @@ Our instance view is an improvement over the previous example. It's a little mo Here is the view for an individual snippet, in the `views.py` module. @api_view(['GET', 'PUT', 'DELETE']) - def snippet_detail(request, pk): + def snippet_detail(request, id): """ Retrieve, update or delete a snippet instance. """ try: - snippet = Snippet.objects.get(pk=pk) + snippet = Snippet.objects.get(id=id) except Snippet.DoesNotExist: return Response(status=status.HTTP_404_NOT_FOUND) @@ -104,7 +104,7 @@ Start by adding a `format` keyword argument to both of the views, like so. and - def snippet_detail(request, pk, format=None): + def snippet_detail(request, id, format=None): Now update the `urls.py` file slightly, to append a set of `format_suffix_patterns` in addition to the existing URLs. @@ -114,7 +114,7 @@ Now update the `urls.py` file slightly, to append a set of `format_suffix_patter urlpatterns = [ url(r'^snippets/$', views.snippet_list), - url(r'^snippets/(?P[0-9]+)$', views.snippet_detail), + url(r'^snippets/(?P[0-9]+)$', views.snippet_detail), ] urlpatterns = format_suffix_patterns(urlpatterns) diff --git a/docs/tutorial/3-class-based-views.md b/docs/tutorial/3-class-based-views.md index f018666f58..6303994cdc 100644 --- a/docs/tutorial/3-class-based-views.md +++ b/docs/tutorial/3-class-based-views.md @@ -36,27 +36,27 @@ So far, so good. It looks pretty similar to the previous case, but we've got be """ Retrieve, update or delete a snippet instance. """ - def get_object(self, pk): + def get_object(self, id): try: - return Snippet.objects.get(pk=pk) + return Snippet.objects.get(id=id) except Snippet.DoesNotExist: raise Http404 - def get(self, request, pk, format=None): - snippet = self.get_object(pk) + def get(self, request, id, format=None): + snippet = self.get_object(id) serializer = SnippetSerializer(snippet) return Response(serializer.data) - def put(self, request, pk, format=None): - snippet = self.get_object(pk) + def put(self, request, id, format=None): + snippet = self.get_object(id) serializer = SnippetSerializer(snippet, data=request.data) if serializer.is_valid(): serializer.save() return Response(serializer.data) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - def delete(self, request, pk, format=None): - snippet = self.get_object(pk) + def delete(self, request, id, format=None): + snippet = self.get_object(id) snippet.delete() return Response(status=status.HTTP_204_NO_CONTENT) @@ -70,7 +70,7 @@ We'll also need to refactor our `urls.py` slightly now we're using class-based v urlpatterns = [ url(r'^snippets/$', views.SnippetList.as_view()), - url(r'^snippets/(?P[0-9]+)/$', views.SnippetDetail.as_view()), + url(r'^snippets/(?P[0-9]+)/$', views.SnippetDetail.as_view()), ] urlpatterns = format_suffix_patterns(urlpatterns) diff --git a/docs/tutorial/4-authentication-and-permissions.md b/docs/tutorial/4-authentication-and-permissions.md index d69c38552d..958f9d3f03 100644 --- a/docs/tutorial/4-authentication-and-permissions.md +++ b/docs/tutorial/4-authentication-and-permissions.md @@ -88,7 +88,7 @@ Make sure to also import the `UserSerializer` class Finally we need to add those views into the API, by referencing them from the URL conf. Add the following to the patterns in `urls.py`. url(r'^users/$', views.UserList.as_view()), - url(r'^users/(?P[0-9]+)/$', views.UserDetail.as_view()), + url(r'^users/(?P[0-9]+)/$', views.UserDetail.as_view()), ## Associating Snippets with Users @@ -150,7 +150,7 @@ The `r'^api-auth/'` part of pattern can actually be whatever URL you want to use Now if you open up the browser again and refresh the page you'll see a 'Login' link in the top right of the page. If you log in as one of the users you created earlier, you'll be able to create code snippets again. -Once you've created a few code snippets, navigate to the '/users/' endpoint, and notice that the representation includes a list of the snippet pks that are associated with each user, in each user's 'snippets' field. +Once you've created a few code snippets, navigate to the '/users/' endpoint, and notice that the representation includes a list of the snippet ids that are associated with each user, in each user's 'snippets' field. ## Object level permissions diff --git a/docs/tutorial/5-relationships-and-hyperlinked-apis.md b/docs/tutorial/5-relationships-and-hyperlinked-apis.md index 8cda09c62b..9fb6c53e0a 100644 --- a/docs/tutorial/5-relationships-and-hyperlinked-apis.md +++ b/docs/tutorial/5-relationships-and-hyperlinked-apis.md @@ -48,7 +48,7 @@ We'll add a url pattern for our new API root in `snippets/urls.py`: And then add a url pattern for the snippet highlights: - url(r'^snippets/(?P[0-9]+)/highlight/$', views.SnippetHighlight.as_view()), + url(r'^snippets/(?P[0-9]+)/highlight/$', views.SnippetHighlight.as_view()), ## Hyperlinking our API @@ -67,7 +67,7 @@ In this case we'd like to use a hyperlinked style between entities. In order to The `HyperlinkedModelSerializer` has the following differences from `ModelSerializer`: -* It does not include the `pk` field by default. +* It does not include the `id` field by default. * It includes a `url` field, using `HyperlinkedIdentityField`. * Relationships use `HyperlinkedRelatedField`, instead of `PrimaryKeyRelatedField`. @@ -80,7 +80,7 @@ We can easily re-write our existing serializers to use hyperlinking. In your `sn class Meta: model = Snippet - fields = ('url', 'pk', 'highlight', 'owner', + fields = ('url', 'id', 'highlight', 'owner', 'title', 'code', 'linenos', 'language', 'style') @@ -89,7 +89,7 @@ We can easily re-write our existing serializers to use hyperlinking. In your `sn class Meta: model = User - fields = ('url', 'pk', 'username', 'snippets') + fields = ('url', 'id', 'username', 'snippets') Notice that we've also added a new `'highlight'` field. This field is of the same type as the `url` field, except that it points to the `'snippet-highlight'` url pattern, instead of the `'snippet-detail'` url pattern. @@ -116,16 +116,16 @@ After adding all those names into our URLconf, our final `snippets/urls.py` file url(r'^snippets/$', views.SnippetList.as_view(), name='snippet-list'), - url(r'^snippets/(?P[0-9]+)/$', + url(r'^snippets/(?P[0-9]+)/$', views.SnippetDetail.as_view(), name='snippet-detail'), - url(r'^snippets/(?P[0-9]+)/highlight/$', + url(r'^snippets/(?P[0-9]+)/highlight/$', views.SnippetHighlight.as_view(), name='snippet-highlight'), url(r'^users/$', views.UserList.as_view(), name='user-list'), - url(r'^users/(?P[0-9]+)/$', + url(r'^users/(?P[0-9]+)/$', views.UserDetail.as_view(), name='user-detail') ]) diff --git a/docs/tutorial/6-viewsets-and-routers.md b/docs/tutorial/6-viewsets-and-routers.md index 00152cc178..6e13210939 100644 --- a/docs/tutorial/6-viewsets-and-routers.md +++ b/docs/tutorial/6-viewsets-and-routers.md @@ -92,10 +92,10 @@ Now that we've bound our resources into concrete views, we can register the view urlpatterns = format_suffix_patterns([ url(r'^$', api_root), url(r'^snippets/$', snippet_list, name='snippet-list'), - url(r'^snippets/(?P[0-9]+)/$', snippet_detail, name='snippet-detail'), - url(r'^snippets/(?P[0-9]+)/highlight/$', snippet_highlight, name='snippet-highlight'), + url(r'^snippets/(?P[0-9]+)/$', snippet_detail, name='snippet-detail'), + url(r'^snippets/(?P[0-9]+)/highlight/$', snippet_highlight, name='snippet-highlight'), url(r'^users/$', user_list, name='user-list'), - url(r'^users/(?P[0-9]+)/$', user_detail, name='user-detail') + url(r'^users/(?P[0-9]+)/$', user_detail, name='user-detail') ]) ## Using Routers diff --git a/docs/tutorial/7-schemas-and-client-libraries.md b/docs/tutorial/7-schemas-and-client-libraries.md index 3d4d2c9410..705b79da6b 100644 --- a/docs/tutorial/7-schemas-and-client-libraries.md +++ b/docs/tutorial/7-schemas-and-client-libraries.md @@ -101,13 +101,13 @@ First we'll load the API schema using the command line client. $ coreapi get http://127.0.0.1:8000/schema/ snippets: { - highlight(pk) + highlight(id) list() - retrieve(pk) + read(id) } users: { list() - retrieve(pk) + read(id) } We haven't authenticated yet, so right now we're only able to see the read only @@ -119,7 +119,7 @@ Let's try listing the existing snippets, using the command line client: [ { "url": "http://127.0.0.1:8000/snippets/1/", - "pk": 1, + "id": 1, "highlight": "http://127.0.0.1:8000/snippets/1/highlight/", "owner": "lucy", "title": "Example", @@ -133,7 +133,7 @@ Let's try listing the existing snippets, using the command line client: Some of the API endpoints require named parameters. For example, to get back the highlight HTML for a particular snippet we need to provide an id. - $ coreapi action snippets highlight --param pk=1 + $ coreapi action snippets highlight --param id=1 @@ -160,16 +160,16 @@ set of available interactions. Pastebin API "http://127.0.0.1:8000/schema/"> snippets: { create(code, [title], [linenos], [language], [style]) - destroy(pk) - highlight(pk) + delete(id) + highlight(id) list() - partial_update(pk, [title], [code], [linenos], [language], [style]) - retrieve(pk) - update(pk, code, [title], [linenos], [language], [style]) + partial_update(id, [title], [code], [linenos], [language], [style]) + read(id) + update(id, code, [title], [linenos], [language], [style]) } users: { list() - retrieve(pk) + read(id) } We're now able to interact with these endpoints. For example, to create a new @@ -178,7 +178,7 @@ snippet: $ coreapi action snippets create --param title="Example" --param code="print('hello, world')" { "url": "http://127.0.0.1:8000/snippets/7/", - "pk": 7, + "id": 7, "highlight": "http://127.0.0.1:8000/snippets/7/highlight/", "owner": "lucy", "title": "Example", @@ -190,7 +190,7 @@ snippet: And to delete a snippet: - $ coreapi action snippets destroy --param pk=7 + $ coreapi action snippets delete --param id=7 As well as the command line client, developers can also interact with your API using client libraries. The Python client library is the first of these diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index ac81782a92..12119adfc9 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -194,7 +194,7 @@ def get_allowed_methods(self, callback): class SchemaGenerator(object): - # Map methods onto 'actions' that are the names used in the link layout. + # Map HTTP methods onto actions. default_mapping = { 'get': 'retrieve', 'post': 'create', @@ -203,9 +203,16 @@ class SchemaGenerator(object): 'delete': 'destroy', } endpoint_inspector_cls = EndpointInspector + + # Map the method names we use for viewset actions onto external schema names. + # These give us names that are more suitable for the external representation. + # Set by 'SCHEMA_COERCE_METHOD_NAMES'. + coerce_method_names = None + # 'pk' isn't great as an externally exposed name for an identifier, # so by default we prefer to use the actual model field name for schemas. - coerce_pk = True + # Set by 'SCHEMA_COERCE_PATH_PK'. + coerce_path_pk = None def __init__(self, title=None, url=None, patterns=None, urlconf=None): assert coreapi, '`coreapi` must be installed for schema support.' @@ -213,6 +220,9 @@ def __init__(self, title=None, url=None, patterns=None, urlconf=None): if url and not url.endswith('/'): url += '/' + self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES + self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK + self.patterns = patterns self.urlconf = urlconf self.title = title @@ -339,7 +349,7 @@ def coerce_path(self, path, method, view): where possible. This is cleaner for an external representation. (Ie. "this is an identifier", not "this is a database primary key") """ - if not self.coerce_pk or '{pk}' not in path: + if not self.coerce_path_pk or '{pk}' not in path: return path model = getattr(getattr(view, 'queryset', None), 'model', None) if model: @@ -405,6 +415,9 @@ def get_description(self, path, method, view): header = getattr(view, 'action', method.lower()) if header in sections: return sections[header].strip() + if header in self.coerce_method_names: + if self.coerce_method_names[header] in sections: + return sections[self.coerce_method_names[header]].strip() return sections[''].strip() def get_encoding(self, path, method, view): @@ -541,10 +554,15 @@ def get_keys(self, subpath, method, view): # Custom action, eg "/users/{pk}/activate/", "/users/active/" if len(view.action_map) > 1: action = self.default_mapping[method.lower()] + if action in self.coerce_method_names: + action = self.coerce_method_names[action] return named_path_components + [action] else: return named_path_components[:-1] + [action] + if action in self.coerce_method_names: + action = self.coerce_method_names[action] + # Default action, eg "/users/", "/users/{pk}/" return named_path_components + [action] diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 89e27e743e..6d9ed23559 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -115,6 +115,13 @@ # Browseable API 'HTML_SELECT_CUTOFF': 1000, 'HTML_SELECT_CUTOFF_TEXT': "More than {count} items...", + + # Schemas + 'SCHEMA_COERCE_PATH_PK': True, + 'SCHEMA_COERCE_METHOD_NAMES': { + 'retrieve': 'read', + 'destroy': 'delete' + }, } diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 0a422f078f..c43fc1effa 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -104,12 +104,12 @@ def test_anonymous_request(self): action='get' ), 'custom_list_action_multiple_methods': { - 'retrieve': coreapi.Link( + 'read': coreapi.Link( url='/example/custom_list_action_multiple_methods/', action='get' ) }, - 'retrieve': coreapi.Link( + 'read': coreapi.Link( url='/example/{id}/', action='get', fields=[ @@ -148,7 +148,7 @@ def test_authenticated_request(self): coreapi.Field('b', required=False, location='form', type='string') ] ), - 'retrieve': coreapi.Link( + 'read': coreapi.Link( url='/example/{id}/', action='get', fields=[ @@ -171,7 +171,7 @@ def test_authenticated_request(self): action='get' ), 'custom_list_action_multiple_methods': { - 'retrieve': coreapi.Link( + 'read': coreapi.Link( url='/example/custom_list_action_multiple_methods/', action='get' ), @@ -200,7 +200,7 @@ def test_authenticated_request(self): coreapi.Field('b', required=False, location='form', type='string') ] ), - 'destroy': coreapi.Link( + 'delete': coreapi.Link( url='/example/{id}/', action='delete', fields=[ @@ -260,7 +260,7 @@ def test_schema_for_regular_views(self): action='get', fields=[] ), - 'retrieve': coreapi.Link( + 'read': coreapi.Link( url='/example/{id}/', action='get', fields=[ @@ -313,7 +313,7 @@ def test_schema_for_regular_views(self): action='get', fields=[] ), - 'retrieve': coreapi.Link( + 'read': coreapi.Link( url='/api/v1/example/{id}/', action='get', fields=[ From 3f3213b5d04149cdf852a36397f4626502862182 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 10 Oct 2016 12:51:19 +0100 Subject: [PATCH 43/43] Renamed and documented 'get_schema_fields' interface. --- docs/api-guide/filtering.md | 6 ++++++ docs/api-guide/pagination.md | 6 ++++++ rest_framework/filters.py | 28 +++++++++++++++++++--------- rest_framework/pagination.py | 34 ++++++++++++++++++++++++---------- rest_framework/schemas.py | 15 ++------------- 5 files changed, 57 insertions(+), 32 deletions(-) diff --git a/docs/api-guide/filtering.md b/docs/api-guide/filtering.md index 0ccd51dd35..9c3f28895d 100644 --- a/docs/api-guide/filtering.md +++ b/docs/api-guide/filtering.md @@ -416,6 +416,12 @@ Generic filters may also present an interface in the browsable API. To do so you The method should return a rendered HTML string. +## Pagination & schemas + +You can also make the filter controls available to the schema autogeneration +that REST framework provides, by implementing a `get_schema_fields()` method, +which should return a list of `coreapi.Field` instances. + # Third party packages The following third party packages provide additional filter implementations. diff --git a/docs/api-guide/pagination.md b/docs/api-guide/pagination.md index 088e071840..f990128c54 100644 --- a/docs/api-guide/pagination.md +++ b/docs/api-guide/pagination.md @@ -276,6 +276,12 @@ To have your custom pagination class be used by default, use the `DEFAULT_PAGINA API responses for list endpoints will now include a `Link` header, instead of including the pagination links as part of the body of the response, for example: +## Pagination & schemas + +You can also make the pagination controls available to the schema autogeneration +that REST framework provides, by implementing a `get_schema_fields()` method, +which should return a list of `coreapi.Field` instances. + --- ![Link Header][link-header] diff --git a/rest_framework/filters.py b/rest_framework/filters.py index aaf3134919..27167adc43 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -16,7 +16,7 @@ from django.utils.translation import ugettext_lazy as _ from rest_framework.compat import ( - crispy_forms, distinct, django_filters, guardian, template_render + coreapi, crispy_forms, distinct, django_filters, guardian, template_render ) from rest_framework.settings import api_settings @@ -72,7 +72,8 @@ def filter_queryset(self, request, queryset, view): """ raise NotImplementedError(".filter_queryset() must be overridden.") - def get_fields(self, view): + def get_schema_fields(self, view): + assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' return [] @@ -131,14 +132,21 @@ def to_html(self, request, queryset, view): template = loader.get_template(self.template) return template_render(template, context) - def get_fields(self, view): + def get_schema_fields(self, view): + assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' filter_class = getattr(view, 'filter_class', None) if filter_class: - return list(filter_class().filters.keys()) + return [ + coreapi.Field(name=field_name, required=False, location='query') + for field_name in filter_class().filters.keys() + ] filter_fields = getattr(view, 'filter_fields', None) if filter_fields: - return filter_fields + return [ + coreapi.Field(name=field_name, required=False, location='query') + for field_name in filter_fields + ] return [] @@ -231,8 +239,9 @@ def to_html(self, request, queryset, view): template = loader.get_template(self.template) return template_render(template, context) - def get_fields(self, view): - return [self.search_param] + def get_schema_fields(self, view): + assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' + return [coreapi.Field(name=self.search_param, required=False, location='query')] class OrderingFilter(BaseFilterBackend): @@ -347,8 +356,9 @@ def to_html(self, request, queryset, view): context = self.get_template_context(request, queryset, view) return template_render(template, context) - def get_fields(self, view): - return [self.ordering_param] + def get_schema_fields(self, view): + assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' + return [coreapi.Field(name=self.ordering_param, required=False, location='query')] class DjangoObjectPermissionsFilter(BaseFilterBackend): diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index c54894efc5..6b539b9008 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -15,7 +15,7 @@ from django.utils.six.moves.urllib import parse as urlparse from django.utils.translation import ugettext_lazy as _ -from rest_framework.compat import template_render +from rest_framework.compat import coreapi, template_render from rest_framework.exceptions import NotFound from rest_framework.response import Response from rest_framework.settings import api_settings @@ -157,7 +157,8 @@ def to_html(self): # pragma: no cover def get_results(self, data): return data['results'] - def get_fields(self, view): + def get_schema_fields(self, view): + assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' return [] @@ -283,10 +284,16 @@ def to_html(self): context = self.get_html_context() return template_render(template, context) - def get_fields(self, view): - if self.page_size_query_param is None: - return [self.page_query_param] - return [self.page_query_param, self.page_size_query_param] + def get_schema_fields(self, view): + assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' + fields = [ + coreapi.Field(name=self.page_query_param, required=False, location='query') + ] + if self.page_size_query_param is not None: + fields.append([ + coreapi.Field(name=self.page_size_query_param, required=False, location='query') + ]) + return fields class LimitOffsetPagination(BasePagination): @@ -415,8 +422,12 @@ def to_html(self): context = self.get_html_context() return template_render(template, context) - def get_fields(self, view): - return [self.limit_query_param, self.offset_query_param] + def get_schema_fields(self, view): + assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' + return [ + coreapi.Field(name=self.limit_query_param, required=False, location='query'), + coreapi.Field(name=self.offset_query_param, required=False, location='query') + ] class CursorPagination(BasePagination): @@ -721,5 +732,8 @@ def to_html(self): context = self.get_html_context() return template_render(template, context) - def get_fields(self, view): - return [self.cursor_query_param] + def get_schema_fields(self, view): + assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' + return [ + coreapi.Field(name=self.cursor_query_param, required=False, location='query') + ] diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index 12119adfc9..69698db0e5 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -42,17 +42,6 @@ def get_pk_name(model): return _get_pk(meta).name -def as_query_fields(items): - """ - Take a list of Fields and plain strings. - Convert any pain strings into `location='query'` Field instances. - """ - return [ - item if isinstance(item, coreapi.Field) else coreapi.Field(name=item, required=False, location='query') - for item in items - ] - - def is_api_view(callback): """ Return `True` if the given view callback is a REST framework view/viewset. @@ -506,7 +495,7 @@ def get_pagination_fields(self, path, method, view): return [] paginator = view.pagination_class() - return as_query_fields(paginator.get_fields(view)) + return paginator.get_schema_fields(view) def get_filter_fields(self, path, method, view): if not is_list_view(path, method, view): @@ -517,7 +506,7 @@ def get_filter_fields(self, path, method, view): fields = [] for filter_backend in view.filter_backends: - fields += as_query_fields(filter_backend().get_fields(view)) + fields += filter_backend().get_schema_fields(view) return fields # Method for generating the link layout....