diff --git a/connexion/content_types.py b/connexion/content_types.py new file mode 100644 index 000000000..1b911ae7f --- /dev/null +++ b/connexion/content_types.py @@ -0,0 +1,181 @@ +import logging +import re + +from jsonschema import ValidationError + +from .exceptions import ExtraParameterProblem, BadRequestProblem +from .types import coerce_type +from .utils import is_null + +logger = logging.getLogger(__name__) + + +class ContentHandlerFactory(object): + + def __init__(self, validator, schema, strict_validation, + is_null_value_valid, consumes): + self.validator = validator + self.schema = schema + self.strict_validation = strict_validation + self.is_null_value_valid = is_null_value_valid + self.consumes = consumes + self._content_handlers = self._discover() + + def _discover(self): + content_handlers = ContentHandler.discover_subclasses() + return { + name: cls(self.validator, self.schema, + self.strict_validation, self.is_null_value_valid) + for name, cls in content_handlers.items() + } + + def get_handler(self, content_type): + match = None + + if content_type is None: + return match + + media_type = content_type.split(";", 1)[0] + if media_type not in self.consumes: + return None + + try: + return self._content_handlers[media_type] + except KeyError: + pass + + matches = [ + (name, handler) for name, handler in self._content_handlers.items() + if handler.regex.match(content_type) + ] + if len(matches) > 1: + logger.warning(f"Content could be handled by multiple validators: {matches}") + + if matches: + name, handler = matches[0] + return handler + + +class ContentHandler(object): + + def __init__(self, validator, schema, strict, is_null_value_valid): + self.schema = schema + self.strict_validation = strict + self.is_null_value_valid = is_null_value_valid + self.validator = validator + self.default = schema.get('default') + + def validate_schema(self, data, url): + # type: (dict, AnyStr) -> Union[ConnexionResponse, None] + if is_null(data): + if self.default: + # TODO do we need to do this? If the spec is valid, this will pass + data = self.default + elif self.is_null_value_valid: + return + + try: + self.validator.validate(data) + except ValidationError as exception: + error_path = '.'.join(str(item) for item in exception.path) + error_path_msg = " - '{path}'".format(path=error_path) \ + if error_path else "" + logger.error( + "{url} validation error: {error}{error_path_msg}".format( + url=url, error=exception.message, + error_path_msg=error_path_msg), + extra={'validator': 'body'}) + raise BadRequestProblem(detail="{message}{error_path_msg}".format( + message=exception.message, + error_path_msg=error_path_msg)) + + def deserialize(self, request): + return request.body + + def validate_request(self, request): + data = self.deserialize(request) + self.validate_schema(data, request.url) + + @classmethod + def discover_subclasses(cls): + subclasses = {c.name: c for c in cls.__subclasses__()} + for s in cls.__subclasses__(): + subclasses.update(s.discover_subclasses()) + return subclasses + + +class StreamingContentHandler(ContentHandler): + name = "application/octet-stream" + regex = re.compile(r'^application\/octet-stream.*') + + def validate_request(self, request): + # Don't validate, leave stream for user to read + pass + + +class TextPlainContentHandler(ContentHandler): + name = "text/plain" + regex = re.compile(r'^text\/plain.*') + + def validate_request(self, request): + # Don't validate, leave stream for user to read + pass + + +class JSONContentHandler(ContentHandler): + name = "application/json" + regex = re.compile(r'^application\/json.*|^.*\+json$') + + def deserialize(self, request): + data = request.json + empty_body = not(request.body or request.form or request.files) + if data is None and not empty_body and not self.is_null_value_valid: + # Content-Type is json but actual body was not parsed + raise BadRequestProblem(detail="Request body is not valid JSON") + return data + + +def validate_parameter_list(request_params, spec_params): + request_params = set(request_params) + spec_params = set(spec_params) + + return request_params.difference(spec_params) + + +class FormDataContentHandler(ContentHandler): + name = "application/x-www-form-urlencoded" + regex = re.compile( + r'^application\/x-www-form-urlencoded.*' + ) + + def _validate_formdata_parameter_list(self, request): + request_params = request.form.keys() + spec_params = self.schema.get('properties', {}).keys() + return validate_parameter_list(request_params, spec_params) + + def deserialize(self, request): + data = dict(request.form.items()) or \ + (request.body if len(request.body) > 0 else {}) + data.update(dict.fromkeys(request.files, '')) # validator expects string.. + logger.debug('%s validating schema...', request.url) + + if self.strict_validation: + formdata_errors = self._validate_formdata_parameter_list(request) + if formdata_errors: + raise ExtraParameterProblem(formdata_errors, []) + + if data: + props = self.schema.get("properties", {}) + for k, param_defn in props.items(): + if k in data: + data[k] = coerce_type(param_defn, data[k], 'requestBody', k) + # XXX it's surprising to hide this in validation + request.form = data + return data + + +class MultiPartFormDataContentHandler(FormDataContentHandler): + name = "multipart/form-data" + regex = re.compile( + r'^multipart\/form-data.*' + ) diff --git a/connexion/decorators/parameter.py b/connexion/decorators/parameter.py index c5341ea0a..6c73dc62e 100644 --- a/connexion/decorators/parameter.py +++ b/connexion/decorators/parameter.py @@ -6,9 +6,8 @@ import inflection import six -from ..http_facts import FORM_CONTENT_TYPES from ..lifecycle import ConnexionRequest # NOQA -from ..utils import all_json +from ..utils import is_form_mimetype, is_json_mimetype try: import builtins @@ -73,7 +72,6 @@ def parameter_to_arg(operation, function, pythonic_params=False, request context will be passed as that argument. :type pass_context_arg_name: str|None """ - consumes = operation.consumes def sanitized(name): return name and re.sub('^[^a-zA-Z_]+', '', re.sub('[^0-9a-zA-Z_]', '', name)) @@ -91,9 +89,9 @@ def wrapper(request): logger.debug('Function Arguments: %s', arguments) kwargs = {} - if all_json(consumes): + if is_json_mimetype(request.content_type): request_body = request.json - elif consumes[0] in FORM_CONTENT_TYPES: + elif is_form_mimetype(request.content_type): request_body = {sanitize(k): v for k, v in request.form.items()} else: request_body = request.body diff --git a/connexion/decorators/validation.py b/connexion/decorators/validation.py index 81e044f1d..bcb4d5e82 100644 --- a/connexion/decorators/validation.py +++ b/connexion/decorators/validation.py @@ -13,6 +13,8 @@ from ..exceptions import ExtraParameterProblem, BadRequestProblem, UnsupportedMediaTypeProblem from ..http_facts import FORM_CONTENT_TYPES from ..json_schema import Draft4RequestValidator, Draft4ResponseValidator +from ..content_types import ContentHandlerFactory +from ..types import TypeValidationError, coerce_type from ..utils import all_json, boolean, is_json_mimetype, is_null, is_nullable _jsonschema_3_or_newer = pkg_resources.parse_version( @@ -21,77 +23,6 @@ logger = logging.getLogger('connexion.decorators.validation') -TYPE_MAP = { - 'integer': int, - 'number': float, - 'boolean': boolean, - 'object': dict -} - - -class TypeValidationError(Exception): - def __init__(self, schema_type, parameter_type, parameter_name): - """ - Exception raise when type validation fails - - :type schema_type: str - :type parameter_type: str - :type parameter_name: str - :return: - """ - self.schema_type = schema_type - self.parameter_type = parameter_type - self.parameter_name = parameter_name - - def __str__(self): - msg = "Wrong type, expected '{schema_type}' for {parameter_type} parameter '{parameter_name}'" - return msg.format(**vars(self)) - - -def coerce_type(param, value, parameter_type, parameter_name=None): - - def make_type(value, type_literal): - type_func = TYPE_MAP.get(type_literal) - return type_func(value) - - param_schema = param.get("schema", param) - if is_nullable(param_schema) and is_null(value): - return None - - param_type = param_schema.get('type') - parameter_name = parameter_name if parameter_name else param.get('name') - if param_type == "array": - converted_params = [] - for v in value: - try: - converted = make_type(v, param_schema["items"]["type"]) - except (ValueError, TypeError): - converted = v - converted_params.append(converted) - return converted_params - elif param_type == 'object': - if param_schema.get('properties'): - def cast_leaves(d, schema): - if type(d) is not dict: - try: - return make_type(d, schema['type']) - except (ValueError, TypeError): - return d - for k, v in d.items(): - if k in schema['properties']: - d[k] = cast_leaves(v, schema['properties'][k]) - return d - - return cast_leaves(value, param_schema) - return value - else: - try: - return make_type(value, param_type) - except ValueError: - raise TypeValidationError(param_type, parameter_type, parameter_name) - except TypeError: - return value - def validate_parameter_list(request_params, spec_params): request_params = set(request_params) @@ -121,11 +52,13 @@ def __init__(self, schema, consumes, api, is_null_value_valid=False, validator=N self.validator = validatorClass(schema, format_checker=draft4_format_checker) self.api = api self.strict_validation = strict_validation - - def validate_formdata_parameter_list(self, request): - request_params = request.form.keys() - spec_params = self.schema.get('properties', {}).keys() - return validate_parameter_list(request_params, spec_params) + self.content_handler_factory = ContentHandlerFactory( + self.validator, + self.schema, + self.strict_validation, + self.is_null_value_valid, + self.consumes + ) def __call__(self, function): """ @@ -135,81 +68,20 @@ def __call__(self, function): @functools.wraps(function) def wrapper(request): - if all_json(self.consumes): - data = request.json - - empty_body = not(request.body or request.form or request.files) - if data is None and not empty_body and not self.is_null_value_valid: - try: - ctype_is_json = is_json_mimetype(request.headers.get("Content-Type", "")) - except ValueError: - ctype_is_json = False - - if ctype_is_json: - # Content-Type is json but actual body was not parsed - raise BadRequestProblem(detail="Request body is not valid JSON") - else: - # the body has contents that were not parsed as JSON - raise UnsupportedMediaTypeProblem( - "Invalid Content-type ({content_type}), expected JSON data".format( - content_type=request.headers.get("Content-Type", "") - )) - - logger.debug("%s validating schema...", request.url) - if data is not None or not self.has_default: - self.validate_schema(data, request.url) - elif self.consumes[0] in FORM_CONTENT_TYPES: - data = dict(request.form.items()) or (request.body if len(request.body) > 0 else {}) - data.update(dict.fromkeys(request.files, '')) # validator expects string.. - logger.debug('%s validating schema...', request.url) - - if self.strict_validation: - formdata_errors = self.validate_formdata_parameter_list(request) - if formdata_errors: - raise ExtraParameterProblem(formdata_errors, []) - - if data: - props = self.schema.get("properties", {}) - errs = [] - for k, param_defn in props.items(): - if k in data: - try: - data[k] = coerce_type(param_defn, data[k], 'requestBody', k) - except TypeValidationError as e: - errs += [str(e)] - print(errs) - if errs: - raise BadRequestProblem(detail=errs) - - self.validate_schema(data, request.url) + content_handler = self.content_handler_factory.get_handler(request.content_type) + if content_handler is None: + raise UnsupportedMediaTypeProblem( + "Unsupported Content-type ({content_type})".format( + content_type=request.headers.get("Content-Type", "") + )) + + content_handler.validate_request(request) response = function(request) return response return wrapper - def validate_schema(self, data, url): - # type: (dict, AnyStr) -> Union[ConnexionResponse, None] - if self.is_null_value_valid and is_null(data): - return None - - try: - self.validator.validate(data) - except ValidationError as exception: - error_path = '.'.join(str(item) for item in exception.path) - error_path_msg = " - '{path}'".format(path=error_path) \ - if error_path else "" - logger.error( - "{url} validation error: {error}{error_path_msg}".format( - url=url, error=exception.message, - error_path_msg=error_path_msg), - extra={'validator': 'body'}) - raise BadRequestProblem(detail="{message}{error_path_msg}".format( - message=exception.message, - error_path_msg=error_path_msg)) - - return None - class ResponseBodyValidator(object): def __init__(self, schema, validator=None): diff --git a/connexion/lifecycle.py b/connexion/lifecycle.py index 32fb0f26f..b86de4e93 100644 --- a/connexion/lifecycle.py +++ b/connexion/lifecycle.py @@ -1,4 +1,3 @@ - class ConnexionRequest(object): def __init__(self, url, @@ -22,6 +21,10 @@ def __init__(self, self.files = files self.context = context if context is not None else {} + @property + def content_type(self): + return self.headers.get("Content-Type") + @property def json(self): return self.json_getter() diff --git a/connexion/operations/abstract.py b/connexion/operations/abstract.py index f009f5991..d8a426d23 100644 --- a/connexion/operations/abstract.py +++ b/connexion/operations/abstract.py @@ -340,8 +340,29 @@ def _uri_parsing_decorator(self): @property def function(self): - """ - Operation function with decorators + """ Operation function with decorators + + request comes in -> connexionrequest (framework generic) + API metrics recorded + API security applied + + arrays are deserialized (spec dependent) + - query params + - form params + - path params + + body is deserialized (spec dependent, content dependent) + + validation is applied (spec dependent, content dependent) + - query params + - form params + - body params + - path params ? + + parameter_to_arg (handler, what to call it) + + json serializer is applied + response is validated if json :rtype: types.FunctionType """ diff --git a/connexion/operations/openapi.py b/connexion/operations/openapi.py index 93f3d8953..764f47715 100644 --- a/connexion/operations/openapi.py +++ b/connexion/operations/openapi.py @@ -272,7 +272,7 @@ def _get_body_argument(self, body, arguments, has_kwargs, sanitize): res = self._get_typed_body_values(body_arg, body_props, additional_props) if x_body_name in arguments or has_kwargs: - return {x_body_name: res} + return {x_body_name: body_arg} return {} def _get_typed_body_values(self, body_arg, body_props, additional_props): diff --git a/connexion/types.py b/connexion/types.py new file mode 100644 index 000000000..c509d3f3e --- /dev/null +++ b/connexion/types.py @@ -0,0 +1,71 @@ +from .utils import boolean, is_null, is_nullable + +TYPE_MAP = { + 'integer': int, + 'number': float, + 'boolean': boolean +} + + +class TypeValidationError(Exception): + def __init__(self, schema_type, parameter_type, parameter_name): + """ + Exception raise when type validation fails + + :type schema_type: str + :type parameter_type: str + :type parameter_name: str + :return: + """ + self.schema_type = schema_type + self.parameter_type = parameter_type + self.parameter_name = parameter_name + + def __str__(self): + msg = "Wrong type, expected '{schema_type}' for {parameter_type} parameter '{parameter_name}'" + return msg.format(**vars(self)) + + +def coerce_type(param, value, parameter_type, parameter_name=None): + + def make_type(value, type_literal): + type_func = TYPE_MAP.get(type_literal) + return type_func(value) + + param_schema = param.get("schema", param) + if is_nullable(param_schema) and is_null(value): + return None + + param_type = param_schema.get('type') + parameter_name = parameter_name if parameter_name else param.get('name') + if param_type == "array": + converted_params = [] + for v in value: + try: + converted = make_type(v, param_schema["items"]["type"]) + except (ValueError, TypeError): + converted = v + converted_params.append(converted) + return converted_params + elif param_type == 'object': + if param_schema.get('properties'): + def cast_leaves(d, schema): + if type(d) is not dict: + try: + return make_type(d, schema['type']) + except (ValueError, TypeError): + return d + for k, v in d.items(): + if k in schema['properties']: + d[k] = cast_leaves(v, schema['properties'][k]) + return d + + return cast_leaves(value, param_schema) + return value + else: + try: + return make_type(value, param_type) + except ValueError: + raise TypeValidationError(param_type, parameter_type, parameter_name) + except TypeError: + return value diff --git a/connexion/utils.py b/connexion/utils.py index 145d3b366..67d970545 100644 --- a/connexion/utils.py +++ b/connexion/utils.py @@ -114,12 +114,27 @@ def get_function_from_name(function_name): return function +def is_form_mimetype(mimetype): + try: + mimetype = mimetype.split(";")[0] + maintype, subtype = mimetype.split('/') # type: str, str + except (ValueError, AttributeError): + return False + + multipart = maintype == 'multipart' and subtype.startswith("form-data") + urlenc = maintype == 'application' and subtype.startswith("x-www-form-urlencoded") + return multipart or urlenc + + def is_json_mimetype(mimetype): """ :type mimetype: str :rtype: bool """ - maintype, subtype = mimetype.split('/') # type: str, str + try: + maintype, subtype = mimetype.split('/') # type: str, str + except (ValueError, AttributeError): + return False return maintype == 'application' and (subtype == 'json' or subtype.endswith('+json')) diff --git a/tests/aiohttp/test_aiohttp_simple_api.py b/tests/aiohttp/test_aiohttp_simple_api.py index 94147c17d..6a82ee3e7 100644 --- a/tests/aiohttp/test_aiohttp_simple_api.py +++ b/tests/aiohttp/test_aiohttp_simple_api.py @@ -301,7 +301,7 @@ def test_get_users(aiohttp_client, aiohttp_app): def test_create_user(aiohttp_client, aiohttp_app): app_client = yield from aiohttp_client(aiohttp_app.app) user = {'name': 'Maksim'} - resp = yield from app_client.post('/v1.0/users', json=user, headers={'Content-type': 'application/json'}) + resp = yield from app_client.post('/v1.0/users', json=user) assert resp.status == 201 diff --git a/tests/api/test_parameters.py b/tests/api/test_parameters.py index 361105706..28bec2581 100644 --- a/tests/api/test_parameters.py +++ b/tests/api/test_parameters.py @@ -159,9 +159,11 @@ def test_falsy_param(simple_app): def test_formdata_param(simple_app): + headers = {'Content-Type': 'multipart/form-data'} app_client = simple_app.app.test_client() resp = app_client.post('/v1.0/test-formData-param', - data={'formData': 'test'}) + data={'formData': 'test'}, + headers=headers) assert resp.status_code == 200 response = json.loads(resp.data.decode('utf-8', 'replace')) assert response == 'test' @@ -169,7 +171,8 @@ def test_formdata_param(simple_app): def test_formdata_bad_request(simple_app): app_client = simple_app.app.test_client() - resp = app_client.post('/v1.0/test-formData-param') + headers = {'Content-Type': 'multipart/form-data'} + resp = app_client.post('/v1.0/test-formData-param', headers=headers) assert resp.status_code == 400 response = json.loads(resp.data.decode('utf-8', 'replace')) assert response['detail'] in [ @@ -187,17 +190,21 @@ def test_formdata_missing_param(simple_app): def test_formdata_extra_param(simple_app): app_client = simple_app.app.test_client() + headers = {'content-type': 'multipart/form-data'} resp = app_client.post('/v1.0/test-formData-param', data={'formData': 'test', - 'extra_formData': 'test'}) + 'extra_formData': 'test'}, + headers=headers) assert resp.status_code == 200 def test_strict_formdata_extra_param(strict_app): app_client = strict_app.app.test_client() + headers = {'content-type': 'multipart/form-data'} resp = app_client.post('/v1.0/test-formData-param', data={'formData': 'test', - 'extra_formData': 'test'}) + 'extra_formData': 'test'}, + headers=headers) assert resp.status_code == 400 response = json.loads(resp.data.decode('utf-8', 'replace')) assert response['detail'] == "Extra formData parameter(s) extra_formData not in spec" @@ -225,7 +232,8 @@ def test_formdata_multiple_file_upload(simple_app): def test_formdata_file_upload_bad_request(simple_app): app_client = simple_app.app.test_client() - resp = app_client.post('/v1.0/test-formData-file-upload') + headers = {'Content-Type': 'multipart/form-data'} + resp = app_client.post('/v1.0/test-formData-file-upload', headers=headers) assert resp.status_code == 400 response = json.loads(resp.data.decode('utf-8', 'replace')) assert response['detail'] in [ @@ -277,6 +285,17 @@ def test_bool_param(simple_app): assert response is False +def test_empty_object_body(simple_app): + app_client = simple_app.app.test_client() + resp = app_client.post( + '/v1.0/test-empty-object-body', + data=json.dumps({}), + headers={'Content-Type': 'application/json'}) + assert resp.status_code == 200 + response = json.loads(resp.data.decode('utf-8', 'replace')) + assert response['stack'] == {} + + def test_bool_array_param(simple_app): app_client = simple_app.app.test_client() resp = app_client.get('/v1.0/test-bool-array-param?thruthiness=true,true,true') @@ -367,19 +386,21 @@ def test_args_kwargs(simple_app): def test_param_sanitization(simple_app): app_client = simple_app.app.test_client() - resp = app_client.post('/v1.0/param-sanitization') + headers = {'content-type': 'multipart/form-data'} + resp = app_client.post('/v1.0/param-sanitization', headers=headers) assert resp.status_code == 200 assert json.loads(resp.data.decode('utf-8', 'replace')) == {} + headers = {'content-type': 'application/x-www-form-urlencoded'} resp = app_client.post('/v1.0/param-sanitization?$query=queryString', - data={'$form': 'formString'}) + data={'$form': 'formString'}, headers=headers) assert resp.status_code == 200 assert json.loads(resp.data.decode('utf-8', 'replace')) == { 'query': 'queryString', 'form': 'formString', } - body = { 'body1': 'bodyString', 'body2': 'otherString' } + body = {'body1': 'bodyString', 'body2': 'otherString'} resp = app_client.post( '/v1.0/body-sanitization', data=json.dumps(body), diff --git a/tests/api/test_responses.py b/tests/api/test_responses.py index e5a787a4c..b637b18d6 100644 --- a/tests/api/test_responses.py +++ b/tests/api/test_responses.py @@ -158,12 +158,13 @@ def test_redirect_response_endpoint(simple_app): def test_default_object_body(simple_app): app_client = simple_app.app.test_client() - resp = app_client.post('/v1.0/test-default-object-body') + headers = {'content-type': 'application/json'} + resp = app_client.post('/v1.0/test-default-object-body', headers=headers) assert resp.status_code == 200 response = json.loads(resp.data.decode('utf-8', 'replace')) assert response['stack'] == {'image_version': 'default_image'} - resp = app_client.post('/v1.0/test-default-integer-body') + resp = app_client.post('/v1.0/test-default-integer-body', headers=headers) assert resp.status_code == 200 response = json.loads(resp.data.decode('utf-8', 'replace')) assert response == 1 @@ -238,8 +239,8 @@ def test_bad_operations(bad_operations_app): def test_text_request(simple_app): app_client = simple_app.app.test_client() - - resp = app_client.post('/v1.0/text-request', data='text') + headers = {'content-type': 'text/plain'} + resp = app_client.post('/v1.0/text-request', data='text', headers=headers) assert resp.status_code == 200 diff --git a/tests/fakeapi/hello.py b/tests/fakeapi/hello.py index b6de06204..500ae3440 100755 --- a/tests/fakeapi/hello.py +++ b/tests/fakeapi/hello.py @@ -34,6 +34,10 @@ def post(): return '' +def test_empty_object_body(stack): + return {"stack": stack} + + def post_greeting(name, **kwargs): data = {'greeting': 'Hello {name}'.format(name=name)} return data diff --git a/tests/fixtures/simple/openapi.yaml b/tests/fixtures/simple/openapi.yaml index fdc0e68b0..9656f3c34 100644 --- a/tests/fixtures/simple/openapi.yaml +++ b/tests/fixtures/simple/openapi.yaml @@ -540,7 +540,7 @@ paths: description: OK requestBody: content: - application/x-www-form-urlencoded: + multipart/form-data: schema: type: object properties: @@ -770,6 +770,19 @@ paths: type: array items: type: string + /test-empty-object-body: + post: + summary: Test if empty object body param is passed to handler. + operationId: fakeapi.hello.test_empty_object_body + responses: + '200': + description: OK + requestBody: + content: + application/json: + schema: + x-body-name: stack + type: object /nullable-parameters: post: operationId: fakeapi.hello.test_nullable_param_post3 @@ -904,6 +917,20 @@ paths: type: object requestBody: content: + application/json: + schema: + type: object + properties: + 'form$': + description: Just a testing parameter in the form data + type: string + application/x-www-form-urlencoded: + schema: + type: object + properties: + 'form$': + description: Just a testing parameter in the form data + type: string multipart/form-data: schema: type: object diff --git a/tests/fixtures/simple/swagger.yaml b/tests/fixtures/simple/swagger.yaml index 32d680e60..bab2374b0 100644 --- a/tests/fixtures/simple/swagger.yaml +++ b/tests/fixtures/simple/swagger.yaml @@ -63,7 +63,18 @@ paths: description: Name of the person to greet. required: true type: string - + /test-empty-object-body: + post: + summary: Test if empty object body param is passed to handler. + operationId: fakeapi.hello.test_empty_object_body + parameters: + - name: stack + in: body + schema: + type: object + responses: + 200: + description: OK /bye/{name}: get: summary: Generate goodbye