Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
bc836aa
Initial pass at schema support
lovelydinosaur Jun 8, 2016
b64340b
Add coreapi to optional requirements.
lovelydinosaur Jun 9, 2016
c890ad4
Clean up test failures
lovelydinosaur Jun 9, 2016
2d28390
Add missing newline
lovelydinosaur Jun 9, 2016
744dba4
Minor docs update
lovelydinosaur Jun 9, 2016
56ece73
Version bump for coreapi in requirements
lovelydinosaur Jun 9, 2016
99adbf1
Catch SyntaxError when importing coreapi with python 3.2
lovelydinosaur Jun 9, 2016
80c595e
Add --diff to isort
lovelydinosaur Jun 9, 2016
47c7765
Import coreapi from compat
lovelydinosaur Jun 9, 2016
29e228d
Fail gracefully if attempting to use schemas without coreapi being in…
lovelydinosaur Jun 9, 2016
eeffca4
Tutorial updates
lovelydinosaur Jun 9, 2016
6c60f58
Docs update
lovelydinosaur Jun 9, 2016
b7fcdd2
Initial schema generation & first tutorial 7 draft
lovelydinosaur Jun 10, 2016
2e60f41
Spelling
lovelydinosaur Jun 10, 2016
2ffa145
Remove unused variable
lovelydinosaur Jun 10, 2016
b709dd4
Docs tweak
lovelydinosaur Jun 10, 2016
474a23e
Merge branch 'master' into schema-support
lovelydinosaur Jun 14, 2016
4822896
Added SchemaGenerator class
lovelydinosaur Jun 15, 2016
1f76cca
Fail gracefully if coreapi is not installed and SchemaGenerator is used
lovelydinosaur Jun 15, 2016
cad24b1
Schema docs, pagination controls, filter controls
lovelydinosaur Jun 21, 2016
8fb2602
Resolve NameError
lovelydinosaur Jun 21, 2016
b438281
Add 'view' argument to 'get_fields()'
lovelydinosaur Jun 21, 2016
8519b4e
Remove extranous blank line
lovelydinosaur Jun 21, 2016
2f5c974
Add integration tests for schema generation
lovelydinosaur Jun 22, 2016
e78753d
Only set 'encoding' if a 'form' or 'body' field exists
lovelydinosaur Jun 22, 2016
84bb5ea
Do not use schmea in tests if coreapi is not installed
lovelydinosaur Jun 24, 2016
63e8467
Inital pass at API client docs
lovelydinosaur Jun 29, 2016
bdbcb33
Inital pass at API client docs
lovelydinosaur Jun 29, 2016
7236af3
More work towards client documentation
lovelydinosaur Jun 30, 2016
89540ab
Add coreapi to optional packages list
lovelydinosaur Jul 4, 2016
e3ced75
Clean up API clients docs
lovelydinosaur Jul 4, 2016
12be5b3
Resolve typo
lovelydinosaur Jul 4, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Added SchemaGenerator class
  • Loading branch information
lovelydinosaur committed Jun 15, 2016
commit 482289695d0f1c0960de5645a05f503cea1d8c5a
79 changes: 9 additions & 70 deletions rest_framework/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,15 @@
import itertools
from collections import OrderedDict, namedtuple

import uritemplate
from django.conf.urls import url
from django.core.exceptions import ImproperlyConfigured
from django.core.urlresolvers import NoReverseMatch

from rest_framework import exceptions, renderers, views
from rest_framework.compat import coreapi
from rest_framework.request import override_method
from rest_framework.response import Response
from rest_framework.reverse import reverse
from rest_framework.schemas import SchemaGenerator
from rest_framework.settings import api_settings
from rest_framework.urlpatterns import format_suffix_patterns

Expand Down Expand Up @@ -263,63 +262,6 @@ def get_urls(self):

return ret

def get_links(self, request=None):
content = {}

for prefix, viewset, basename in self.registry:
lookup_field = getattr(viewset, 'lookup_field', 'pk')
lookup_url_kwarg = getattr(viewset, 'lookup_url_kwarg', None) or lookup_field
lookup_placeholder = '{' + lookup_url_kwarg + '}'

routes = self.get_routes(viewset)

for route in routes:
url = '/' + route.url.format(
prefix=prefix,
lookup=lookup_placeholder,
trailing_slash=self.trailing_slash
).lstrip('^').rstrip('$')

mapping = self.get_method_map(viewset, route.mapping)
if not mapping:
continue

for method, action in mapping.items():
link = self.get_link(viewset, url, method, request)
if link is None:
continue # User does not have permissions.
if prefix not in content:
content[prefix] = {}
content[prefix][action] = link
return content

def get_link(self, viewset, url, method, request=None):
view_instance = viewset()
if request is not None:
with override_method(view_instance, request, method.upper()) as request:
try:
view_instance.check_permissions(request)
except exceptions.APIException:
return None

fields = []

for variable in uritemplate.variables(url):
field = coreapi.Field(name=variable, location='path', required=True)
fields.append(field)

if method in ('put', 'patch', 'post'):
cls = view_instance.get_serializer_class()
serializer = cls()
for field in serializer.fields.values():
if field.read_only:
continue
required = field.required and method != 'patch'
field = coreapi.Field(name=field.source, location='form', required=required)
fields.append(field)

return coreapi.Link(url=url, action=method, fields=fields)


class DefaultRouter(SimpleRouter):
"""
Expand All @@ -334,7 +276,7 @@ def __init__(self, *args, **kwargs):
self.schema_title = kwargs.pop('schema_title', None)
super(DefaultRouter, self).__init__(*args, **kwargs)

def get_api_root_view(self):
def get_api_root_view(self, schema_urls=None):
"""
Return a view to use as the API root.
"""
Expand All @@ -345,21 +287,20 @@ def get_api_root_view(self):

view_renderers = list(api_settings.DEFAULT_RENDERER_CLASSES)

if self.schema_title:
if schema_urls and self.schema_title:
assert coreapi, '`coreapi` must be installed for schema support.'
view_renderers += [renderers.CoreJSONRenderer]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it make more sense to expect users to include the CoreJSONRenderer in their default renderers, and raise an error if it is missing from the defaults?

router = self
schema_generator = SchemaGenerator(patterns=schema_urls)

class APIRoot(views.APIView):
_ignore_model_permissions = True
renderer_classes = view_renderers

def get(self, request, *args, **kwargs):
if request.accepted_renderer.format == 'corejson':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason why this couldn't be brought into the CoreJSONRenderer?

Copy link
Contributor Author

@lovelydinosaur lovelydinosaur Jun 15, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is less what to do with CoreJSONRenderer and more about making sure that we preserve the behavior of that view for anything that isn't a schema renderer.

(Also I'd consider that part still slightly in flux - clearly needs refinement)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

making sure that we preserve the behavior of that view for anything that isn't a schema renderer

Right, my main concern with that line (which I'm sure will probably change another few times before this is merged) is that it's hard-coding corejson as the only renderer with schema support.

Also I'd consider that part still slightly in flux - clearly needs refinement

Definitely understandable.

content = router.get_links(request)
if not content:
schema = schema_generator.get_schema(request)
if schema is None:
raise exceptions.PermissionDenied()
schema = coreapi.Document(title=router.schema_title, content=content)
return Response(schema)

ret = OrderedDict()
Expand Down Expand Up @@ -388,15 +329,13 @@ def get_urls(self):
Generate the list of URL patterns, including a default root view
for the API, and appending `.json` style format suffixes.
"""
urls = []
urls = super(DefaultRouter, self).get_urls()

if self.include_root_view:
root_url = url(r'^$', self.get_api_root_view(), name=self.root_view_name)
view = self.get_api_root_view(schema_urls=urls)
root_url = url(r'^$', view, name=self.root_view_name)
urls.append(root_url)

default_urls = super(DefaultRouter, self).get_urls()
urls.extend(default_urls)

if self.include_format_suffixes:
urls = format_suffix_patterns(urls)

Expand Down
176 changes: 176 additions & 0 deletions rest_framework/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
from importlib import import_module

import coreapi
import uritemplate
from django.conf import settings
from django.contrib.admindocs.views import simplify_regex
from django.core.urlresolvers import RegexURLPattern, RegexURLResolver
from django.utils import six

from rest_framework import exceptions
from rest_framework.request import clone_request
from rest_framework.views import APIView


class SchemaGenerator(object):
default_mapping = {
'get': 'read',
'post': 'create',
'put': 'update',
'patch': 'partial_update',
'delete': 'destroy',
}

def __init__(self, schema_title=None, patterns=None, urlconf=None):
if patterns is None and urlconf is not None:
if isinstance(urlconf, six.string_types):
urls = import_module(urlconf)
else:
urls = urlconf
patterns = urls.urlpatterns
elif patterns is None and urlconf is None:
urls = import_module(settings.ROOT_URLCONF)
patterns = urls.urlpatterns

self.schema_title = schema_title
self.endpoints = self.get_api_endpoints(patterns)

def get_schema(self, request=None):
if request is None:
endpoints = self.endpoints
else:
# Filter the list of endpoints to only include those that
# the user has permission on.
endpoints = []
for key, link, callback in self.endpoints:
method = link.action.upper()
view = callback.cls()
view.request = clone_request(request, method)
try:
view.check_permissions(view.request)
except exceptions.APIException:
pass
else:
endpoints.append((key, link, callback))

if not endpoints:
return None

# Generate the schema content structure, from the endpoints.
# ('users', 'list'), Link -> {'users': {'list': Link()}}
content = {}
for key, link, callback in endpoints:
insert_into = content
for item in key[:1]:
if item not in insert_into:
insert_into[item] = {}
insert_into = insert_into[item]
insert_into[key[-1]] = link

# Return the schema document.
return coreapi.Document(title=self.schema_title, content=content)

def get_api_endpoints(self, patterns, prefix=''):
"""
Return a list of all available API endpoints by inspecting the URL conf.
"""
api_endpoints = []

for pattern in patterns:
path_regex = prefix + pattern.regex.pattern

if isinstance(pattern, RegexURLPattern):
path = self.get_path(path_regex)
callback = pattern.callback
if self.include_endpoint(path, callback):
for method in self.get_allowed_methods(callback):
key = self.get_key(path, method, callback)
link = self.get_link(path, method, callback)
endpoint = (key, link, callback)
api_endpoints.append(endpoint)

elif isinstance(pattern, RegexURLResolver):
nested_endpoints = self.get_api_endpoints(
patterns=pattern.url_patterns,
prefix=path_regex
)
api_endpoints.extend(nested_endpoints)

return api_endpoints

def get_path(self, path_regex):
"""
Given a URL conf regex, return a URI template string.
"""
path = simplify_regex(path_regex)
path = path.replace('<', '{').replace('>', '}')
return path

def include_endpoint(self, path, callback):
"""
Return True if the given endpoint should be included.
"""
cls = getattr(callback, 'cls', None)
if (cls is None) or not issubclass(cls, APIView):
return False

if path.endswith('.{format}') or path.endswith('.{format}/'):
return False

if path == '/':
return False

return True

def get_allowed_methods(self, callback):
"""
Return a list of the valid HTTP methods for this endpoint.
"""
if hasattr(callback, 'actions'):
return [method.upper() for method in callback.actions.keys()]

return [
method for method in
callback.cls().allowed_methods if method != 'OPTIONS'
]

def get_key(self, path, method, callback):
"""
Return a tuple of strings, indicating the identity to use for a
given endpoint. eg. ('users', 'list').
"""
category = None
for item in path.strip('/').split('/'):
if '{' in item:
break
category = item

actions = getattr(callback, 'actions', self.default_mapping)
action = actions[method.lower()]

if category:
return (category, action)
return (action,)

def get_link(self, path, method, callback):
"""
Return a `coreapi.Link` instance for the given endpoint.
"""
view = callback.cls()
fields = []

for variable in uritemplate.variables(path):
field = coreapi.Field(name=variable, location='path', required=True)
fields.append(field)

if method in ('PUT', 'PATCH', 'POST'):
serializer_class = view.get_serializer_class()
serializer = serializer_class()
for field in serializer.fields.values():
if field.read_only:
continue
required = field.required and method != 'PATCH'
field = coreapi.Field(name=field.source, location='form', required=required)
fields.append(field)

return coreapi.Link(url=path, action=method.lower(), fields=fields)
1 change: 1 addition & 0 deletions rest_framework/viewsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def view(request, *args, **kwargs):
# resolved URL.
view.cls = cls
view.suffix = initkwargs.get('suffix', None)
view.actions = actions
return csrf_exempt(view)

def initialize_request(self, request, *args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ class NoteViewSet(viewsets.ModelViewSet):

def test_router_has_custom_name(self):
expected = 'nameable-root'
self.assertEqual(expected, self.urls[0].name)
self.assertEqual(expected, self.urls[-1].name)


class TestActionKeywordArgs(TestCase):
Expand Down