Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Use async/await in security
  • Loading branch information
cognifloyd committed Jun 4, 2020
commit caacc78b1ea68236fabb0f9c038e98da6214d83f
5 changes: 2 additions & 3 deletions connexion/security/aiohttp_security_handler_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,12 @@ def get_token_info_remote(self, token_info_url):
:type token_info_url: str
:rtype: types.FunctionType
"""
@asyncio.coroutine
def wrapper(token):
async def wrapper(token):
if not self.client_session:
# Must be created in a coroutine
self.client_session = aiohttp.ClientSession()
headers = {'Authorization': 'Bearer {}'.format(token)}
token_request = yield from self.client_session.get(token_info_url, headers=headers, timeout=5)
token_request = await self.client_session.get(token_info_url, headers=headers, timeout=5)
if token_request.status != 200:
return None
return token_request.json()
Expand Down
17 changes: 7 additions & 10 deletions connexion/security/async_security_handler_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,15 @@ class AbstractAsyncSecurityHandlerFactory(AbstractSecurityHandlerFactory):
def _generic_check(self, func, exception_msg):
need_to_add_context, need_to_add_required_scopes = self._need_to_add_context_or_scopes(func)

@asyncio.coroutine
def wrapper(request, *args, required_scopes=None):
async def wrapper(request, *args, required_scopes=None):
kwargs = {}
if need_to_add_context:
kwargs[self.pass_context_arg_name] = request.context
if need_to_add_required_scopes:
kwargs[self.required_scopes_kw] = required_scopes
token_info = func(*args, **kwargs)
while asyncio.iscoroutine(token_info):
token_info = yield from token_info
token_info = await token_info
if token_info is self.no_value:
return self.no_value
if token_info is None:
Expand All @@ -37,9 +36,8 @@ def check_oauth_func(self, token_info_func, scope_validate_func):
get_token_info = self._generic_check(token_info_func, 'Provided token is not valid')
need_to_add_context, _ = self._need_to_add_context_or_scopes(scope_validate_func)

@asyncio.coroutine
def wrapper(request, token, required_scopes):
token_info = yield from get_token_info(request, token, required_scopes=required_scopes)
async def wrapper(request, token, required_scopes):
token_info = await get_token_info(request, token, required_scopes=required_scopes)

# Fallback to 'scopes' for backward compatibility
token_scopes = token_info.get('scope', token_info.get('scopes', ''))
Expand All @@ -49,7 +47,7 @@ def wrapper(request, token, required_scopes):
kwargs[self.pass_context_arg_name] = request.context
validation = scope_validate_func(required_scopes, token_scopes, **kwargs)
while asyncio.iscoroutine(validation):
validation = yield from validation
validation = await validation
if not validation:
raise OAuthScopeProblem(
description='Provided token doesn\'t have the required scope',
Expand All @@ -62,14 +60,13 @@ def wrapper(request, token, required_scopes):

@classmethod
def verify_security(cls, auth_funcs, required_scopes, function):
@asyncio.coroutine
@functools.wraps(function)
def wrapper(request):
async def wrapper(request):
token_info = None
for func in auth_funcs:
token_info = func(request, required_scopes)
while asyncio.iscoroutine(token_info):
token_info = yield from token_info
token_info = await token_info
if token_info is not cls.no_value:
break

Expand Down
80 changes: 37 additions & 43 deletions tests/aiohttp/test_aiohttp_api_secure.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,13 @@ def __init__(self, status_code, data):
self.data = data
self.ok = status_code == 200

@asyncio.coroutine
def json(self):
async def json(self):
return self.data


@pytest.fixture
def oauth_aiohttp_client(monkeypatch):
@asyncio.coroutine
def fake_get(url, params=None, headers=None, timeout=None):
async def fake_get(url, params=None, headers=None, timeout=None):
"""
:type url: str
:type params: dict| None
Expand All @@ -51,119 +49,115 @@ def fake_get(url, params=None, headers=None, timeout=None):
monkeypatch.setattr('aiohttp.ClientSession', MagicMock(return_value=client_instance))


@asyncio.coroutine
def test_auth_all_paths(oauth_aiohttp_client, aiohttp_api_spec_dir, aiohttp_client):
async def test_auth_all_paths(oauth_aiohttp_client, aiohttp_api_spec_dir, aiohttp_client):
app = AioHttpApp(__name__, port=5001,
specification_dir=aiohttp_api_spec_dir,
debug=True, auth_all_paths=True)
app.add_api('swagger_secure.yaml')

app_client = yield from aiohttp_client(app.app)
app_client = await aiohttp_client(app.app)

get_inexistent_endpoint = yield from app_client.get(
get_inexistent_endpoint = await app_client.get(
'/v1.0/does-not-exist-valid-token',
headers={'Authorization': 'Bearer 100'}
)
assert get_inexistent_endpoint.status == 404
assert get_inexistent_endpoint.content_type == 'application/problem+json'

get_inexistent_endpoint = yield from app_client.get(
get_inexistent_endpoint = await app_client.get(
'/v1.0/does-not-exist-no-token'
)
assert get_inexistent_endpoint.status == 401
assert get_inexistent_endpoint.content_type == 'application/problem+json'


@pytest.mark.parametrize('spec', ['swagger_secure.yaml', 'openapi_secure.yaml'])
@asyncio.coroutine
def test_secure_app(oauth_aiohttp_client, aiohttp_api_spec_dir, aiohttp_client, spec):
async def test_secure_app(oauth_aiohttp_client, aiohttp_api_spec_dir, aiohttp_client, spec):
"""
Test common authentication method between Swagger 2 and OpenApi 3
"""
app = AioHttpApp(__name__, port=5001, specification_dir=aiohttp_api_spec_dir, debug=True)
app.add_api(spec)
app_client = yield from aiohttp_client(app.app)
app_client = await aiohttp_client(app.app)

response = yield from app_client.get('/v1.0/all_auth')
response = await app_client.get('/v1.0/all_auth')
assert response.status == 401
assert response.content_type == 'application/problem+json'

response = yield from app_client.get('/v1.0/all_auth', headers={'Authorization': 'Bearer 100'})
response = await app_client.get('/v1.0/all_auth', headers={'Authorization': 'Bearer 100'})
assert response.status == 200
assert (yield from response.json()) == {"scope": ['myscope'], "uid": 'test-user'}
assert (await response.json()) == {"scope": ['myscope'], "uid": 'test-user'}

response = yield from app_client.get('/v1.0/all_auth', headers={'authorization': 'Bearer 100'})
response = await app_client.get('/v1.0/all_auth', headers={'authorization': 'Bearer 100'})
assert response.status == 200, "Authorization header in lower case should be accepted"
assert (yield from response.json()) == {"scope": ['myscope'], "uid": 'test-user'}
assert (await response.json()) == {"scope": ['myscope'], "uid": 'test-user'}

response = yield from app_client.get('/v1.0/all_auth', headers={'AUTHORIZATION': 'Bearer 100'})
response = await app_client.get('/v1.0/all_auth', headers={'AUTHORIZATION': 'Bearer 100'})
assert response.status == 200, "Authorization header in upper case should be accepted"
assert (yield from response.json()) == {"scope": ['myscope'], "uid": 'test-user'}
assert (await response.json()) == {"scope": ['myscope'], "uid": 'test-user'}

basic_header = 'Basic ' + base64.b64encode(b'username:username').decode('ascii')
response = yield from app_client.get('/v1.0/all_auth', headers={'Authorization': basic_header})
response = await app_client.get('/v1.0/all_auth', headers={'Authorization': basic_header})
assert response.status == 200
assert (yield from response.json()) == {"uid": 'username'}
assert (await response.json()) == {"uid": 'username'}

basic_header = 'Basic ' + base64.b64encode(b'username:wrong').decode('ascii')
response = yield from app_client.get('/v1.0/all_auth', headers={'Authorization': basic_header})
response = await app_client.get('/v1.0/all_auth', headers={'Authorization': basic_header})
assert response.status == 401, "Wrong password should trigger unauthorized"
assert response.content_type == 'application/problem+json'

response = yield from app_client.get('/v1.0/all_auth', headers={'X-API-Key': '{"foo": "bar"}'})
response = await app_client.get('/v1.0/all_auth', headers={'X-API-Key': '{"foo": "bar"}'})
assert response.status == 200
assert (yield from response.json()) == {"foo": "bar"}
assert (await response.json()) == {"foo": "bar"}


@asyncio.coroutine
def test_bearer_secure(aiohttp_api_spec_dir, aiohttp_client):
async def test_bearer_secure(aiohttp_api_spec_dir, aiohttp_client):
"""
Test authentication method specific to OpenApi 3
"""
app = AioHttpApp(__name__, port=5001, specification_dir=aiohttp_api_spec_dir, debug=True)
app.add_api('openapi_secure.yaml')
app_client = yield from aiohttp_client(app.app)
app_client = await aiohttp_client(app.app)

bearer_header = 'Bearer {"scope": ["myscope"], "uid": "test-user"}'
response = yield from app_client.get('/v1.0/bearer_auth', headers={'Authorization': bearer_header})
response = await app_client.get('/v1.0/bearer_auth', headers={'Authorization': bearer_header})
assert response.status == 200
assert (yield from response.json()) == {"scope": ['myscope'], "uid": 'test-user'}
assert (await response.json()) == {"scope": ['myscope'], "uid": 'test-user'}


@asyncio.coroutine
def test_async_secure(aiohttp_api_spec_dir, aiohttp_client):
async def test_async_secure(aiohttp_api_spec_dir, aiohttp_client):
app = AioHttpApp(__name__, port=5001, specification_dir=aiohttp_api_spec_dir, debug=True)
app.add_api('openapi_secure.yaml', pass_context_arg_name='request')
app_client = yield from aiohttp_client(app.app)
app_client = await aiohttp_client(app.app)

response = yield from app_client.get('/v1.0/async_auth')
response = await app_client.get('/v1.0/async_auth')
assert response.status == 401
assert response.content_type == 'application/problem+json'

bearer_header = 'Bearer {"scope": ["myscope"], "uid": "test-user"}'
response = yield from app_client.get('/v1.0/async_auth', headers={'Authorization': bearer_header})
response = await app_client.get('/v1.0/async_auth', headers={'Authorization': bearer_header})
assert response.status == 200
assert (yield from response.json()) == {"scope": ['myscope'], "uid": 'test-user'}
assert (await response.json()) == {"scope": ['myscope'], "uid": 'test-user'}

bearer_header = 'Bearer {"scope": ["myscope", "other_scope"], "uid": "test-user"}'
response = yield from app_client.get('/v1.0/async_auth', headers={'Authorization': bearer_header})
response = await app_client.get('/v1.0/async_auth', headers={'Authorization': bearer_header})
assert response.status == 403, "async_scope_validation should deny access if scopes are not strictly the same"

basic_header = 'Basic ' + base64.b64encode(b'username:username').decode('ascii')
response = yield from app_client.get('/v1.0/async_auth', headers={'Authorization': basic_header})
response = await app_client.get('/v1.0/async_auth', headers={'Authorization': basic_header})
assert response.status == 200
assert (yield from response.json()) == {"uid": 'username'}
assert (await response.json()) == {"uid": 'username'}

basic_header = 'Basic ' + base64.b64encode(b'username:wrong').decode('ascii')
response = yield from app_client.get('/v1.0/async_auth', headers={'Authorization': basic_header})
response = await app_client.get('/v1.0/async_auth', headers={'Authorization': basic_header})
assert response.status == 401, "Wrong password should trigger unauthorized"
assert response.content_type == 'application/problem+json'

response = yield from app_client.get('/v1.0/all_auth', headers={'X-API-Key': '{"foo": "bar"}'})
response = await app_client.get('/v1.0/all_auth', headers={'X-API-Key': '{"foo": "bar"}'})
assert response.status == 200
assert (yield from response.json()) == {"foo": "bar"}
assert (await response.json()) == {"foo": "bar"}

bearer_header = 'Bearer {"scope": ["myscope"], "uid": "test-user"}'
response = yield from app_client.get('/v1.0/async_bearer_auth', headers={'Authorization': bearer_header})
response = await app_client.get('/v1.0/async_bearer_auth', headers={'Authorization': bearer_header})
assert response.status == 200
assert (yield from response.json()) == {"scope": ['myscope'], "uid": 'test-user'}
assert (await response.json()) == {"scope": ['myscope'], "uid": 'test-user'}
9 changes: 3 additions & 6 deletions tests/fakeapi/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,13 @@ def fake_json_auth(token, required_scopes=None):
return None


@asyncio.coroutine
def async_basic_auth(username, password, required_scopes=None, request=None):
async def async_basic_auth(username, password, required_scopes=None, request=None):
return fake_basic_auth(username, password, required_scopes)


@asyncio.coroutine
def async_json_auth(token, required_scopes=None, request=None):
async def async_json_auth(token, required_scopes=None, request=None):
return fake_json_auth(token, required_scopes)


@asyncio.coroutine
def async_scope_validation(required_scopes, token_scopes, request):
async def async_scope_validation(required_scopes, token_scopes, request):
return required_scopes == token_scopes