From 390336663618fdb50d80df203b753d2ac62521cd Mon Sep 17 00:00:00 2001 From: Robbe Sneyders Date: Fri, 16 Sep 2022 22:14:00 +0200 Subject: [PATCH 1/9] Extract boilerplate code into Routed base classes --- connexion/middleware/__init__.py | 2 +- connexion/middleware/abstract.py | 130 ++++++++++++++++++++++ connexion/middleware/exceptions.py | 2 +- connexion/middleware/routing.py | 121 ++++++++++---------- connexion/middleware/security.py | 169 +++++++++------------------- connexion/middleware/swagger_ui.py | 100 ++++++++--------- connexion/middleware/validation.py | 172 +++++++++-------------------- tests/test_operation2.py | 9 +- 8 files changed, 351 insertions(+), 354 deletions(-) diff --git a/connexion/middleware/__init__.py b/connexion/middleware/__init__.py index 302bc67c7..342848f42 100644 --- a/connexion/middleware/__init__.py +++ b/connexion/middleware/__init__.py @@ -1,4 +1,4 @@ -from .abstract import AppMiddleware # NOQA +from .abstract import AppMiddleware, RoutedMiddleware # NOQA from .main import ConnexionMiddleware # NOQA from .routing import RoutingMiddleware # NOQA from .swagger_ui import SwaggerUIMiddleware # NOQA diff --git a/connexion/middleware/abstract.py b/connexion/middleware/abstract.py index 69c065cb4..2654fbcb1 100644 --- a/connexion/middleware/abstract.py +++ b/connexion/middleware/abstract.py @@ -1,7 +1,20 @@ import abc +import logging import pathlib import typing as t +from starlette.types import ASGIApp, Receive, Scope, Send + +from connexion.apis.abstract import AbstractSpecAPI +from connexion.exceptions import MissingMiddleware +from connexion.http_facts import METHODS +from connexion.operations import AbstractOperation +from connexion.resolver import ResolverError + +logger = logging.getLogger("connexion.middleware.abstract") + +ROUTING_CONTEXT = "connexion_routing" + class AppMiddleware(abc.ABC): """Middlewares that need the APIs to be registered on them should inherit from this base @@ -12,3 +25,120 @@ def add_api( self, specification: t.Union[pathlib.Path, str, dict], **kwargs ) -> None: pass + + +class RoutedOperation(t.Protocol): + def __init__(self, next_app: ASGIApp, **kwargs) -> None: + ... + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + ... + + +OP = t.TypeVar("OP", bound=RoutedOperation) + + +class RoutedAPI(AbstractSpecAPI, t.Generic[OP]): + + operation_cls: t.Type[OP] + """The operation this middleware uses, which should implement the RoutingOperation protocol.""" + + def __init__( + self, + specification: t.Union[pathlib.Path, str, dict], + *args, + next_app: ASGIApp, + **kwargs, + ) -> None: + super().__init__(specification, *args, **kwargs) + self.next_app = next_app + self.operations: t.MutableMapping[str, OP] = {} + + def add_paths(self) -> None: + paths = self.specification.get("paths", {}) + for path, methods in paths.items(): + for method in methods: + if method not in METHODS: + continue + try: + self.add_operation(path, method) + except ResolverError: + # ResolverErrors are either raised or handled in routing middleware. + pass + + def add_operation(self, path: str, method: str) -> None: + operation_spec_cls = self.specification.operation_cls + operation = operation_spec_cls.from_spec( + self.specification, self, path, method, self.resolver + ) + routed_operation = self.make_operation(operation) + self.operations[operation.operation_id] = routed_operation + + @abc.abstractmethod + def make_operation(self, operation: AbstractOperation) -> OP: + """Create an operation of the `operation_cls` type.""" + raise NotImplementedError + + +API = t.TypeVar("API", bound="RoutedAPI") + + +class RoutedMiddleware(AppMiddleware, t.Generic[API]): + """Baseclass for middleware that wants to leverage the RoutingMiddleware to route requests to + its operations. + + The RoutingMiddleware adds the operation_id to the ASGI scope. This middleware registers its + operations by operation_id at startup. At request time, the operation is fetched by an + operation_id lookup. + """ + + @property + @abc.abstractmethod + def api_cls(self) -> t.Type[API]: + """The subclass of RoutedAPI this middleware uses.""" + raise NotImplementedError + + def __init__(self, app: ASGIApp) -> None: + self.app = app + self.apis: t.Dict[str, API] = {} + + def add_api( + self, specification: t.Union[pathlib.Path, str, dict], **kwargs + ) -> None: + api = self.api_cls(specification, next_app=self.app, **kwargs) + self.apis[api.base_path] = api + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """Fetches the operation related to the request and calls it.""" + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + try: + connexion_context = scope["extensions"][ROUTING_CONTEXT] + except KeyError: + raise MissingMiddleware( + "Could not find routing information in scope. Please make sure " + "you have a routing middleware registered upstream. " + ) + api_base_path = connexion_context.get("api_base_path") + if api_base_path: + api = self.apis[api_base_path] + operation_id = connexion_context.get("operation_id") + try: + operation = api.operations[operation_id] + except KeyError as e: + if operation_id is None: + logger.debug("Skipping validation check for operation without id.") + await self.app(scope, receive, send) + return + else: + raise MissingOperation("Encountered unknown operation_id.") from e + else: + return await operation(scope, receive, send) + + await self.app(scope, receive, send) + + +class MissingOperation(Exception): + """Missing operation""" diff --git a/connexion/middleware/exceptions.py b/connexion/middleware/exceptions.py index 81ea509ea..e23102299 100644 --- a/connexion/middleware/exceptions.py +++ b/connexion/middleware/exceptions.py @@ -39,7 +39,7 @@ def problem_handler(self, _, exception: ProblemException): def http_exception(self, request: Request, exc: HTTPException) -> Response: try: - headers = exc.headers + headers = exc.headers # type: ignore except AttributeError: # Starlette < 0.19 headers = {} diff --git a/connexion/middleware/routing.py b/connexion/middleware/routing.py index 54d409597..905ecb315 100644 --- a/connexion/middleware/routing.py +++ b/connexion/middleware/routing.py @@ -6,61 +6,36 @@ from starlette.types import ASGIApp, Receive, Scope, Send from connexion.apis import AbstractRoutingAPI -from connexion.middleware import AppMiddleware +from connexion.middleware.abstract import ROUTING_CONTEXT, AppMiddleware from connexion.operations import AbstractOperation from connexion.resolver import Resolver -ROUTING_CONTEXT = "connexion_routing" - - _scope: ContextVar[dict] = ContextVar("SCOPE") -class RoutingMiddleware(AppMiddleware): - def __init__(self, app: ASGIApp) -> None: - """Middleware that resolves the Operation for an incoming request and attaches it to the - scope. - - :param app: app to wrap in middleware. - """ - self.app = app - # Pass unknown routes to next app - self.router = Router(default=RoutingOperation(None, self.app)) - - def add_api( - self, - specification: t.Union[pathlib.Path, str, dict], - base_path: t.Optional[str] = None, - arguments: t.Optional[dict] = None, - **kwargs - ) -> None: - """Add an API to the router based on a OpenAPI spec. +class RoutingOperation: + def __init__(self, operation_id: t.Optional[str], next_app: ASGIApp) -> None: + self.operation_id = operation_id + self.next_app = next_app - :param specification: OpenAPI spec as dict or path to file. - :param base_path: Base path where to add this API. - :param arguments: Jinja arguments to replace in the spec. - """ - api = RoutingAPI( - specification, - base_path=base_path, - arguments=arguments, - next_app=self.app, - **kwargs - ) - self.router.mount(api.base_path, app=api.router) + @classmethod + def from_operation(cls, operation: AbstractOperation, next_app: ASGIApp): + return cls(operation.operation_id, next_app) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - """Route request to matching operation, and attach it to the scope before calling the - next app.""" - if scope["type"] != "http": - await self.app(scope, receive, send) - return + """Attach operation to scope and pass it to the next app""" + original_scope = _scope.get() - _scope.set(scope.copy()) # type: ignore + api_base_path = scope.get("root_path", "")[ + len(original_scope.get("root_path", "")) : + ] - # Needs to be set so starlette router throws exceptions instead of returning error responses - scope["app"] = self - await self.router(scope, receive, send) + extensions = original_scope.setdefault("extensions", {}) + connexion_routing = extensions.setdefault(ROUTING_CONTEXT, {}) + connexion_routing.update( + {"api_base_path": api_base_path, "operation_id": self.operation_id} + ) + await self.next_app(original_scope, receive, send) class RoutingAPI(AbstractRoutingAPI): @@ -105,26 +80,48 @@ def _add_operation_internal( self.router.add_route(path, operation, methods=[method]) -class RoutingOperation: - def __init__(self, operation_id: t.Optional[str], next_app: ASGIApp) -> None: - self.operation_id = operation_id - self.next_app = next_app +class RoutingMiddleware(AppMiddleware): + def __init__(self, app: ASGIApp) -> None: + """Middleware that resolves the Operation for an incoming request and attaches it to the + scope. - @classmethod - def from_operation(cls, operation: AbstractOperation, next_app: ASGIApp): - return cls(operation.operation_id, next_app) + :param app: app to wrap in middleware. + """ + self.app = app + # Pass unknown routes to next app + self.router = Router(default=RoutingOperation(None, self.app)) + + def add_api( + self, + specification: t.Union[pathlib.Path, str, dict], + base_path: t.Optional[str] = None, + arguments: t.Optional[dict] = None, + **kwargs + ) -> None: + """Add an API to the router based on a OpenAPI spec. + + :param specification: OpenAPI spec as dict or path to file. + :param base_path: Base path where to add this API. + :param arguments: Jinja arguments to replace in the spec. + """ + api = RoutingAPI( + specification, + base_path=base_path, + arguments=arguments, + next_app=self.app, + **kwargs + ) + self.router.mount(api.base_path, app=api.router) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - """Attach operation to scope and pass it to the next app""" - original_scope = _scope.get() + """Route request to matching operation, and attach it to the scope before calling the + next app.""" + if scope["type"] != "http": + await self.app(scope, receive, send) + return - api_base_path = scope.get("root_path", "")[ - len(original_scope.get("root_path", "")) : - ] + _scope.set(scope.copy()) # type: ignore - extensions = original_scope.setdefault("extensions", {}) - connexion_routing = extensions.setdefault(ROUTING_CONTEXT, {}) - connexion_routing.update( - {"api_base_path": api_base_path, "operation_id": self.operation_id} - ) - await self.next_app(original_scope, receive, send) + # Needs to be set so starlette router throws exceptions instead of returning error responses + scope["app"] = self + await self.router(scope, receive, send) diff --git a/connexion/middleware/security.py b/connexion/middleware/security.py index 64566174a..33e15ba83 100644 --- a/connexion/middleware/security.py +++ b/connexion/middleware/security.py @@ -1,136 +1,28 @@ import logging -import pathlib import typing as t from collections import defaultdict from starlette.types import ASGIApp, Receive, Scope, Send -from connexion.apis.abstract import AbstractSpecAPI -from connexion.exceptions import MissingMiddleware, ProblemException -from connexion.http_facts import METHODS +from connexion.exceptions import ProblemException from connexion.lifecycle import MiddlewareRequest -from connexion.middleware import AppMiddleware -from connexion.middleware.routing import ROUTING_CONTEXT +from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware from connexion.operations import AbstractOperation -from connexion.resolver import ResolverError from connexion.security import SecurityHandlerFactory -from connexion.spec import Specification logger = logging.getLogger("connexion.middleware.security") -class SecurityMiddleware(AppMiddleware): - """Middleware to check if operation is accessible on scope.""" - - def __init__(self, app: ASGIApp) -> None: - self.app = app - self.apis: t.Dict[str, SecurityAPI] = {} - - def add_api( - self, specification: t.Union[pathlib.Path, str, dict], **kwargs - ) -> None: - api = SecurityAPI(specification, **kwargs) - self.apis[api.base_path] = api - - async def __call__(self, scope: Scope, receive: Receive, send: Send): - if scope["type"] != "http": - await self.app(scope, receive, send) - return - - try: - connexion_context = scope["extensions"][ROUTING_CONTEXT] - except KeyError: - raise MissingMiddleware( - "Could not find routing information in scope. Please make sure " - "you have a routing middleware registered upstream. " - ) - - api_base_path = connexion_context.get("api_base_path") - if api_base_path: - api = self.apis[api_base_path] - operation_id = connexion_context.get("operation_id") - try: - operation = api.operations[operation_id] - except KeyError as e: - if operation_id is None: - logger.debug( - "Skipping security check for operation without id. Enable " - "`auth_all_paths` to check security for unknown operations." - ) - else: - raise MissingSecurityOperation( - "Encountered unknown operation_id." - ) from e - - else: - request = MiddlewareRequest(scope) - await operation(request) - - await self.app(scope, receive, send) - - -class SecurityAPI(AbstractSpecAPI): - def __init__( - self, - specification: t.Union[pathlib.Path, str, dict], - auth_all_paths: bool = False, - *args, - **kwargs - ): - super().__init__(specification, *args, **kwargs) - self.security_handler_factory = SecurityHandlerFactory("context") - - if auth_all_paths: - self.add_auth_on_not_found() - else: - self.operations: t.Dict[str, SecurityOperation] = {} - - self.add_paths() - - def add_auth_on_not_found(self): - """Register a default SecurityOperation for routes that are not found.""" - default_operation = self.make_operation(self.specification) - self.operations = defaultdict(lambda: default_operation) - - def add_paths(self): - paths = self.specification.get("paths", {}) - for path, methods in paths.items(): - for method in methods: - if method not in METHODS: - continue - try: - self.add_operation(path, method) - except ResolverError: - # ResolverErrors are either raised or handled in routing middleware. - pass - - def add_operation(self, path: str, method: str) -> None: - operation_cls = self.specification.operation_cls - operation = operation_cls.from_spec( - self.specification, self, path, method, self.resolver - ) - security_operation = self.make_operation(operation) - self._add_operation_internal(operation.operation_id, security_operation) - - def make_operation(self, operation: t.Union[AbstractOperation, Specification]): - return SecurityOperation.from_operation( - operation, - security_handler_factory=self.security_handler_factory, - ) - - def _add_operation_internal( - self, operation_id: str, operation: "SecurityOperation" - ): - self.operations[operation_id] = operation - - class SecurityOperation: def __init__( self, + next_app: ASGIApp, + *, security_handler_factory: SecurityHandlerFactory, security: list, security_schemes: dict, ): + self.next_app = next_app self.security_handler_factory = security_handler_factory self.security = security self.security_schemes = security_schemes @@ -139,12 +31,14 @@ def __init__( @classmethod def from_operation( cls, - operation: t.Union[AbstractOperation, Specification], + operation: AbstractOperation, + *, + next_app: ASGIApp, security_handler_factory: SecurityHandlerFactory, - ): - # TODO: Turn Operation class into OperationSpec and use as init argument instead + ) -> "SecurityOperation": return cls( - security_handler_factory, + next_app=next_app, + security_handler_factory=security_handler_factory, security=operation.security, security_schemes=operation.security_schemes, ) @@ -304,8 +198,47 @@ def _get_verification_fn(self): return self.security_handler_factory.verify_security(auth_funcs) - async def __call__(self, request: MiddlewareRequest): + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + request = MiddlewareRequest(scope) await self.verification_fn(request) + await self.next_app(scope, receive, send) + + +class SecurityAPI(RoutedAPI[SecurityOperation]): + + operation_cls = SecurityOperation + + def __init__(self, *args, auth_all_paths: bool = False, **kwargs): + super().__init__(*args, **kwargs) + + self.security_handler_factory = SecurityHandlerFactory("context") + + if auth_all_paths: + self.add_auth_on_not_found() + else: + self.operations: t.MutableMapping[str, SecurityOperation] = {} + + self.add_paths() + + def add_auth_on_not_found(self) -> None: + """Register a default SecurityOperation for routes that are not found.""" + default_operation = self.make_operation(self.specification) + self.operations = defaultdict(lambda: default_operation) + + def make_operation(self, operation: AbstractOperation) -> SecurityOperation: + return SecurityOperation.from_operation( + operation, + next_app=self.next_app, + security_handler_factory=self.security_handler_factory, + ) + + +class SecurityMiddleware(RoutedMiddleware[SecurityAPI]): + """Middleware to check if operation is accessible on scope.""" + + @property + def api_cls(self) -> t.Type[SecurityAPI]: + return SecurityAPI class MissingSecurityOperation(ProblemException): diff --git a/connexion/middleware/swagger_ui.py b/connexion/middleware/swagger_ui.py index 71fb5f75e..6bf5666f1 100644 --- a/connexion/middleware/swagger_ui.py +++ b/connexion/middleware/swagger_ui.py @@ -21,56 +21,6 @@ _original_scope: ContextVar[Scope] = ContextVar("SCOPE") -class SwaggerUIMiddleware(AppMiddleware): - def __init__(self, app: ASGIApp) -> None: - """Middleware that hosts a swagger UI. - - :param app: app to wrap in middleware. - """ - self.app = app - # Set default to pass unknown routes to next app - self.router = Router(default=self.default_fn) - - def add_api( - self, - specification: t.Union[pathlib.Path, str, dict], - base_path: t.Optional[str] = None, - arguments: t.Optional[dict] = None, - **kwargs - ) -> None: - """Add an API to the router based on a OpenAPI spec. - - :param specification: OpenAPI spec as dict or path to file. - :param base_path: Base path where to add this API. - :param arguments: Jinja arguments to replace in the spec. - """ - api = SwaggerUIAPI( - specification, - base_path=base_path, - arguments=arguments, - default=self.default_fn, - **kwargs - ) - self.router.mount(api.base_path, app=api.router) - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - _original_scope.set(scope.copy()) # type: ignore - await self.router(scope, receive, send) - - async def default_fn(self, _scope: Scope, receive: Receive, send: Send) -> None: - """ - Callback to call next app as default when no matching route is found. - - Unfortunately we cannot just pass the next app as default, since the router manipulates - the scope when descending into mounts, losing information about the base path. Therefore, - we use the original scope instead. - - This is caused by https://github.com/encode/starlette/issues/1336. - """ - original_scope = _original_scope.get() - await self.app(original_scope, receive, send) - - class SwaggerUIAPI(AbstractSwaggerUIAPI): def __init__(self, *args, default: ASGIApp, **kwargs): self.router = Router(default=default) @@ -206,3 +156,53 @@ async def _get_swagger_ui_config(self, request): media_type="application/json", content=self.jsonifier.dumps(self.options.openapi_console_ui_config), ) + + +class SwaggerUIMiddleware(AppMiddleware): + def __init__(self, app: ASGIApp) -> None: + """Middleware that hosts a swagger UI. + + :param app: app to wrap in middleware. + """ + self.app = app + # Set default to pass unknown routes to next app + self.router = Router(default=self.default_fn) + + def add_api( + self, + specification: t.Union[pathlib.Path, str, dict], + base_path: t.Optional[str] = None, + arguments: t.Optional[dict] = None, + **kwargs + ) -> None: + """Add an API to the router based on a OpenAPI spec. + + :param specification: OpenAPI spec as dict or path to file. + :param base_path: Base path where to add this API. + :param arguments: Jinja arguments to replace in the spec. + """ + api = SwaggerUIAPI( + specification, + base_path=base_path, + arguments=arguments, + default=self.default_fn, + **kwargs + ) + self.router.mount(api.base_path, app=api.router) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + _original_scope.set(scope.copy()) # type: ignore + await self.router(scope, receive, send) + + async def default_fn(self, _scope: Scope, receive: Receive, send: Send) -> None: + """ + Callback to call next app as default when no matching route is found. + + Unfortunately we cannot just pass the next app as default, since the router manipulates + the scope when descending into mounts, losing information about the base path. Therefore, + we use the original scope instead. + + This is caused by https://github.com/encode/starlette/issues/1336. + """ + original_scope = _original_scope.get() + await self.app(original_scope, receive, send) diff --git a/connexion/middleware/validation.py b/connexion/middleware/validation.py index 0b7a050eb..ec8ab25cb 100644 --- a/connexion/middleware/validation.py +++ b/connexion/middleware/validation.py @@ -2,19 +2,14 @@ Validation Middleware. """ import logging -import pathlib import typing as t from starlette.types import ASGIApp, Receive, Scope, Send -from connexion.apis.abstract import AbstractSpecAPI from connexion.decorators.uri_parsing import AbstractURIParser -from connexion.exceptions import MissingMiddleware, UnsupportedMediaTypeProblem -from connexion.http_facts import METHODS -from connexion.middleware import AppMiddleware -from connexion.middleware.routing import ROUTING_CONTEXT +from connexion.exceptions import UnsupportedMediaTypeProblem +from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware from connexion.operations import AbstractOperation -from connexion.resolver import ResolverError from connexion.utils import is_nullable from connexion.validators import JSONBodyValidator @@ -30,130 +25,19 @@ } -class ValidationMiddleware(AppMiddleware): - """Middleware for validating requests according to the API contract.""" - - def __init__(self, app: ASGIApp) -> None: - self.app = app - self.apis: t.Dict[str, ValidationAPI] = {} - - def add_api( - self, specification: t.Union[pathlib.Path, str, dict], **kwargs - ) -> None: - api = ValidationAPI(specification, next_app=self.app, **kwargs) - self.apis[api.base_path] = api - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - if scope["type"] != "http": - await self.app(scope, receive, send) - return - - try: - connexion_context = scope["extensions"][ROUTING_CONTEXT] - except KeyError: - raise MissingMiddleware( - "Could not find routing information in scope. Please make sure " - "you have a routing middleware registered upstream. " - ) - api_base_path = connexion_context.get("api_base_path") - if api_base_path: - api = self.apis[api_base_path] - operation_id = connexion_context.get("operation_id") - try: - operation = api.operations[operation_id] - except KeyError as e: - if operation_id is None: - logger.debug("Skipping validation check for operation without id.") - await self.app(scope, receive, send) - return - else: - raise MissingValidationOperation( - "Encountered unknown operation_id." - ) from e - else: - return await operation(scope, receive, send) - - await self.app(scope, receive, send) - - -class ValidationAPI(AbstractSpecAPI): - """Validation API.""" - - def __init__( - self, - specification: t.Union[pathlib.Path, str, dict], - *args, - next_app: ASGIApp, - validate_responses=False, - strict_validation=False, - validator_map=None, - uri_parser_class=None, - **kwargs, - ): - super().__init__(specification, *args, **kwargs) - self.next_app = next_app - - self.validator_map = validator_map - - logger.debug("Validate Responses: %s", str(validate_responses)) - self.validate_responses = validate_responses - - logger.debug("Strict Request Validation: %s", str(strict_validation)) - self.strict_validation = strict_validation - - self.uri_parser_class = uri_parser_class - - self.operations: t.Dict[str, ValidationOperation] = {} - self.add_paths() - - def add_paths(self): - paths = self.specification.get("paths", {}) - for path, methods in paths.items(): - for method in methods: - if method not in METHODS: - continue - try: - self.add_operation(path, method) - except ResolverError: - # ResolverErrors are either raised or handled in routing middleware. - pass - - def add_operation(self, path: str, method: str) -> None: - operation_cls = self.specification.operation_cls - operation = operation_cls.from_spec( - self.specification, self, path, method, self.resolver - ) - validation_operation = self.make_operation(operation) - self._add_operation_internal(operation.operation_id, validation_operation) - - def make_operation(self, operation: AbstractOperation): - return ValidationOperation( - operation, - self.next_app, - validate_responses=self.validate_responses, - strict_validation=self.strict_validation, - validator_map=self.validator_map, - uri_parser_class=self.uri_parser_class, - ) - - def _add_operation_internal( - self, operation_id: str, operation: "ValidationOperation" - ): - self.operations[operation_id] = operation - - class ValidationOperation: def __init__( self, - operation: AbstractOperation, next_app: ASGIApp, + *, + operation: AbstractOperation, validate_responses: bool = False, strict_validation: bool = False, validator_map: t.Optional[dict] = None, uri_parser_class: t.Optional[AbstractURIParser] = None, ) -> None: - self._operation = operation self.next_app = next_app + self._operation = operation self.validate_responses = validate_responses self.strict_validation = strict_validation self._validator_map = VALIDATOR_MAP @@ -228,5 +112,51 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send): await self.next_app(scope, receive, send) +class ValidationAPI(RoutedAPI[ValidationOperation]): + """Validation API.""" + + operation_cls = ValidationOperation + + def __init__( + self, + *args, + validate_responses=False, + strict_validation=False, + validator_map=None, + uri_parser_class=None, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.validator_map = validator_map + + logger.debug("Validate Responses: %s", str(validate_responses)) + self.validate_responses = validate_responses + + logger.debug("Strict Request Validation: %s", str(strict_validation)) + self.strict_validation = strict_validation + + self.uri_parser_class = uri_parser_class + + self.add_paths() + + def make_operation(self, operation: AbstractOperation) -> ValidationOperation: + return ValidationOperation( + self.next_app, + operation=operation, + validate_responses=self.validate_responses, + strict_validation=self.strict_validation, + validator_map=self.validator_map, + uri_parser_class=self.uri_parser_class, + ) + + +class ValidationMiddleware(RoutedMiddleware[ValidationAPI]): + """Middleware for validating requests according to the API contract.""" + + @property + def api_cls(self) -> t.Type[ValidationAPI]: + return ValidationAPI + + class MissingValidationOperation(Exception): """Missing validation operation""" diff --git a/tests/test_operation2.py b/tests/test_operation2.py index bac5be61f..76863d72c 100644 --- a/tests/test_operation2.py +++ b/tests/test_operation2.py @@ -425,6 +425,7 @@ def test_operation_remote_token_info(security_handler_factory): ) SecurityOperation( + next_app=mock.AsyncMock, security_handler_factory=security_handler_factory, security=[{"oauth": ["uid"]}], security_schemes=SECURITY_DEFINITIONS_REMOTE, @@ -494,6 +495,7 @@ def test_operation_local_security_oauth2(security_handler_factory): security_handler_factory.verify_oauth = verify_oauth SecurityOperation( + next_app=mock.AsyncMock, security_handler_factory=security_handler_factory, security=[{"oauth": ["uid"]}], security_schemes=SECURITY_DEFINITIONS_LOCAL, @@ -509,7 +511,8 @@ def test_operation_local_security_duplicate_token_info(security_handler_factory) security_handler_factory.verify_oauth = verify_oauth SecurityOperation( - security_handler_factory, + next_app=mock.AsyncMock, + security_handler_factory=security_handler_factory, security=[{"oauth": ["uid"]}], security_schemes=SECURITY_DEFINITIONS_BOTH, ) @@ -545,6 +548,7 @@ def test_multi_body(api): def test_no_token_info(security_handler_factory): SecurityOperation( + next_app=mock.AsyncMock, security_handler_factory=security_handler_factory, security=[{"oauth": ["uid"]}], security_schemes=SECURITY_DEFINITIONS_WO_INFO, @@ -565,6 +569,7 @@ def return_api_key_name(func, in_, name): security = [{"key1": [], "key2": []}] SecurityOperation( + next_app=mock.AsyncMock, security_handler_factory=security_handler_factory, security=security, security_schemes=SECURITY_DEFINITIONS_2_KEYS, @@ -589,6 +594,7 @@ def test_multiple_oauth_in_and(security_handler_factory, caplog): security = [{"oauth_1": ["uid"], "oauth_2": ["uid"]}] SecurityOperation( + next_app=mock.AsyncMock, security_handler_factory=security_handler_factory, security=security, security_schemes=SECURITY_DEFINITIONS_2_OAUTH, @@ -685,6 +691,7 @@ def test_oauth_scopes_in_or(security_handler_factory): security = [{"oauth": ["myscope"]}, {"oauth": ["myscope2"]}] SecurityOperation( + next_app=mock.AsyncMock, security_handler_factory=security_handler_factory, security=security, security_schemes=SECURITY_DEFINITIONS_LOCAL, From 927a144cbccf149e77e079803b9b76b6e6c1adfa Mon Sep 17 00:00:00 2001 From: Robbe Sneyders Date: Fri, 16 Sep 2022 22:24:44 +0200 Subject: [PATCH 2/9] Use typing_extensions for Python 3.7 Protocol support --- connexion/middleware/abstract.py | 3 ++- setup.py | 2 ++ tests/test_operation2.py | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/connexion/middleware/abstract.py b/connexion/middleware/abstract.py index 2654fbcb1..ebfb5cabb 100644 --- a/connexion/middleware/abstract.py +++ b/connexion/middleware/abstract.py @@ -3,6 +3,7 @@ import pathlib import typing as t +import typing_extensions as te from starlette.types import ASGIApp, Receive, Scope, Send from connexion.apis.abstract import AbstractSpecAPI @@ -27,7 +28,7 @@ def add_api( pass -class RoutedOperation(t.Protocol): +class RoutedOperation(te.Protocol): def __init__(self, next_app: ASGIApp, **kwargs) -> None: ... diff --git a/setup.py b/setup.py index c9c1a1117..eb0ef2940 100755 --- a/setup.py +++ b/setup.py @@ -28,6 +28,8 @@ def read_version(package): 'werkzeug>=2.2.1,<3', 'starlette>=0.15,<1', 'httpx>=0.15,<1', + 'typing-extensions>=4,<5', + 'mock>=3,<4' ] swagger_ui_require = 'swagger-ui-bundle>=0.0.2,<0.1' diff --git a/tests/test_operation2.py b/tests/test_operation2.py index 76863d72c..1d5ec8bd4 100644 --- a/tests/test_operation2.py +++ b/tests/test_operation2.py @@ -3,8 +3,8 @@ import math import pathlib import types -from unittest import mock +import mock import pytest from connexion.apis.flask_api import Jsonifier from connexion.exceptions import InvalidSpecification From 24b5725911f05bef16b315a4fc0330c035c17f5a Mon Sep 17 00:00:00 2001 From: Robbe Sneyders Date: Fri, 16 Sep 2022 23:40:02 +0200 Subject: [PATCH 3/9] Use Mock instead of AsyncMock --- setup.py | 1 - tests/test_operation2.py | 16 ++++++++-------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/setup.py b/setup.py index eb0ef2940..ea5db549e 100755 --- a/setup.py +++ b/setup.py @@ -29,7 +29,6 @@ def read_version(package): 'starlette>=0.15,<1', 'httpx>=0.15,<1', 'typing-extensions>=4,<5', - 'mock>=3,<4' ] swagger_ui_require = 'swagger-ui-bundle>=0.0.2,<0.1' diff --git a/tests/test_operation2.py b/tests/test_operation2.py index 1d5ec8bd4..cd37d14ff 100644 --- a/tests/test_operation2.py +++ b/tests/test_operation2.py @@ -3,8 +3,8 @@ import math import pathlib import types +from unittest import mock -import mock import pytest from connexion.apis.flask_api import Jsonifier from connexion.exceptions import InvalidSpecification @@ -425,7 +425,7 @@ def test_operation_remote_token_info(security_handler_factory): ) SecurityOperation( - next_app=mock.AsyncMock, + next_app=mock.Mock, security_handler_factory=security_handler_factory, security=[{"oauth": ["uid"]}], security_schemes=SECURITY_DEFINITIONS_REMOTE, @@ -495,7 +495,7 @@ def test_operation_local_security_oauth2(security_handler_factory): security_handler_factory.verify_oauth = verify_oauth SecurityOperation( - next_app=mock.AsyncMock, + next_app=mock.Mock, security_handler_factory=security_handler_factory, security=[{"oauth": ["uid"]}], security_schemes=SECURITY_DEFINITIONS_LOCAL, @@ -511,7 +511,7 @@ def test_operation_local_security_duplicate_token_info(security_handler_factory) security_handler_factory.verify_oauth = verify_oauth SecurityOperation( - next_app=mock.AsyncMock, + next_app=mock.Mock, security_handler_factory=security_handler_factory, security=[{"oauth": ["uid"]}], security_schemes=SECURITY_DEFINITIONS_BOTH, @@ -548,7 +548,7 @@ def test_multi_body(api): def test_no_token_info(security_handler_factory): SecurityOperation( - next_app=mock.AsyncMock, + next_app=mock.Mock, security_handler_factory=security_handler_factory, security=[{"oauth": ["uid"]}], security_schemes=SECURITY_DEFINITIONS_WO_INFO, @@ -569,7 +569,7 @@ def return_api_key_name(func, in_, name): security = [{"key1": [], "key2": []}] SecurityOperation( - next_app=mock.AsyncMock, + next_app=mock.Mock, security_handler_factory=security_handler_factory, security=security, security_schemes=SECURITY_DEFINITIONS_2_KEYS, @@ -594,7 +594,7 @@ def test_multiple_oauth_in_and(security_handler_factory, caplog): security = [{"oauth_1": ["uid"], "oauth_2": ["uid"]}] SecurityOperation( - next_app=mock.AsyncMock, + next_app=mock.Mock, security_handler_factory=security_handler_factory, security=security, security_schemes=SECURITY_DEFINITIONS_2_OAUTH, @@ -691,7 +691,7 @@ def test_oauth_scopes_in_or(security_handler_factory): security = [{"oauth": ["myscope"]}, {"oauth": ["myscope2"]}] SecurityOperation( - next_app=mock.AsyncMock, + next_app=mock.Mock, security_handler_factory=security_handler_factory, security=security, security_schemes=SECURITY_DEFINITIONS_LOCAL, From 141ba1c6c7f089f36291935ade436ad06b93d310 Mon Sep 17 00:00:00 2001 From: Robbe Sneyders Date: Sun, 18 Sep 2022 00:27:52 +0200 Subject: [PATCH 4/9] Extract response validation to middleware --- connexion/decorators/response.py | 135 -------------- connexion/decorators/validation.py | 25 +-- connexion/middleware/abstract.py | 1 + connexion/middleware/main.py | 6 +- .../{validation.py => request_validation.py} | 36 ++-- connexion/middleware/response_validation.py | 166 ++++++++++++++++++ connexion/operations/abstract.py | 17 -- connexion/validators.py | 98 ++++++++++- tests/api/test_headers.py | 2 +- tests/api/test_schema.py | 28 +-- tests/fakeapi/example_method_view.py | 17 +- tests/test_json_validation.py | 5 +- tests/test_resolver_methodview.py | 8 +- 13 files changed, 304 insertions(+), 240 deletions(-) delete mode 100644 connexion/decorators/response.py rename connexion/middleware/{validation.py => request_validation.py} (81%) create mode 100644 connexion/middleware/response_validation.py diff --git a/connexion/decorators/response.py b/connexion/decorators/response.py deleted file mode 100644 index 98511259a..000000000 --- a/connexion/decorators/response.py +++ /dev/null @@ -1,135 +0,0 @@ -""" -This module defines a view function decorator to validate its responses. -""" - -import asyncio -import functools -import logging - -from jsonschema import ValidationError - -from ..exceptions import NonConformingResponseBody, NonConformingResponseHeaders -from ..utils import all_json, has_coroutine -from .decorator import BaseDecorator -from .validation import ResponseBodyValidator - -logger = logging.getLogger("connexion.decorators.response") - - -class ResponseValidator(BaseDecorator): - def __init__(self, operation, mimetype, validator=None): - """ - :type operation: Operation - :type mimetype: str - :param validator: Validator class that should be used to validate passed data - against API schema. - :type validator: jsonschema.IValidator - """ - self.operation = operation - self.mimetype = mimetype - self.validator = validator - - def validate_response(self, data, status_code, headers, url): - """ - Validates the Response object based on what has been declared in the specification. - Ensures the response body matches the declared schema. - :type data: dict - :type status_code: int - :type headers: dict - :rtype bool | None - """ - # check against returned header, fall back to expected mimetype - content_type = headers.get("Content-Type", self.mimetype) - content_type = content_type.rsplit(";", 1)[ - 0 - ] # remove things like utf8 metadata - - response_definition = self.operation.response_definition( - str(status_code), content_type - ) - response_schema = self.operation.response_schema(str(status_code), content_type) - - if self.is_json_schema_compatible(response_schema): - v = ResponseBodyValidator(response_schema, validator=self.validator) - try: - data = self.operation.json_loads(data) - v.validate_schema(data, url) - except ValidationError as e: - raise NonConformingResponseBody(message=str(e)) - - if response_definition and response_definition.get("headers"): - required_header_keys = { - k - for (k, v) in response_definition.get("headers").items() - if v.get("required", False) - } - header_keys = set(headers.keys()) - missing_keys = required_header_keys - header_keys - if missing_keys: - pretty_list = ", ".join(missing_keys) - msg = ( - "Keys in header don't match response specification. " - "Difference: {}" - ).format(pretty_list) - raise NonConformingResponseHeaders(message=msg) - return True - - def is_json_schema_compatible(self, response_schema: dict) -> bool: - """ - Verify if the specified operation responses are JSON schema - compatible. - - All operations that specify a JSON schema and have content - type "application/json" or "text/plain" can be validated using - json_schema package. - """ - if not response_schema: - return False - return all_json([self.mimetype]) or self.mimetype == "text/plain" - - def __call__(self, function): - """ - :type function: types.FunctionType - :rtype: types.FunctionType - """ - - def _wrapper(request, response): - connexion_response = self.operation.api.get_connexion_response( - response, self.mimetype - ) - if not connexion_response.is_streamed: - self.validate_response( - connexion_response.body, - connexion_response.status_code, - connexion_response.headers, - request.url, - ) - else: - logger.warning("Skipping response validation for streamed response.") - - return response - - if has_coroutine(function): - - @functools.wraps(function) - async def wrapper(request): - response = function(request) - while asyncio.iscoroutine(response): - response = await response - - return _wrapper(request, response) - - else: # pragma: no cover - - @functools.wraps(function) - def wrapper(request): - response = function(request) - return _wrapper(request, response) - - return wrapper - - def __repr__(self): - """ - :rtype: str - """ - return "" # pragma: no cover diff --git a/connexion/decorators/validation.py b/connexion/decorators/validation.py index bfa004ed2..83a258d67 100644 --- a/connexion/decorators/validation.py +++ b/connexion/decorators/validation.py @@ -14,7 +14,7 @@ from ..exceptions import BadRequestProblem, ExtraParameterProblem from ..http_facts import FORM_CONTENT_TYPES -from ..json_schema import Draft4RequestValidator, Draft4ResponseValidator +from ..json_schema import Draft4RequestValidator from ..lifecycle import ConnexionResponse from ..utils import boolean, is_null, is_nullable @@ -196,29 +196,6 @@ def validate_schema(self, data: dict, url: str) -> t.Optional[ConnexionResponse] return None -class ResponseBodyValidator: - def __init__(self, schema, validator=None): - """ - :param schema: The schema of the response body - :param validator: Validator class that should be used to validate passed data - against API schema. Default is Draft4ResponseValidator. - :type validator: jsonschema.IValidator - """ - ValidatorClass = validator or Draft4ResponseValidator - self.validator = ValidatorClass(schema, format_checker=draft4_format_checker) - - def validate_schema(self, data: dict, url: str) -> t.Optional[ConnexionResponse]: - try: - self.validator.validate(data) - except ValidationError as exception: - logger.error( - f"{url} validation error: {exception}", extra={"validator": "response"} - ) - raise exception - - return None - - class ParameterValidator: def __init__(self, parameters, api, strict_validation=False): """ diff --git a/connexion/middleware/abstract.py b/connexion/middleware/abstract.py index ebfb5cabb..dd69d525a 100644 --- a/connexion/middleware/abstract.py +++ b/connexion/middleware/abstract.py @@ -118,6 +118,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: try: connexion_context = scope["extensions"][ROUTING_CONTEXT] except KeyError: + # TODO: update message raise MissingMiddleware( "Could not find routing information in scope. Please make sure " "you have a routing middleware registered upstream. " diff --git a/connexion/middleware/main.py b/connexion/middleware/main.py index 950cdb532..2e2acea1b 100644 --- a/connexion/middleware/main.py +++ b/connexion/middleware/main.py @@ -5,10 +5,11 @@ from connexion.middleware.abstract import AppMiddleware from connexion.middleware.exceptions import ExceptionMiddleware +from connexion.middleware.request_validation import RequestValidationMiddleware +from connexion.middleware.response_validation import ResponseValidationMiddleware from connexion.middleware.routing import RoutingMiddleware from connexion.middleware.security import SecurityMiddleware from connexion.middleware.swagger_ui import SwaggerUIMiddleware -from connexion.middleware.validation import ValidationMiddleware class ConnexionMiddleware: @@ -18,7 +19,8 @@ class ConnexionMiddleware: SwaggerUIMiddleware, RoutingMiddleware, SecurityMiddleware, - ValidationMiddleware, + RequestValidationMiddleware, + ResponseValidationMiddleware, ] def __init__( diff --git a/connexion/middleware/validation.py b/connexion/middleware/request_validation.py similarity index 81% rename from connexion/middleware/validation.py rename to connexion/middleware/request_validation.py index ec8ab25cb..89468d6e3 100644 --- a/connexion/middleware/validation.py +++ b/connexion/middleware/request_validation.py @@ -11,34 +11,23 @@ from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware from connexion.operations import AbstractOperation from connexion.utils import is_nullable -from connexion.validators import JSONBodyValidator - -from ..decorators.response import ResponseValidator -from ..decorators.validation import ParameterValidator +from connexion.validators import VALIDATOR_MAP logger = logging.getLogger("connexion.middleware.validation") -VALIDATOR_MAP = { - "parameter": ParameterValidator, - "body": {"application/json": JSONBodyValidator}, - "response": ResponseValidator, -} - -class ValidationOperation: +class RequestValidationOperation: def __init__( self, next_app: ASGIApp, *, operation: AbstractOperation, - validate_responses: bool = False, strict_validation: bool = False, validator_map: t.Optional[dict] = None, uri_parser_class: t.Optional[AbstractURIParser] = None, ) -> None: self.next_app = next_app self._operation = operation - self.validate_responses = validate_responses self.strict_validation = strict_validation self._validator_map = VALIDATOR_MAP self._validator_map.update(validator_map or {}) @@ -112,15 +101,14 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send): await self.next_app(scope, receive, send) -class ValidationAPI(RoutedAPI[ValidationOperation]): +class RequestValidationAPI(RoutedAPI[RequestValidationOperation]): """Validation API.""" - operation_cls = ValidationOperation + operation_cls = RequestValidationOperation def __init__( self, *args, - validate_responses=False, strict_validation=False, validator_map=None, uri_parser_class=None, @@ -129,9 +117,6 @@ def __init__( super().__init__(*args, **kwargs) self.validator_map = validator_map - logger.debug("Validate Responses: %s", str(validate_responses)) - self.validate_responses = validate_responses - logger.debug("Strict Request Validation: %s", str(strict_validation)) self.strict_validation = strict_validation @@ -139,23 +124,24 @@ def __init__( self.add_paths() - def make_operation(self, operation: AbstractOperation) -> ValidationOperation: - return ValidationOperation( + def make_operation( + self, operation: AbstractOperation + ) -> RequestValidationOperation: + return RequestValidationOperation( self.next_app, operation=operation, - validate_responses=self.validate_responses, strict_validation=self.strict_validation, validator_map=self.validator_map, uri_parser_class=self.uri_parser_class, ) -class ValidationMiddleware(RoutedMiddleware[ValidationAPI]): +class RequestValidationMiddleware(RoutedMiddleware[RequestValidationAPI]): """Middleware for validating requests according to the API contract.""" @property - def api_cls(self) -> t.Type[ValidationAPI]: - return ValidationAPI + def api_cls(self) -> t.Type[RequestValidationAPI]: + return RequestValidationAPI class MissingValidationOperation(Exception): diff --git a/connexion/middleware/response_validation.py b/connexion/middleware/response_validation.py new file mode 100644 index 000000000..25b45a340 --- /dev/null +++ b/connexion/middleware/response_validation.py @@ -0,0 +1,166 @@ +""" +Validation Middleware. +""" +import logging +import typing as t + +from starlette.types import ASGIApp, Receive, Scope, Send + +from connexion.exceptions import NonConformingResponseHeaders +from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware +from connexion.operations import AbstractOperation +from connexion.utils import is_nullable +from connexion.validators import VALIDATOR_MAP + +logger = logging.getLogger("connexion.middleware.validation") + + +class ResponseValidationOperation: + def __init__( + self, + next_app: ASGIApp, + *, + operation: AbstractOperation, + validator_map: t.Optional[dict] = None, + ) -> None: + self.next_app = next_app + self._operation = operation + self._validator_map = VALIDATOR_MAP + self._validator_map.update(validator_map or {}) + + def extract_content_type(self, headers: dict) -> t.Tuple[str, str]: + """Extract the mime type and encoding from the content type headers. + + :param headers: Header dict from ASGI scope + + :return: A tuple of mime type, encoding + """ + encoding = "utf-8" + for key, value in headers: + # Headers can always be decoded using latin-1: + # https://stackoverflow.com/a/27357138/4098821 + key = key.decode("latin-1") + if key.lower() == "content-type": + content_type = value.decode("latin-1") + if ";" in content_type: + mime_type, parameters = content_type.split(";", maxsplit=1) + + prefix = "charset=" + for parameter in parameters.split(";"): + if parameter.startswith(prefix): + encoding = parameter[len(prefix) :] + else: + mime_type = content_type + break + else: + # Content-type header is not required. Take a best guess. + mime_type = self._operation.consumes[0] + + return mime_type, encoding + + def validate_mime_type(self, mime_type: str) -> None: + """Validate the mime type against the spec. + + :param mime_type: mime type from content type header + """ + if mime_type.lower() not in [c.lower() for c in self._operation.produces]: + raise NonConformingResponseHeaders( + reason="Invalid Response Content-type", + message=f"Invalid Response Content-type ({mime_type}), " + f"expected {self._operation.produces}", + ) + + def validate_required_headers( + self, headers: t.List[tuple], response_definition: dict + ) -> None: + required_header_keys = { + k.lower() + for (k, v) in response_definition.get("headers", {}).items() + if v.get("required", False) + } + header_keys = set(header[0].decode("latin-1").lower() for header in headers) + missing_keys = required_header_keys - header_keys + if missing_keys: + pretty_list = ", ".join(missing_keys) + msg = ( + "Keys in header don't match response specification. " "Difference: {}" + ).format(pretty_list) + raise NonConformingResponseHeaders(message=msg) + + async def __call__(self, scope: Scope, receive: Receive, send: Send): + + send_fn = send + + async def wrapped_send(message: t.MutableMapping[str, t.Any]) -> None: + nonlocal send_fn + + if message["type"] == "http.response.start": + status = str(message["status"]) + headers = message["headers"] + mime_type, encoding = self.extract_content_type(headers) + # TODO: Add produces to all tests and fix response content types + # self.validate_mime_type(mime_type) + response_definition = self._operation.response_definition( + status, mime_type + ) + self.validate_required_headers(headers, response_definition) + + # Validate body + try: + body_validator = self._validator_map["response"][mime_type] # type: ignore + except KeyError: + logging.info( + f"Skipping validation. No validator registered for content type: " + f"{mime_type}." + ) + else: + validator = body_validator( + scope, + send, + schema=self._operation.response_schema(status, mime_type), + nullable=is_nullable(self._operation.body_definition), + encoding=encoding, + ) + send_fn = validator.send + + return await send_fn(message) + + await self.next_app(scope, receive, wrapped_send) + + +class ResponseValidationAPI(RoutedAPI[ResponseValidationOperation]): + """Validation API.""" + + operation_cls = ResponseValidationOperation + + def __init__( + self, + *args, + validator_map=None, + validate_responses=False, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.validator_map = validator_map + self.validate_responses = validate_responses + self.add_paths() + + def make_operation( + self, operation: AbstractOperation + ) -> ResponseValidationOperation: + if self.validate_responses: + return ResponseValidationOperation( + self.next_app, + operation=operation, + validator_map=self.validator_map, + ) + else: + return self.next_app # type: ignore + + +class ResponseValidationMiddleware(RoutedMiddleware[ResponseValidationAPI]): + """Middleware for validating requests according to the API contract.""" + + @property + def api_cls(self) -> t.Type[ResponseValidationAPI]: + return ResponseValidationAPI diff --git a/connexion/operations/abstract.py b/connexion/operations/abstract.py index be6d971e9..65825c888 100644 --- a/connexion/operations/abstract.py +++ b/connexion/operations/abstract.py @@ -9,7 +9,6 @@ from ..decorators.decorator import RequestResponseDecorator from ..decorators.parameter import parameter_to_arg from ..decorators.produces import BaseSerializer, Produces -from ..decorators.response import ResponseValidator from ..decorators.validation import ParameterValidator, RequestBodyValidator from ..utils import all_json, is_nullable @@ -20,7 +19,6 @@ VALIDATOR_MAP = { "parameter": ParameterValidator, "body": RequestBodyValidator, - "response": ResponseValidator, } @@ -395,12 +393,6 @@ def function(self): self._pass_context_arg_name, ) - if self.validate_responses: - logger.debug("... Response validation enabled.") - response_decorator = self.__response_validation_decorator - logger.debug("... Adding response decorator (%r)", response_decorator) - function = response_decorator(function) - produces_decorator = self.__content_type_decorator logger.debug("... Adding produces decorator (%r)", produces_decorator) function = produces_decorator(function) @@ -479,15 +471,6 @@ def __validation_decorators(self): strict_validation=self.strict_validation, ) - @property - def __response_validation_decorator(self): - """ - Get a decorator for validating the generated Response. - :rtype: types.FunctionType - """ - ResponseValidator = self.validator_map["response"] - return ResponseValidator(self, self.get_mimetype()) - def json_loads(self, data): """ A wrapper for calling the API specific JSON loader. diff --git a/connexion/validators.py b/connexion/validators.py index 4d726a20d..50f9078d3 100644 --- a/connexion/validators.py +++ b/connexion/validators.py @@ -8,14 +8,15 @@ from jsonschema import Draft4Validator, ValidationError, draft4_format_checker from starlette.types import ASGIApp, Receive, Scope, Send -from connexion.exceptions import BadRequestProblem -from connexion.json_schema import Draft4RequestValidator +from connexion.decorators.validation import ParameterValidator +from connexion.exceptions import BadRequestProblem, NonConformingResponseBody +from connexion.json_schema import Draft4RequestValidator, Draft4ResponseValidator from connexion.utils import is_null logger = logging.getLogger("connexion.middleware.validators") -class JSONBodyValidator: +class JSONRequestBodyValidator: """Request body validator for json content types.""" def __init__( @@ -31,10 +32,8 @@ def __init__( self.schema = schema self.has_default = schema.get("default", False) self.nullable = nullable - self.validator_cls = validator or Draft4RequestValidator - self.validator = self.validator_cls( - schema, format_checker=draft4_format_checker - ) + validator_cls = validator or Draft4RequestValidator + self.validator = validator_cls(schema, format_checker=draft4_format_checker) self.encoding = encoding @classmethod @@ -44,7 +43,6 @@ def _error_path_message(cls, exception): return error_path_msg def validate(self, body: dict): - try: self.validator.validate(body) except ValidationError as exception: @@ -85,3 +83,87 @@ async def wrapped_receive(): return await receive() await self.next_app(scope, wrapped_receive, send) + + +class JSONResponseBodyValidator: + def __init__( + self, + scope: Scope, + send: Send, + *, + schema: dict, + validator: t.Type[Draft4Validator] = None, + nullable=False, + encoding: str, + ) -> None: + self._scope = scope + self._send = send + self.schema = schema + self.has_default = schema.get("default", False) + self.nullable = nullable + validator_cls = validator or Draft4ResponseValidator + self.validator = validator_cls(schema, format_checker=draft4_format_checker) + self.encoding = encoding + self._messages: t.List[t.MutableMapping[str, t.Any]] = [] + + @classmethod + def _error_path_message(cls, exception): + error_path = ".".join(str(item) for item in exception.path) + error_path_msg = f" - '{error_path}'" if error_path else "" + return error_path_msg + + def validate(self, body: dict): + try: + self.validator.validate(body) + except ValidationError as exception: + error_path_msg = self._error_path_message(exception=exception) + logger.error( + f"Validation error: {exception.message}{error_path_msg}", + extra={"validator": "body"}, + ) + raise NonConformingResponseBody( + message=f"{exception.message}{error_path_msg}" + ) + + @staticmethod + def parse(body: str) -> dict: + try: + return json.loads(body) + except json.decoder.JSONDecodeError as e: + raise BadRequestProblem(str(e)) + + async def send(self, message: t.MutableMapping[str, t.Any]) -> None: + self._messages.append(message) + + if message["type"] == "http.response.start" or message.get("more_body", False): + return + + # TODO: make json library pluggable + bytes_body = b"".join([message.get("body", b"") for message in self._messages]) + decoded_body = bytes_body.decode(self.encoding) + + if decoded_body and not (self.nullable and is_null(decoded_body)): + body = self.parse(decoded_body) + self.validate(body) + + while self._messages: + await self._send(self._messages.pop(0)) + + +class TextResonseBodyValidator(JSONResponseBodyValidator): + @staticmethod + def parse(body: str) -> str: # type: ignore + try: + return json.loads(body) + except json.decoder.JSONDecodeError: + return body + + +VALIDATOR_MAP = { + "parameter": ParameterValidator, + "body": {"application/json": JSONRequestBodyValidator}, + "response": { + "application/json": JSONResponseBodyValidator, + "text/plain": TextResonseBodyValidator, + }, +} diff --git a/tests/api/test_headers.py b/tests/api/test_headers.py index 709b8821f..6b6ec4236 100644 --- a/tests/api/test_headers.py +++ b/tests/api/test_headers.py @@ -34,7 +34,7 @@ def test_header_not_returned(simple_openapi_app): assert data["title"] == "Response headers do not conform to specification" assert ( data["detail"] - == "Keys in header don't match response specification. Difference: Location" + == "Keys in header don't match response specification. Difference: location" ) assert data["status"] == 500 diff --git a/tests/api/test_schema.py b/tests/api/test_schema.py index ec20cfd9a..fa48c35d7 100644 --- a/tests/api/test_schema.py +++ b/tests/api/test_schema.py @@ -80,59 +80,59 @@ def test_schema_response(schema_app): request = app_client.get( "/v1.0/test_schema/response/object/valid", headers={}, data=None ) # type: flask.Response - assert request.status_code == 200 + assert request.status_code == 200, request.text request = app_client.get( "/v1.0/test_schema/response/object/invalid_type", headers={}, data=None ) # type: flask.Response - assert request.status_code == 500 + assert request.status_code == 500, request.text request = app_client.get( "/v1.0/test_schema/response/object/invalid_requirements", headers={}, data=None ) # type: flask.Response - assert request.status_code == 500 + assert request.status_code == 500, request.text request = app_client.get( "/v1.0/test_schema/response/string/valid", headers={}, data=None ) # type: flask.Response - assert request.status_code == 200 + assert request.status_code == 200, request.text request = app_client.get( "/v1.0/test_schema/response/string/invalid", headers={}, data=None ) # type: flask.Response - assert request.status_code == 500 + assert request.status_code == 500, request.text request = app_client.get( "/v1.0/test_schema/response/integer/valid", headers={}, data=None ) # type: flask.Response - assert request.status_code == 200 + assert request.status_code == 200, request.text request = app_client.get( "/v1.0/test_schema/response/integer/invalid", headers={}, data=None ) # type: flask.Response - assert request.status_code == 500 + assert request.status_code == 500, request.text request = app_client.get( "/v1.0/test_schema/response/number/valid", headers={}, data=None ) # type: flask.Response - assert request.status_code == 200 + assert request.status_code == 200, request.text request = app_client.get( "/v1.0/test_schema/response/number/invalid", headers={}, data=None ) # type: flask.Response - assert request.status_code == 500 + assert request.status_code == 500, request.text request = app_client.get( "/v1.0/test_schema/response/boolean/valid", headers={}, data=None ) # type: flask.Response - assert request.status_code == 200 + assert request.status_code == 200, request.text request = app_client.get( "/v1.0/test_schema/response/boolean/invalid", headers={}, data=None ) # type: flask.Response - assert request.status_code == 500 + assert request.status_code == 500, request.text request = app_client.get( "/v1.0/test_schema/response/array/valid", headers={}, data=None ) # type: flask.Response - assert request.status_code == 200 + assert request.status_code == 200, request.text request = app_client.get( "/v1.0/test_schema/response/array/invalid_dict", headers={}, data=None ) # type: flask.Response - assert request.status_code == 500 + assert request.status_code == 500, request.text request = app_client.get( "/v1.0/test_schema/response/array/invalid_string", headers={}, data=None ) # type: flask.Response - assert request.status_code == 500 + assert request.status_code == 500, request.text def test_schema_in_query(schema_app): diff --git a/tests/fakeapi/example_method_view.py b/tests/fakeapi/example_method_view.py index f66d2b549..d6ad0e93b 100644 --- a/tests/fakeapi/example_method_view.py +++ b/tests/fakeapi/example_method_view.py @@ -6,19 +6,22 @@ class PetsView(MethodView): mycontent = "demonstrate return from MethodView class" def get(self, **kwargs): - kwargs.update({"method": "get"}) - return kwargs + if kwargs: + kwargs.update({"name": "get"}) + return kwargs + else: + return [{"name": "get"}] def search(self): - return "search" + return [{"name": "search"}] def post(self, **kwargs): - kwargs.update({"method": "post"}) - return kwargs + kwargs.update({"name": "post"}) + return kwargs, 201 def put(self, *args, **kwargs): - kwargs.update({"method": "put"}) - return kwargs + kwargs.update({"name": "put"}) + return kwargs, 201 # Test that operation_id can still override resolver diff --git a/tests/test_json_validation.py b/tests/test_json_validation.py index 3857fc3d8..36916dd0b 100644 --- a/tests/test_json_validation.py +++ b/tests/test_json_validation.py @@ -3,10 +3,9 @@ import pytest from connexion import App -from connexion.decorators.validation import RequestBodyValidator from connexion.json_schema import Draft4RequestValidator from connexion.spec import Specification -from connexion.validators import JSONBodyValidator +from connexion.validators import JSONRequestBodyValidator from jsonschema.validators import _utils, extend from conftest import build_app_from_fixture @@ -31,7 +30,7 @@ def validate_type(validator, types, instance, schema): MinLengthRequestValidator = extend(Draft4RequestValidator, {"type": validate_type}) - class MyJSONBodyValidator(JSONBodyValidator): + class MyJSONBodyValidator(JSONRequestBodyValidator): def __init__(self, *args, **kwargs): super().__init__(*args, validator=MinLengthRequestValidator, **kwargs) diff --git a/tests/test_resolver_methodview.py b/tests/test_resolver_methodview.py index 06436860a..5d315e1ac 100644 --- a/tests/test_resolver_methodview.py +++ b/tests/test_resolver_methodview.py @@ -192,13 +192,13 @@ def test_method_view_resolver_integration(method_view_app): client = method_view_app.app.test_client() r = client.get("/v1.0/pets") - assert r.json == {"method": "get"} + assert r.json == [{"name": "get"}] r = client.get("/v1.0/pets/1") - assert r.json == {"method": "get", "petId": 1} + assert r.json == {"name": "get", "petId": 1} r = client.post("/v1.0/pets", json={"name": "Musti"}) - assert r.json == {"method": "post", "body": {"name": "Musti"}} + assert r.json == {"name": "post", "body": {"name": "Musti"}} r = client.put("/v1.0/pets/1", json={"name": "Igor"}) - assert r.json == {"method": "put", "petId": 1, "body": {"name": "Igor"}} + assert r.json == {"name": "put", "petId": 1, "body": {"name": "Igor"}} From 08ff7654a8deb964a193f1aeda27362363d91a3f Mon Sep 17 00:00:00 2001 From: Robbe Sneyders Date: Sun, 18 Sep 2022 00:48:17 +0200 Subject: [PATCH 5/9] Refactor Request validation to match Response validation --- connexion/middleware/abstract.py | 6 ++-- connexion/middleware/request_validation.py | 13 ++++----- connexion/middleware/response_validation.py | 3 +- connexion/validators.py | 31 +++++++++------------ 4 files changed, 25 insertions(+), 28 deletions(-) diff --git a/connexion/middleware/abstract.py b/connexion/middleware/abstract.py index dd69d525a..9e109e56a 100644 --- a/connexion/middleware/abstract.py +++ b/connexion/middleware/abstract.py @@ -118,7 +118,6 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: try: connexion_context = scope["extensions"][ROUTING_CONTEXT] except KeyError: - # TODO: update message raise MissingMiddleware( "Could not find routing information in scope. Please make sure " "you have a routing middleware registered upstream. " @@ -131,7 +130,10 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: operation = api.operations[operation_id] except KeyError as e: if operation_id is None: - logger.debug("Skipping validation check for operation without id.") + logger.debug( + f"Skipping {self.__class__.__name__} check for operation without " + f"id." + ) await self.app(scope, receive, send) return else: diff --git a/connexion/middleware/request_validation.py b/connexion/middleware/request_validation.py index 89468d6e3..3a4b6504a 100644 --- a/connexion/middleware/request_validation.py +++ b/connexion/middleware/request_validation.py @@ -75,6 +75,8 @@ def validate_mime_type(self, mime_type: str) -> None: ) async def __call__(self, scope: Scope, receive: Receive, send: Send): + receive_fn = receive + headers = scope["headers"] mime_type, encoding = self.extract_content_type(headers) self.validate_mime_type(mime_type) @@ -91,14 +93,15 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send): ) else: validator = body_validator( - self.next_app, + scope, + receive, schema=self._operation.body_schema, nullable=is_nullable(self._operation.body_definition), encoding=encoding, ) - return await validator(scope, receive, send) + receive_fn = validator.receive - await self.next_app(scope, receive, send) + await self.next_app(scope, receive_fn, send) class RequestValidationAPI(RoutedAPI[RequestValidationOperation]): @@ -142,7 +145,3 @@ class RequestValidationMiddleware(RoutedMiddleware[RequestValidationAPI]): @property def api_cls(self) -> t.Type[RequestValidationAPI]: return RequestValidationAPI - - -class MissingValidationOperation(Exception): - """Missing validation operation""" diff --git a/connexion/middleware/response_validation.py b/connexion/middleware/response_validation.py index 25b45a340..7d17bd73c 100644 --- a/connexion/middleware/response_validation.py +++ b/connexion/middleware/response_validation.py @@ -70,8 +70,9 @@ def validate_mime_type(self, mime_type: str) -> None: f"expected {self._operation.produces}", ) + @staticmethod def validate_required_headers( - self, headers: t.List[tuple], response_definition: dict + headers: t.List[tuple], response_definition: dict ) -> None: required_header_keys = { k.lower() diff --git a/connexion/validators.py b/connexion/validators.py index 50f9078d3..c19851abd 100644 --- a/connexion/validators.py +++ b/connexion/validators.py @@ -6,7 +6,7 @@ import typing as t from jsonschema import Draft4Validator, ValidationError, draft4_format_checker -from starlette.types import ASGIApp, Receive, Scope, Send +from starlette.types import Receive, Scope, Send from connexion.decorators.validation import ParameterValidator from connexion.exceptions import BadRequestProblem, NonConformingResponseBody @@ -21,20 +21,23 @@ class JSONRequestBodyValidator: def __init__( self, - next_app: ASGIApp, + scope: Scope, + receive: Receive, *, schema: dict, validator: t.Type[Draft4Validator] = None, nullable=False, encoding: str, ) -> None: - self.next_app = next_app + self._scope = scope + self._receive = receive self.schema = schema self.has_default = schema.get("default", False) self.nullable = nullable validator_cls = validator or Draft4RequestValidator self.validator = validator_cls(schema, format_checker=draft4_format_checker) self.encoding = encoding + self._messages: t.List[t.MutableMapping[str, t.Any]] = [] @classmethod def _error_path_message(cls, exception): @@ -53,18 +56,15 @@ def validate(self, body: dict): ) raise BadRequestProblem(detail=f"{exception.message}{error_path_msg}") - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - # Based on https://github.com/encode/starlette/pull/1519#issuecomment-1060633787 - # Ingest all body messages from the ASGI `receive` callable. - messages = [] + async def receive(self) -> t.Optional[t.MutableMapping[str, t.Any]]: more_body = True while more_body: - message = await receive() - messages.append(message) + message = await self._receive() + self._messages.append(message) more_body = message.get("more_body", False) # TODO: make json library pluggable - bytes_body = b"".join([message.get("body", b"") for message in messages]) + bytes_body = b"".join([message.get("body", b"") for message in self._messages]) decoded_body = bytes_body.decode(self.encoding) if decoded_body and not (self.nullable and is_null(decoded_body)): @@ -75,14 +75,9 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: self.validate(body) - async def wrapped_receive(): - # First up we want to return any messages we've stashed. - if messages: - return messages.pop(0) - # Once that's done we can just await any other messages. - return await receive() - - await self.next_app(scope, wrapped_receive, send) + while self._messages: + return self._messages.pop(0) + return None class JSONResponseBodyValidator: From f4dcdbe3fc34002827e32b03826834e9e6ad08b5 Mon Sep 17 00:00:00 2001 From: Robbe Sneyders Date: Sun, 18 Sep 2022 12:16:26 +0200 Subject: [PATCH 6/9] Factor out shared functionality --- connexion/middleware/request_validation.py | 37 ++++++++------------- connexion/middleware/response_validation.py | 37 ++++++++------------- connexion/middleware/security.py | 5 --- connexion/utils.py | 30 +++++++++++++++++ connexion/validators.py | 2 ++ 5 files changed, 60 insertions(+), 51 deletions(-) diff --git a/connexion/middleware/request_validation.py b/connexion/middleware/request_validation.py index 3a4b6504a..93dc5898c 100644 --- a/connexion/middleware/request_validation.py +++ b/connexion/middleware/request_validation.py @@ -6,11 +6,11 @@ from starlette.types import ASGIApp, Receive, Scope, Send +from connexion import utils from connexion.decorators.uri_parsing import AbstractURIParser from connexion.exceptions import UnsupportedMediaTypeProblem from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware from connexion.operations import AbstractOperation -from connexion.utils import is_nullable from connexion.validators import VALIDATOR_MAP logger = logging.getLogger("connexion.middleware.validation") @@ -33,33 +33,24 @@ def __init__( self._validator_map.update(validator_map or {}) self.uri_parser_class = uri_parser_class - def extract_content_type(self, headers: dict) -> t.Tuple[str, str]: + def extract_content_type( + self, headers: t.List[t.Tuple[bytes, bytes]] + ) -> t.Tuple[str, str]: """Extract the mime type and encoding from the content type headers. - :param headers: Header dict from ASGI scope + :param headers: Headers from ASGI scope :return: A tuple of mime type, encoding """ - encoding = "utf-8" - for key, value in headers: - # Headers can always be decoded using latin-1: - # https://stackoverflow.com/a/27357138/4098821 - key = key.decode("latin-1") - if key.lower() == "content-type": - content_type = value.decode("latin-1") - if ";" in content_type: - mime_type, parameters = content_type.split(";", maxsplit=1) - - prefix = "charset=" - for parameter in parameters.split(";"): - if parameter.startswith(prefix): - encoding = parameter[len(prefix) :] - else: - mime_type = content_type - break - else: + mime_type, encoding = utils.extract_content_type(headers) + if mime_type is None: # Content-type header is not required. Take a best guess. - mime_type = self._operation.consumes[0] + try: + mime_type = self._operation.consumes[0] + except IndexError: + mime_type = "application/octet-stream" + if encoding is None: + encoding = "utf-8" return mime_type, encoding @@ -96,7 +87,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send): scope, receive, schema=self._operation.body_schema, - nullable=is_nullable(self._operation.body_definition), + nullable=utils.is_nullable(self._operation.body_definition), encoding=encoding, ) receive_fn = validator.receive diff --git a/connexion/middleware/response_validation.py b/connexion/middleware/response_validation.py index 7d17bd73c..b33d081c9 100644 --- a/connexion/middleware/response_validation.py +++ b/connexion/middleware/response_validation.py @@ -6,10 +6,10 @@ from starlette.types import ASGIApp, Receive, Scope, Send +from connexion import utils from connexion.exceptions import NonConformingResponseHeaders from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware from connexion.operations import AbstractOperation -from connexion.utils import is_nullable from connexion.validators import VALIDATOR_MAP logger = logging.getLogger("connexion.middleware.validation") @@ -28,33 +28,24 @@ def __init__( self._validator_map = VALIDATOR_MAP self._validator_map.update(validator_map or {}) - def extract_content_type(self, headers: dict) -> t.Tuple[str, str]: + def extract_content_type( + self, headers: t.List[t.Tuple[bytes, bytes]] + ) -> t.Tuple[str, str]: """Extract the mime type and encoding from the content type headers. - :param headers: Header dict from ASGI scope + :param headers: Headers from ASGI scope :return: A tuple of mime type, encoding """ - encoding = "utf-8" - for key, value in headers: - # Headers can always be decoded using latin-1: - # https://stackoverflow.com/a/27357138/4098821 - key = key.decode("latin-1") - if key.lower() == "content-type": - content_type = value.decode("latin-1") - if ";" in content_type: - mime_type, parameters = content_type.split(";", maxsplit=1) - - prefix = "charset=" - for parameter in parameters.split(";"): - if parameter.startswith(prefix): - encoding = parameter[len(prefix) :] - else: - mime_type = content_type - break - else: + mime_type, encoding = utils.extract_content_type(headers) + if mime_type is None: # Content-type header is not required. Take a best guess. - mime_type = self._operation.consumes[0] + try: + mime_type = self._operation.produces[0] + except IndexError: + mime_type = "application/octet-stream" + if encoding is None: + encoding = "utf-8" return mime_type, encoding @@ -119,7 +110,7 @@ async def wrapped_send(message: t.MutableMapping[str, t.Any]) -> None: scope, send, schema=self._operation.response_schema(status, mime_type), - nullable=is_nullable(self._operation.body_definition), + nullable=utils.is_nullable(self._operation.body_definition), encoding=encoding, ) send_fn = validator.send diff --git a/connexion/middleware/security.py b/connexion/middleware/security.py index 33e15ba83..4320dc616 100644 --- a/connexion/middleware/security.py +++ b/connexion/middleware/security.py @@ -4,7 +4,6 @@ from starlette.types import ASGIApp, Receive, Scope, Send -from connexion.exceptions import ProblemException from connexion.lifecycle import MiddlewareRequest from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware from connexion.operations import AbstractOperation @@ -239,7 +238,3 @@ class SecurityMiddleware(RoutedMiddleware[SecurityAPI]): @property def api_cls(self) -> t.Type[SecurityAPI]: return SecurityAPI - - -class MissingSecurityOperation(ProblemException): - pass diff --git a/connexion/utils.py b/connexion/utils.py index e92d403ce..cc6cdd632 100644 --- a/connexion/utils.py +++ b/connexion/utils.py @@ -5,6 +5,7 @@ import asyncio import functools import importlib +import typing as t import yaml @@ -266,3 +267,32 @@ def _required_lib(exc, *args, **kwargs): raise exc return functools.partial(_required_lib, exc) + + +def extract_content_type( + headers: t.List[t.Tuple[bytes, bytes]] +) -> t.Tuple[t.Optional[str], t.Optional[str]]: + """Extract the mime type and encoding from the content type headers. + + :param headers: Headers from ASGI scope + + :return: A tuple of mime type, encoding + """ + mime_type, encoding = None, None + for key, value in headers: + # Headers can always be decoded using latin-1: + # https://stackoverflow.com/a/27357138/4098821 + decoded_key = key.decode("latin-1") + if decoded_key.lower() == "content-type": + content_type = value.decode("latin-1") + if ";" in content_type: + mime_type, parameters = content_type.split(";", maxsplit=1) + + prefix = "charset=" + for parameter in parameters.split(";"): + if parameter.startswith(prefix): + encoding = parameter[len(prefix) :] + else: + mime_type = content_type + break + return mime_type, encoding diff --git a/connexion/validators.py b/connexion/validators.py index c19851abd..55c7dd41e 100644 --- a/connexion/validators.py +++ b/connexion/validators.py @@ -81,6 +81,8 @@ async def receive(self) -> t.Optional[t.MutableMapping[str, t.Any]]: class JSONResponseBodyValidator: + """Response body validator for json content types.""" + def __init__( self, scope: Scope, From 88a003f826dbf40e8edb07136cd80edf15702d72 Mon Sep 17 00:00:00 2001 From: Robbe Sneyders Date: Fri, 23 Sep 2022 21:30:18 +0200 Subject: [PATCH 7/9] Fix typo in TextResponseBodyValidator class name --- connexion/validators.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/connexion/validators.py b/connexion/validators.py index 55c7dd41e..dff5c41bc 100644 --- a/connexion/validators.py +++ b/connexion/validators.py @@ -147,7 +147,7 @@ async def send(self, message: t.MutableMapping[str, t.Any]) -> None: await self._send(self._messages.pop(0)) -class TextResonseBodyValidator(JSONResponseBodyValidator): +class TextResponseBodyValidator(JSONResponseBodyValidator): @staticmethod def parse(body: str) -> str: # type: ignore try: @@ -161,6 +161,6 @@ def parse(body: str) -> str: # type: ignore "body": {"application/json": JSONRequestBodyValidator}, "response": { "application/json": JSONResponseBodyValidator, - "text/plain": TextResonseBodyValidator, + "text/plain": TextResponseBodyValidator, }, } From c996561f62e3878f00af67a76267c73fc54fcfeb Mon Sep 17 00:00:00 2001 From: Robbe Sneyders Date: Mon, 26 Sep 2022 21:18:54 +0200 Subject: [PATCH 8/9] Fix string formatting --- connexion/middleware/response_validation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/connexion/middleware/response_validation.py b/connexion/middleware/response_validation.py index 19d09fe13..f736b0581 100644 --- a/connexion/middleware/response_validation.py +++ b/connexion/middleware/response_validation.py @@ -75,7 +75,7 @@ def validate_required_headers( if missing_keys: pretty_list = ", ".join(missing_keys) msg = ( - "Keys in header don't match response specification. " "Difference: {}" + "Keys in header don't match response specification. Difference: {}" ).format(pretty_list) raise NonConformingResponseHeaders(message=msg) From dde3dd330f5aab3baea27ba20b2b07f656b7875b Mon Sep 17 00:00:00 2001 From: Robbe Sneyders Date: Mon, 3 Oct 2022 22:26:49 +0200 Subject: [PATCH 9/9] Use correct schema to check nullability in response validation --- connexion/middleware/response_validation.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/connexion/middleware/response_validation.py b/connexion/middleware/response_validation.py index f736b0581..bcd66b8c9 100644 --- a/connexion/middleware/response_validation.py +++ b/connexion/middleware/response_validation.py @@ -110,7 +110,9 @@ async def wrapped_send(message: t.MutableMapping[str, t.Any]) -> None: scope, send, schema=self._operation.response_schema(status, mime_type), - nullable=utils.is_nullable(self._operation.body_definition), + nullable=utils.is_nullable( + self._operation.response_definition(status, mime_type) + ), encoding=encoding, ) send_fn = validator.send