Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
f53699f
Explicitly add license file
rayluo Jul 17, 2021
7218aaf
Merge pull request #384 from AzureAD/release-1.13.0
rayluo Jul 20, 2021
3bd1a9d
Use dot-env for convenient local testing
rayluo Jul 15, 2021
89a10d4
Merge branch 'dot-env' into dev
rayluo Jul 23, 2021
ab1f353
It was skipped and recently broken. Now it works.
rayluo Jul 30, 2021
9309743
Merge branch 'worldwide-confidential-client-test' into dev
rayluo Jul 30, 2021
ffc5953
Switch to github action badge
rayluo Jul 8, 2021
efbc83a
Merge branch 'switch-to-github-action-badge' into dev
rayluo Aug 1, 2021
db6f001
Survive issue 387
rayluo Aug 4, 2021
ee96522
Merge pull request #390 from AzureAD/fix-issue-387
rayluo Aug 4, 2021
4398d23
obtain_token_by_browser(..., error_template=...)
rayluo Mar 26, 2021
4f2f2de
Merge branch 'error_template' into dev
rayluo Mar 29, 2021
c89f193
Customizable browser_name
rayluo Mar 31, 2021
bbb9af5
Merge branch 'customizable-browser' into dev
rayluo Apr 5, 2021
efa5668
Merge in staged oauth2cli changes
rayluo Aug 5, 2021
e94dda5
Merge branch 'o2c' into dev
rayluo Aug 5, 2021
bce6cc0
Turns out webbrowser.open() is more robust
rayluo Apr 8, 2021
ed4c796
Merge branch 'customizable-browser' into dev
rayluo Apr 12, 2021
c687d5b
Merge remote-tracking branch 'oauth2cli_github/dev' into o2c
rayluo Aug 5, 2021
71802d0
Change regional endpoint doname name
rayluo Aug 10, 2021
96140b0
Regional endpoint test cases do not rely on env var REGION_NAME
rayluo Aug 10, 2021
5fdae2d
REGION_NAME has no unified format across services
rayluo Aug 10, 2021
3e2b56d
Merge pull request #394 from AzureAD/region-endpoint-specs-changes
rayluo Aug 10, 2021
f565493
Prefer Edge when running on Linux
rayluo Jul 30, 2021
ef89877
Merge pull request #388 from AzureAD/prefer-edge-on-linux
rayluo Aug 19, 2021
27097e6
An individual cache, after 3+ prototypes
rayluo May 7, 2020
e33b055
ThrottledHttpClient
rayluo Jul 8, 2021
b4401a1
Use throttled_http_client
rayluo Jul 8, 2021
deb7900
Merge pull request #379 from AzureAD/http-cache
rayluo Aug 19, 2021
8eb5c18
Convert staticmethod to module-wide public method
rayluo Aug 18, 2021
e969e64
Merge branch 'refactor-token-cache-test-cases' into dev
rayluo Aug 19, 2021
a1f9ca7
Enable ThrottledHttpClient.close()
rayluo Aug 26, 2021
a621e50
Merge branch 'http-cache' into dev
rayluo Aug 26, 2021
24959a9
MSAL Python 1.14
rayluo Aug 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
ThrottledHttpClient
Decorate the http_client for http_cache behavior

Wrap http_client instead of decorate it

Rename to throttled_http_client.py

Refactor and change default retry-after delay to 60 seconds

ThrottledHttpClient test case contains params
  • Loading branch information
rayluo committed Aug 19, 2021
commit e33b055bce6820fa816bde277917340acbae22dd
134 changes: 134 additions & 0 deletions msal/throttled_http_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from threading import Lock
from hashlib import sha256

from .individual_cache import _IndividualCache as IndividualCache
from .individual_cache import _ExpiringMapping as ExpiringMapping


# https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code"


def _hash(raw):
return sha256(repr(raw).encode("utf-8")).hexdigest()


def _parse_http_429_5xx_retry_after(result=None, **ignored):
"""Return seconds to throttle"""
assert result is not None, """
The signature defines it with a default value None,
only because the its shape is already decided by the
IndividualCache's.__call__().
In actual code path, the result parameter here won't be None.
"""
response = result
lowercase_headers = {k.lower(): v for k, v in getattr(
# Historically, MSAL's HttpResponse does not always have headers
response, "headers", {}).items()}
if not (response.status_code == 429 or response.status_code >= 500
or "retry-after" in lowercase_headers):
return 0 # Quick exit
default = 60 # Recommended at the end of
# https://identitydivision.visualstudio.com/devex/_git/AuthLibrariesApiReview?version=GBdev&path=%2FService%20protection%2FIntial%20set%20of%20protection%20measures.md&_a=preview
retry_after = int(lowercase_headers.get("retry-after", default))
try:
# AAD's retry_after uses integer format only
# https://stackoverflow.microsoft.com/questions/264931/264932
delay_seconds = int(retry_after)
except ValueError:
delay_seconds = default
return min(3600, delay_seconds)


def _extract_data(kwargs, key, default=None):
data = kwargs.get("data", {}) # data is usually a dict, but occasionally a string
return data.get(key) if isinstance(data, dict) else default


class ThrottledHttpClient(object):
def __init__(self, http_client, http_cache):
"""Throttle the given http_client by storing and retrieving data from cache.

This wrapper exists so that our patching post() and get() would prevent
re-patching side effect when/if same http_client being reused.
"""
expiring_mapping = ExpiringMapping( # It will automatically clean up
mapping=http_cache if http_cache is not None else {},
capacity=1024, # To prevent cache blowing up especially for CCA
lock=Lock(), # TODO: This should ideally also allow customization
)

_post = http_client.post # We'll patch _post, and keep original post() intact

_post = IndividualCache(
# Internal specs requires throttling on at least token endpoint,
# here we have a generic patch for POST on all endpoints.
mapping=expiring_mapping,
key_maker=lambda func, args, kwargs:
"POST {} client_id={} scope={} hash={} 429/5xx/Retry-After".format(
args[0], # It is the url, typically containing authority and tenant
_extract_data(kwargs, "client_id"), # Per internal specs
_extract_data(kwargs, "scope"), # Per internal specs
_hash(
# The followings are all approximations of the "account" concept
# to support per-account throttling.
# TODO: We may want to disable it for confidential client, though
_extract_data(kwargs, "refresh_token", # "account" during refresh
_extract_data(kwargs, "code", # "account" of auth code grant
_extract_data(kwargs, "username")))), # "account" of ROPC
),
expires_in=_parse_http_429_5xx_retry_after,
)(_post)

_post = IndividualCache( # It covers the "UI required cache"
mapping=expiring_mapping,
key_maker=lambda func, args, kwargs: "POST {} hash={} 400".format(
args[0], # It is the url, typically containing authority and tenant
_hash(
# Here we use literally all parameters, even those short-lived
# parameters containing timestamps (WS-Trust or POP assertion),
# because they will automatically be cleaned up by ExpiringMapping.
#
# Furthermore, there is no need to implement
# "interactive requests would reset the cache",
# because acquire_token_silent()'s would be automatically unblocked
# due to token cache layer operates on top of http cache layer.
#
# And, acquire_token_silent(..., force_refresh=True) will NOT
# bypass http cache, because there is no real gain from that.
# We won't bother implement it, nor do we want to encourage
# acquire_token_silent(..., force_refresh=True) pattern.
str(kwargs.get("params")) + str(kwargs.get("data"))),
),
expires_in=lambda result=None, data=None, **ignored:
60
if result.status_code == 400
# Here we choose to cache exact HTTP 400 errors only (rather than 4xx)
# because they are the ones defined in OAuth2
# (https://datatracker.ietf.org/doc/html/rfc6749#section-5.2)
# Other 4xx errors might have different requirements e.g.
# "407 Proxy auth required" would need a key including http headers.
and not( # Exclude Device Flow cause its retry is expected and regulated
isinstance(data, dict) and data.get("grant_type") == DEVICE_AUTH_GRANT
)
and "retry-after" not in set( # Leave it to the Retry-After decorator
h.lower() for h in getattr(result, "headers", {}).keys())
else 0,
)(_post)

self.post = _post

self.get = IndividualCache( # Typically those discovery GETs
mapping=expiring_mapping,
key_maker=lambda func, args, kwargs: "GET {} hash={} 2xx".format(
args[0], # It is the url, sometimes containing inline params
_hash(kwargs.get("params", "")),
),
expires_in=lambda result=None, **ignored:
3600*24 if 200 <= result.status_code < 300 else 0,
)(http_client.get)

# The following 2 methods have been defined dynamically by __init__()
#def post(self, *args, **kwargs): pass
#def get(self, *args, **kwargs): pass

165 changes: 165 additions & 0 deletions tests/test_throttled_http_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Test cases for https://identitydivision.visualstudio.com/devex/_git/AuthLibrariesApiReview?version=GBdev&path=%2FService%20protection%2FIntial%20set%20of%20protection%20measures.md&_a=preview&anchor=common-test-cases
from time import sleep
from random import random
import logging
from msal.throttled_http_client import ThrottledHttpClient
from tests import unittest
from tests.http_client import MinimalResponse


logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)


class DummyHttpResponse(MinimalResponse):
def __init__(self, headers=None, **kwargs):
self.headers = {} if headers is None else headers
super(DummyHttpResponse, self).__init__(**kwargs)


class DummyHttpClient(object):
def __init__(self, status_code=None, response_headers=None):
self._status_code = status_code
self._response_headers = response_headers

def _build_dummy_response(self):
return DummyHttpResponse(
status_code=self._status_code,
headers=self._response_headers,
text=random(), # So that we'd know whether a new response is received
)

def post(self, url, params=None, data=None, headers=None, **kwargs):
return self._build_dummy_response()

def get(self, url, params=None, headers=None, **kwargs):
return self._build_dummy_response()


class TestHttpDecoration(unittest.TestCase):

def test_throttled_http_client_should_not_alter_original_http_client(self):
http_cache = {}
original_http_client = DummyHttpClient()
original_get = original_http_client.get
original_post = original_http_client.post
throttled_http_client = ThrottledHttpClient(original_http_client, http_cache)
goal = """The implementation should wrap original http_client
and keep it intact, instead of monkey-patching it"""
self.assertNotEqual(throttled_http_client, original_http_client, goal)
self.assertEqual(original_post, original_http_client.post)
self.assertEqual(original_get, original_http_client.get)

def _test_RetryAfter_N_seconds_should_keep_entry_for_N_seconds(
self, http_client, retry_after):
http_cache = {}
http_client = ThrottledHttpClient(http_client, http_cache)
resp1 = http_client.post("https://example.com") # We implemented POST only
resp2 = http_client.post("https://example.com") # We implemented POST only
logger.debug(http_cache)
self.assertEqual(resp1.text, resp2.text, "Should return a cached response")
sleep(retry_after + 1)
resp3 = http_client.post("https://example.com") # We implemented POST only
self.assertNotEqual(resp1.text, resp3.text, "Should return a new response")

def test_429_with_RetryAfter_N_seconds_should_keep_entry_for_N_seconds(self):
retry_after = 1
self._test_RetryAfter_N_seconds_should_keep_entry_for_N_seconds(
DummyHttpClient(
status_code=429, response_headers={"Retry-After": retry_after}),
retry_after)

def test_5xx_with_RetryAfter_N_seconds_should_keep_entry_for_N_seconds(self):
retry_after = 1
self._test_RetryAfter_N_seconds_should_keep_entry_for_N_seconds(
DummyHttpClient(
status_code=503, response_headers={"Retry-After": retry_after}),
retry_after)

def test_400_with_RetryAfter_N_seconds_should_keep_entry_for_N_seconds(self):
"""Retry-After is supposed to only shown in http 429/5xx,
but we choose to support Retry-After for arbitrary http response."""
retry_after = 1
self._test_RetryAfter_N_seconds_should_keep_entry_for_N_seconds(
DummyHttpClient(
status_code=400, response_headers={"Retry-After": retry_after}),
retry_after)

def test_one_RetryAfter_request_should_block_a_similar_request(self):
http_cache = {}
http_client = DummyHttpClient(
status_code=429, response_headers={"Retry-After": 2})
http_client = ThrottledHttpClient(http_client, http_cache)
resp1 = http_client.post("https://example.com", data={
"scope": "one", "claims": "bar", "grant_type": "authorization_code"})
resp2 = http_client.post("https://example.com", data={
"scope": "one", "claims": "foo", "grant_type": "password"})
logger.debug(http_cache)
self.assertEqual(resp1.text, resp2.text, "Should return a cached response")

def test_one_RetryAfter_request_should_not_block_a_different_request(self):
http_cache = {}
http_client = DummyHttpClient(
status_code=429, response_headers={"Retry-After": 2})
http_client = ThrottledHttpClient(http_client, http_cache)
resp1 = http_client.post("https://example.com", data={"scope": "one"})
resp2 = http_client.post("https://example.com", data={"scope": "two"})
logger.debug(http_cache)
self.assertNotEqual(resp1.text, resp2.text, "Should return a new response")

def test_one_invalid_grant_should_block_a_similar_request(self):
http_cache = {}
http_client = DummyHttpClient(
status_code=400) # It covers invalid_grant and interaction_required
http_client = ThrottledHttpClient(http_client, http_cache)
resp1 = http_client.post("https://example.com", data={"claims": "foo"})
logger.debug(http_cache)
resp1_again = http_client.post("https://example.com", data={"claims": "foo"})
self.assertEqual(resp1.text, resp1_again.text, "Should return a cached response")
resp2 = http_client.post("https://example.com", data={"claims": "bar"})
self.assertNotEqual(resp1.text, resp2.text, "Should return a new response")
resp2_again = http_client.post("https://example.com", data={"claims": "bar"})
self.assertEqual(resp2.text, resp2_again.text, "Should return a cached response")

def test_one_foci_app_recovering_from_invalid_grant_should_also_unblock_another(self):
"""
Need not test multiple FOCI app's acquire_token_silent() here. By design,
one FOCI app's successful populating token cache would result in another
FOCI app's acquire_token_silent() to hit a token without invoking http request.
"""

def test_forcefresh_behavior(self):
"""
The implementation let token cache and http cache operate in different
layers. They do not couple with each other.
Therefore, acquire_token_silent(..., force_refresh=True)
would bypass the token cache yet technically still hit the http cache.

But that is OK, cause the customer need no force_refresh in the first place.
After a successful AT/RT acquisition, AT/RT will be in the token cache,
and a normal acquire_token_silent(...) without force_refresh would just work.
This was discussed in https://identitydivision.visualstudio.com/DevEx/_git/AuthLibrariesApiReview/pullrequest/3618?_a=files
"""

def test_http_get_200_should_be_cached(self):
http_cache = {}
http_client = DummyHttpClient(
status_code=200) # It covers UserRealm discovery and OIDC discovery
http_client = ThrottledHttpClient(http_client, http_cache)
resp1 = http_client.get("https://example.com?foo=bar")
resp2 = http_client.get("https://example.com?foo=bar")
logger.debug(http_cache)
self.assertEqual(resp1.text, resp2.text, "Should return a cached response")

def test_device_flow_retry_should_not_be_cached(self):
DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code"
http_cache = {}
http_client = DummyHttpClient(status_code=400)
http_client = ThrottledHttpClient(http_client, http_cache)
resp1 = http_client.get(
"https://example.com", data={"grant_type": DEVICE_AUTH_GRANT})
resp2 = http_client.get(
"https://example.com", data={"grant_type": DEVICE_AUTH_GRANT})
logger.debug(http_cache)
self.assertNotEqual(resp1.text, resp2.text, "Should return a new response")