Skip to content
Prev Previous commit
Next Next commit
content handler autodiscover
  • Loading branch information
dtkav committed Dec 12, 2019
commit 0023799dc3d5bd6dbacd00e87652ef78d0e05fa1
78 changes: 66 additions & 12 deletions connexion/content_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,53 @@
from .types import coerce_type
from .utils import is_null

logger = logging.getLogger('connexion.content_types')
logger = logging.getLogger(__name__)


class ContentHandlerFactory(object):

def __init__(self, validator, schema, strict_validation,
is_null_value_valid, consumes):
self.validator = validator
self.schema = schema
self.strict_validation = strict_validation
self.is_null_value_valid = is_null_value_valid
self.consumes = consumes
self._content_handlers = self._discover()

def _discover(self):
content_handlers = ContentHandler.discover_subclasses()
return {
name: cls(self.validator, self.schema,
self.strict_validation, self.is_null_value_valid)
for name, cls in content_handlers.items()
}

def get_handler(self, content_type):
match = None

if content_type is None:
return match

media_type = content_type.split(";", 1)[0]
if media_type not in self.consumes:
return None

try:
return self._content_handlers[media_type]
except KeyError:
pass

matches = [
(name, handler) for name, handler in self._content_handlers.items()
if handler.regex.match(content_type)
]
if len(matches) > 1:
logger.warning(f"Content could be handled by multiple validators: {matches}")

if matches:
name, handler = matches[0]
return handler


class ContentHandler(object):
Expand All @@ -23,7 +69,7 @@ def validate_schema(self, data, url):
# type: (dict, AnyStr) -> Union[ConnexionResponse, None]
if is_null(data):
if self.default:
# XXX do we need to do this? If the spec is valid, this will pass
# TODO do we need to do this? If the spec is valid, this will pass
data = self.default
elif self.is_null_value_valid:
return
Expand All @@ -46,16 +92,32 @@ def validate_schema(self, data, url):
def deserialize(self, request):
return request.body

def validate(self, request):
def validate_request(self, request):
data = self.deserialize(request)
self.validate_schema(data, request.url)

@classmethod
def discover_subclasses(cls):
subclasses = {c.name: c for c in cls.__subclasses__()}
for s in cls.__subclasses__():
subclasses.update(s.discover_subclasses())
return subclasses


class StreamingContentHandler(ContentHandler):
name = "application/octet-stream"
regex = re.compile(r'^application\/octet-stream.*')

def validate(self, request):
def validate_request(self, request):
# Don't validate, leave stream for user to read
pass


class TextPlainContentHandler(ContentHandler):
name = "text/plain"
regex = re.compile(r'^text\/plain.*')

def validate_request(self, request):
# Don't validate, leave stream for user to read
pass

Expand Down Expand Up @@ -117,11 +179,3 @@ class MultiPartFormDataContentHandler(FormDataContentHandler):
regex = re.compile(
r'^multipart\/form-data.*'
)


KNOWN_CONTENT_TYPES = (
StreamingContentHandler,
JSONContentHandler,
FormDataContentHandler,
MultiPartFormDataContentHandler
)
53 changes: 12 additions & 41 deletions connexion/decorators/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ..exceptions import ExtraParameterProblem, BadRequestProblem, UnsupportedMediaTypeProblem
from ..http_facts import FORM_CONTENT_TYPES
from ..json_schema import Draft4RequestValidator, Draft4ResponseValidator
from ..content_types import KNOWN_CONTENT_TYPES
from ..content_types import ContentHandlerFactory
from ..types import TypeValidationError, coerce_type
from ..utils import all_json, boolean, is_json_mimetype, is_null, is_nullable

Expand Down Expand Up @@ -52,30 +52,13 @@ def __init__(self, schema, consumes, api, is_null_value_valid=False, validator=N
self.validator = validatorClass(schema, format_checker=draft4_format_checker)
self.api = api
self.strict_validation = strict_validation
self._content_handlers = [
de(self.validator,
self.schema,
self.strict_validation,
self.is_null_value_valid) for de in KNOWN_CONTENT_TYPES
]

def register_content_handler(self, cv):
deser = cv(self.validator,
self.schema,
self.strict_validation,
self.is_null_value_valid)
self._content_handlers += [deser]

def lookup_content_handler(self, request):
matches = [
v for v in self._content_handlers
if request.content_type is not None and
v.regex.match(request.content_type)
]
if len(matches) > 1:
logger.warning("Content could be handled by multiple validators")
if matches:
return matches[0]
self.content_handler_factory = ContentHandlerFactory(
self.validator,
self.schema,
self.strict_validation,
self.is_null_value_valid,
self.consumes
)

def __call__(self, function):
"""
Expand All @@ -85,26 +68,14 @@ def __call__(self, function):

@functools.wraps(function)
def wrapper(request):
content_handler = self.lookup_content_handler(request)
exact_match = (
request.content_type is not None and
request.content_type in self.consumes
)
partial_match = content_handler and content_handler.name in self.consumes
if not (exact_match or partial_match):
content_handler = self.content_handler_factory.get_handler(request.content_type)
if content_handler is None:
raise UnsupportedMediaTypeProblem(
"Invalid Content-type ({content_type}), expected JSON data".format(
"Unsupported Content-type ({content_type})".format(
content_type=request.headers.get("Content-Type", "")
))

if content_handler:
error = content_handler.validate(request)
if error:
return error

elif self.strict_validation:
logger.warning("No handler for ({content_type})".format(
content_type=request.content_type))
content_handler.validate_request(request)

response = function(request)
return response
Expand Down