Skip to content

Commit e33b055

Browse files
committed
ThrottledHttpClient
Decorate the http_client for http_cache behavior Wrap http_client instead of decorate it Rename to throttled_http_client.py Refactor and change default retry-after delay to 60 seconds ThrottledHttpClient test case contains params
1 parent 27097e6 commit e33b055

File tree

2 files changed

+299
-0
lines changed

2 files changed

+299
-0
lines changed

msal/throttled_http_client.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
from threading import Lock
2+
from hashlib import sha256
3+
4+
from .individual_cache import _IndividualCache as IndividualCache
5+
from .individual_cache import _ExpiringMapping as ExpiringMapping
6+
7+
8+
# https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
9+
DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code"
10+
11+
12+
def _hash(raw):
13+
return sha256(repr(raw).encode("utf-8")).hexdigest()
14+
15+
16+
def _parse_http_429_5xx_retry_after(result=None, **ignored):
17+
"""Return seconds to throttle"""
18+
assert result is not None, """
19+
The signature defines it with a default value None,
20+
only because the its shape is already decided by the
21+
IndividualCache's.__call__().
22+
In actual code path, the result parameter here won't be None.
23+
"""
24+
response = result
25+
lowercase_headers = {k.lower(): v for k, v in getattr(
26+
# Historically, MSAL's HttpResponse does not always have headers
27+
response, "headers", {}).items()}
28+
if not (response.status_code == 429 or response.status_code >= 500
29+
or "retry-after" in lowercase_headers):
30+
return 0 # Quick exit
31+
default = 60 # Recommended at the end of
32+
# https://identitydivision.visualstudio.com/devex/_git/AuthLibrariesApiReview?version=GBdev&path=%2FService%20protection%2FIntial%20set%20of%20protection%20measures.md&_a=preview
33+
retry_after = int(lowercase_headers.get("retry-after", default))
34+
try:
35+
# AAD's retry_after uses integer format only
36+
# https://stackoverflow.microsoft.com/questions/264931/264932
37+
delay_seconds = int(retry_after)
38+
except ValueError:
39+
delay_seconds = default
40+
return min(3600, delay_seconds)
41+
42+
43+
def _extract_data(kwargs, key, default=None):
44+
data = kwargs.get("data", {}) # data is usually a dict, but occasionally a string
45+
return data.get(key) if isinstance(data, dict) else default
46+
47+
48+
class ThrottledHttpClient(object):
49+
def __init__(self, http_client, http_cache):
50+
"""Throttle the given http_client by storing and retrieving data from cache.
51+
52+
This wrapper exists so that our patching post() and get() would prevent
53+
re-patching side effect when/if same http_client being reused.
54+
"""
55+
expiring_mapping = ExpiringMapping( # It will automatically clean up
56+
mapping=http_cache if http_cache is not None else {},
57+
capacity=1024, # To prevent cache blowing up especially for CCA
58+
lock=Lock(), # TODO: This should ideally also allow customization
59+
)
60+
61+
_post = http_client.post # We'll patch _post, and keep original post() intact
62+
63+
_post = IndividualCache(
64+
# Internal specs requires throttling on at least token endpoint,
65+
# here we have a generic patch for POST on all endpoints.
66+
mapping=expiring_mapping,
67+
key_maker=lambda func, args, kwargs:
68+
"POST {} client_id={} scope={} hash={} 429/5xx/Retry-After".format(
69+
args[0], # It is the url, typically containing authority and tenant
70+
_extract_data(kwargs, "client_id"), # Per internal specs
71+
_extract_data(kwargs, "scope"), # Per internal specs
72+
_hash(
73+
# The followings are all approximations of the "account" concept
74+
# to support per-account throttling.
75+
# TODO: We may want to disable it for confidential client, though
76+
_extract_data(kwargs, "refresh_token", # "account" during refresh
77+
_extract_data(kwargs, "code", # "account" of auth code grant
78+
_extract_data(kwargs, "username")))), # "account" of ROPC
79+
),
80+
expires_in=_parse_http_429_5xx_retry_after,
81+
)(_post)
82+
83+
_post = IndividualCache( # It covers the "UI required cache"
84+
mapping=expiring_mapping,
85+
key_maker=lambda func, args, kwargs: "POST {} hash={} 400".format(
86+
args[0], # It is the url, typically containing authority and tenant
87+
_hash(
88+
# Here we use literally all parameters, even those short-lived
89+
# parameters containing timestamps (WS-Trust or POP assertion),
90+
# because they will automatically be cleaned up by ExpiringMapping.
91+
#
92+
# Furthermore, there is no need to implement
93+
# "interactive requests would reset the cache",
94+
# because acquire_token_silent()'s would be automatically unblocked
95+
# due to token cache layer operates on top of http cache layer.
96+
#
97+
# And, acquire_token_silent(..., force_refresh=True) will NOT
98+
# bypass http cache, because there is no real gain from that.
99+
# We won't bother implement it, nor do we want to encourage
100+
# acquire_token_silent(..., force_refresh=True) pattern.
101+
str(kwargs.get("params")) + str(kwargs.get("data"))),
102+
),
103+
expires_in=lambda result=None, data=None, **ignored:
104+
60
105+
if result.status_code == 400
106+
# Here we choose to cache exact HTTP 400 errors only (rather than 4xx)
107+
# because they are the ones defined in OAuth2
108+
# (https://datatracker.ietf.org/doc/html/rfc6749#section-5.2)
109+
# Other 4xx errors might have different requirements e.g.
110+
# "407 Proxy auth required" would need a key including http headers.
111+
and not( # Exclude Device Flow cause its retry is expected and regulated
112+
isinstance(data, dict) and data.get("grant_type") == DEVICE_AUTH_GRANT
113+
)
114+
and "retry-after" not in set( # Leave it to the Retry-After decorator
115+
h.lower() for h in getattr(result, "headers", {}).keys())
116+
else 0,
117+
)(_post)
118+
119+
self.post = _post
120+
121+
self.get = IndividualCache( # Typically those discovery GETs
122+
mapping=expiring_mapping,
123+
key_maker=lambda func, args, kwargs: "GET {} hash={} 2xx".format(
124+
args[0], # It is the url, sometimes containing inline params
125+
_hash(kwargs.get("params", "")),
126+
),
127+
expires_in=lambda result=None, **ignored:
128+
3600*24 if 200 <= result.status_code < 300 else 0,
129+
)(http_client.get)
130+
131+
# The following 2 methods have been defined dynamically by __init__()
132+
#def post(self, *args, **kwargs): pass
133+
#def get(self, *args, **kwargs): pass
134+
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# Test cases for https://identitydivision.visualstudio.com/devex/_git/AuthLibrariesApiReview?version=GBdev&path=%2FService%20protection%2FIntial%20set%20of%20protection%20measures.md&_a=preview&anchor=common-test-cases
2+
from time import sleep
3+
from random import random
4+
import logging
5+
from msal.throttled_http_client import ThrottledHttpClient
6+
from tests import unittest
7+
from tests.http_client import MinimalResponse
8+
9+
10+
logger = logging.getLogger(__name__)
11+
logging.basicConfig(level=logging.DEBUG)
12+
13+
14+
class DummyHttpResponse(MinimalResponse):
15+
def __init__(self, headers=None, **kwargs):
16+
self.headers = {} if headers is None else headers
17+
super(DummyHttpResponse, self).__init__(**kwargs)
18+
19+
20+
class DummyHttpClient(object):
21+
def __init__(self, status_code=None, response_headers=None):
22+
self._status_code = status_code
23+
self._response_headers = response_headers
24+
25+
def _build_dummy_response(self):
26+
return DummyHttpResponse(
27+
status_code=self._status_code,
28+
headers=self._response_headers,
29+
text=random(), # So that we'd know whether a new response is received
30+
)
31+
32+
def post(self, url, params=None, data=None, headers=None, **kwargs):
33+
return self._build_dummy_response()
34+
35+
def get(self, url, params=None, headers=None, **kwargs):
36+
return self._build_dummy_response()
37+
38+
39+
class TestHttpDecoration(unittest.TestCase):
40+
41+
def test_throttled_http_client_should_not_alter_original_http_client(self):
42+
http_cache = {}
43+
original_http_client = DummyHttpClient()
44+
original_get = original_http_client.get
45+
original_post = original_http_client.post
46+
throttled_http_client = ThrottledHttpClient(original_http_client, http_cache)
47+
goal = """The implementation should wrap original http_client
48+
and keep it intact, instead of monkey-patching it"""
49+
self.assertNotEqual(throttled_http_client, original_http_client, goal)
50+
self.assertEqual(original_post, original_http_client.post)
51+
self.assertEqual(original_get, original_http_client.get)
52+
53+
def _test_RetryAfter_N_seconds_should_keep_entry_for_N_seconds(
54+
self, http_client, retry_after):
55+
http_cache = {}
56+
http_client = ThrottledHttpClient(http_client, http_cache)
57+
resp1 = http_client.post("https://example.com") # We implemented POST only
58+
resp2 = http_client.post("https://example.com") # We implemented POST only
59+
logger.debug(http_cache)
60+
self.assertEqual(resp1.text, resp2.text, "Should return a cached response")
61+
sleep(retry_after + 1)
62+
resp3 = http_client.post("https://example.com") # We implemented POST only
63+
self.assertNotEqual(resp1.text, resp3.text, "Should return a new response")
64+
65+
def test_429_with_RetryAfter_N_seconds_should_keep_entry_for_N_seconds(self):
66+
retry_after = 1
67+
self._test_RetryAfter_N_seconds_should_keep_entry_for_N_seconds(
68+
DummyHttpClient(
69+
status_code=429, response_headers={"Retry-After": retry_after}),
70+
retry_after)
71+
72+
def test_5xx_with_RetryAfter_N_seconds_should_keep_entry_for_N_seconds(self):
73+
retry_after = 1
74+
self._test_RetryAfter_N_seconds_should_keep_entry_for_N_seconds(
75+
DummyHttpClient(
76+
status_code=503, response_headers={"Retry-After": retry_after}),
77+
retry_after)
78+
79+
def test_400_with_RetryAfter_N_seconds_should_keep_entry_for_N_seconds(self):
80+
"""Retry-After is supposed to only shown in http 429/5xx,
81+
but we choose to support Retry-After for arbitrary http response."""
82+
retry_after = 1
83+
self._test_RetryAfter_N_seconds_should_keep_entry_for_N_seconds(
84+
DummyHttpClient(
85+
status_code=400, response_headers={"Retry-After": retry_after}),
86+
retry_after)
87+
88+
def test_one_RetryAfter_request_should_block_a_similar_request(self):
89+
http_cache = {}
90+
http_client = DummyHttpClient(
91+
status_code=429, response_headers={"Retry-After": 2})
92+
http_client = ThrottledHttpClient(http_client, http_cache)
93+
resp1 = http_client.post("https://example.com", data={
94+
"scope": "one", "claims": "bar", "grant_type": "authorization_code"})
95+
resp2 = http_client.post("https://example.com", data={
96+
"scope": "one", "claims": "foo", "grant_type": "password"})
97+
logger.debug(http_cache)
98+
self.assertEqual(resp1.text, resp2.text, "Should return a cached response")
99+
100+
def test_one_RetryAfter_request_should_not_block_a_different_request(self):
101+
http_cache = {}
102+
http_client = DummyHttpClient(
103+
status_code=429, response_headers={"Retry-After": 2})
104+
http_client = ThrottledHttpClient(http_client, http_cache)
105+
resp1 = http_client.post("https://example.com", data={"scope": "one"})
106+
resp2 = http_client.post("https://example.com", data={"scope": "two"})
107+
logger.debug(http_cache)
108+
self.assertNotEqual(resp1.text, resp2.text, "Should return a new response")
109+
110+
def test_one_invalid_grant_should_block_a_similar_request(self):
111+
http_cache = {}
112+
http_client = DummyHttpClient(
113+
status_code=400) # It covers invalid_grant and interaction_required
114+
http_client = ThrottledHttpClient(http_client, http_cache)
115+
resp1 = http_client.post("https://example.com", data={"claims": "foo"})
116+
logger.debug(http_cache)
117+
resp1_again = http_client.post("https://example.com", data={"claims": "foo"})
118+
self.assertEqual(resp1.text, resp1_again.text, "Should return a cached response")
119+
resp2 = http_client.post("https://example.com", data={"claims": "bar"})
120+
self.assertNotEqual(resp1.text, resp2.text, "Should return a new response")
121+
resp2_again = http_client.post("https://example.com", data={"claims": "bar"})
122+
self.assertEqual(resp2.text, resp2_again.text, "Should return a cached response")
123+
124+
def test_one_foci_app_recovering_from_invalid_grant_should_also_unblock_another(self):
125+
"""
126+
Need not test multiple FOCI app's acquire_token_silent() here. By design,
127+
one FOCI app's successful populating token cache would result in another
128+
FOCI app's acquire_token_silent() to hit a token without invoking http request.
129+
"""
130+
131+
def test_forcefresh_behavior(self):
132+
"""
133+
The implementation let token cache and http cache operate in different
134+
layers. They do not couple with each other.
135+
Therefore, acquire_token_silent(..., force_refresh=True)
136+
would bypass the token cache yet technically still hit the http cache.
137+
138+
But that is OK, cause the customer need no force_refresh in the first place.
139+
After a successful AT/RT acquisition, AT/RT will be in the token cache,
140+
and a normal acquire_token_silent(...) without force_refresh would just work.
141+
This was discussed in https://identitydivision.visualstudio.com/DevEx/_git/AuthLibrariesApiReview/pullrequest/3618?_a=files
142+
"""
143+
144+
def test_http_get_200_should_be_cached(self):
145+
http_cache = {}
146+
http_client = DummyHttpClient(
147+
status_code=200) # It covers UserRealm discovery and OIDC discovery
148+
http_client = ThrottledHttpClient(http_client, http_cache)
149+
resp1 = http_client.get("https://example.com?foo=bar")
150+
resp2 = http_client.get("https://example.com?foo=bar")
151+
logger.debug(http_cache)
152+
self.assertEqual(resp1.text, resp2.text, "Should return a cached response")
153+
154+
def test_device_flow_retry_should_not_be_cached(self):
155+
DEVICE_AUTH_GRANT = "urn:ietf:params:oauth:grant-type:device_code"
156+
http_cache = {}
157+
http_client = DummyHttpClient(status_code=400)
158+
http_client = ThrottledHttpClient(http_client, http_cache)
159+
resp1 = http_client.get(
160+
"https://example.com", data={"grant_type": DEVICE_AUTH_GRANT})
161+
resp2 = http_client.get(
162+
"https://example.com", data={"grant_type": DEVICE_AUTH_GRANT})
163+
logger.debug(http_cache)
164+
self.assertNotEqual(resp1.text, resp2.text, "Should return a new response")
165+

0 commit comments

Comments
 (0)