Skip to content
Merged
Prev Previous commit
Next Next commit
Extract response validation to middleware
  • Loading branch information
RobbeSneyders committed Sep 18, 2022
commit 141ba1c6c7f089f36291935ade436ad06b93d310
135 changes: 0 additions & 135 deletions connexion/decorators/response.py

This file was deleted.

25 changes: 1 addition & 24 deletions connexion/decorators/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from ..exceptions import BadRequestProblem, ExtraParameterProblem
from ..http_facts import FORM_CONTENT_TYPES
from ..json_schema import Draft4RequestValidator, Draft4ResponseValidator
from ..json_schema import Draft4RequestValidator
from ..lifecycle import ConnexionResponse
from ..utils import boolean, is_null, is_nullable

Expand Down Expand Up @@ -196,29 +196,6 @@ def validate_schema(self, data: dict, url: str) -> t.Optional[ConnexionResponse]
return None


class ResponseBodyValidator:
def __init__(self, schema, validator=None):
"""
:param schema: The schema of the response body
:param validator: Validator class that should be used to validate passed data
against API schema. Default is Draft4ResponseValidator.
:type validator: jsonschema.IValidator
"""
ValidatorClass = validator or Draft4ResponseValidator
self.validator = ValidatorClass(schema, format_checker=draft4_format_checker)

def validate_schema(self, data: dict, url: str) -> t.Optional[ConnexionResponse]:
try:
self.validator.validate(data)
except ValidationError as exception:
logger.error(
f"{url} validation error: {exception}", extra={"validator": "response"}
)
raise exception

return None


class ParameterValidator:
def __init__(self, parameters, api, strict_validation=False):
"""
Expand Down
1 change: 1 addition & 0 deletions connexion/middleware/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
try:
connexion_context = scope["extensions"][ROUTING_CONTEXT]
except KeyError:
# TODO: update message
raise MissingMiddleware(
"Could not find routing information in scope. Please make sure "
"you have a routing middleware registered upstream. "
Expand Down
6 changes: 4 additions & 2 deletions connexion/middleware/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

from connexion.middleware.abstract import AppMiddleware
from connexion.middleware.exceptions import ExceptionMiddleware
from connexion.middleware.request_validation import RequestValidationMiddleware
from connexion.middleware.response_validation import ResponseValidationMiddleware
from connexion.middleware.routing import RoutingMiddleware
from connexion.middleware.security import SecurityMiddleware
from connexion.middleware.swagger_ui import SwaggerUIMiddleware
from connexion.middleware.validation import ValidationMiddleware


class ConnexionMiddleware:
Expand All @@ -18,7 +19,8 @@ class ConnexionMiddleware:
SwaggerUIMiddleware,
RoutingMiddleware,
SecurityMiddleware,
ValidationMiddleware,
RequestValidationMiddleware,
ResponseValidationMiddleware,
]

def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,34 +11,23 @@
from connexion.middleware.abstract import RoutedAPI, RoutedMiddleware
from connexion.operations import AbstractOperation
from connexion.utils import is_nullable
from connexion.validators import JSONBodyValidator

from ..decorators.response import ResponseValidator
from ..decorators.validation import ParameterValidator
from connexion.validators import VALIDATOR_MAP

logger = logging.getLogger("connexion.middleware.validation")

VALIDATOR_MAP = {
"parameter": ParameterValidator,
"body": {"application/json": JSONBodyValidator},
"response": ResponseValidator,
}


class ValidationOperation:
class RequestValidationOperation:
def __init__(
self,
next_app: ASGIApp,
*,
operation: AbstractOperation,
validate_responses: bool = False,
strict_validation: bool = False,
validator_map: t.Optional[dict] = None,
uri_parser_class: t.Optional[AbstractURIParser] = None,
) -> None:
self.next_app = next_app
self._operation = operation
self.validate_responses = validate_responses
self.strict_validation = strict_validation
self._validator_map = VALIDATOR_MAP
self._validator_map.update(validator_map or {})
Expand Down Expand Up @@ -112,15 +101,14 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send):
await self.next_app(scope, receive, send)


class ValidationAPI(RoutedAPI[ValidationOperation]):
class RequestValidationAPI(RoutedAPI[RequestValidationOperation]):
"""Validation API."""

operation_cls = ValidationOperation
operation_cls = RequestValidationOperation

def __init__(
self,
*args,
validate_responses=False,
strict_validation=False,
validator_map=None,
uri_parser_class=None,
Expand All @@ -129,33 +117,31 @@ def __init__(
super().__init__(*args, **kwargs)
self.validator_map = validator_map

logger.debug("Validate Responses: %s", str(validate_responses))
self.validate_responses = validate_responses

logger.debug("Strict Request Validation: %s", str(strict_validation))
self.strict_validation = strict_validation

self.uri_parser_class = uri_parser_class

self.add_paths()

def make_operation(self, operation: AbstractOperation) -> ValidationOperation:
return ValidationOperation(
def make_operation(
self, operation: AbstractOperation
) -> RequestValidationOperation:
return RequestValidationOperation(
self.next_app,
operation=operation,
validate_responses=self.validate_responses,
strict_validation=self.strict_validation,
validator_map=self.validator_map,
uri_parser_class=self.uri_parser_class,
)


class ValidationMiddleware(RoutedMiddleware[ValidationAPI]):
class RequestValidationMiddleware(RoutedMiddleware[RequestValidationAPI]):
"""Middleware for validating requests according to the API contract."""

@property
def api_cls(self) -> t.Type[ValidationAPI]:
return ValidationAPI
def api_cls(self) -> t.Type[RequestValidationAPI]:
return RequestValidationAPI


class MissingValidationOperation(Exception):
Expand Down
Loading