Skip to content

Commit 1e53642

Browse files
committed
Merging new changes
2 parents e0dc68f + 6e82d29 commit 1e53642

File tree

3 files changed

+95
-19
lines changed

3 files changed

+95
-19
lines changed

msal/authority.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,14 @@ def __init__(self, authority_url, validate_authority=True,
3838
self.proxies = proxies
3939
self.timeout = timeout
4040
canonicalized, self.instance, tenant = canonicalize(authority_url)
41-
tenant_discovery_endpoint = ( # Hard code a V2 pattern as default value
42-
'https://{}/{}/v2.0/.well-known/openid-configuration'
43-
.format(self.instance, tenant))
44-
if validate_authority and self.instance not in WELL_KNOWN_AUTHORITY_HOSTS:
41+
tenant_discovery_endpoint = (
42+
'https://{}/{}{}/.well-known/openid-configuration'.format(
43+
self.instance,
44+
tenant,
45+
"" if tenant == "adfs" else "/v2.0" # the AAD v2 endpoint
46+
))
47+
if (tenant != "adfs" and validate_authority
48+
and self.instance not in WELL_KNOWN_AUTHORITY_HOSTS):
4549
tenant_discovery_endpoint = instance_discovery(
4650
canonicalized + "/oauth2/v2.0/authorize",
4751
verify=verify, proxies=proxies, timeout=timeout)

msal/token_cache.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -121,15 +121,16 @@ def add(self, event, now=None):
121121
id_token_claims = (
122122
decode_id_token(id_token, client_id=event["client_id"])
123123
if id_token else {})
124-
client_info = (
125-
json.loads(base64decode(response["client_info"]))
126-
if "client_info" in response
127-
else { # ADFS scenario
124+
client_info = {}
125+
if "client_info" in response: # We asked for it, and AAD will provide it
126+
client_info = json.loads(base64decode(response["client_info"]))
127+
elif id_token_claims: # This would be an end user on ADFS-direct scenario
128+
client_info = {
128129
"uid": id_token_claims.get("sub"),
129-
"utid": environment, # TBD
130+
"utid": realm, # which, in ADFS scenario, would typically be "adfs"
130131
}
131-
)
132-
home_account_id = "{uid}.{utid}".format(**client_info)
132+
home_account_id = ( # It would remain None in client_credentials flow
133+
"{uid}.{utid}".format(**client_info) if client_info else None)
133134
target = ' '.join(event.get("scope", [])) # Per schema, we don't sort it
134135

135136
with self._lock:

tests/test_token_cache.py

Lines changed: 79 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,30 +16,29 @@ class TokenCacheTestCase(unittest.TestCase):
1616
@staticmethod
1717
def build_id_token(
1818
iss="issuer", sub="subject", aud="my_client_id", exp=None, iat=None,
19-
preferred_username="me", **claims):
19+
**claims): # AAD issues "preferred_username", ADFS issues "upn"
2020
return "header.%s.signature" % base64.b64encode(json.dumps(dict({
2121
"iss": iss,
2222
"sub": sub,
2323
"aud": aud,
2424
"exp": exp or (time.time() + 100),
2525
"iat": iat or time.time(),
26-
"preferred_username": preferred_username,
2726
}, **claims)).encode()).decode('utf-8')
2827

2928
@staticmethod
3029
def build_response( # simulate a response from AAD
31-
uid="uid", utid="utid", # They will form client_info
30+
uid=None, utid=None, # If present, they will form client_info
3231
access_token=None, expires_in=3600, token_type="some type",
3332
refresh_token=None,
3433
foci=None,
3534
id_token=None, # or something generated by build_id_token()
3635
error=None,
3736
):
38-
response = {
39-
"client_info": base64.b64encode(json.dumps({
37+
response = {}
38+
if uid and utid: # Mimic the AAD behavior for "client_info=1" request
39+
response["client_info"] = base64.b64encode(json.dumps({
4040
"uid": uid, "utid": utid,
41-
}).encode()).decode('utf-8'),
42-
}
41+
}).encode()).decode('utf-8')
4342
if error:
4443
response["error"] = error
4544
if access_token:
@@ -59,7 +58,7 @@ def build_response( # simulate a response from AAD
5958
def setUp(self):
6059
self.cache = TokenCache()
6160

62-
def testAdd(self):
61+
def testAddByAad(self):
6362
client_id = "my_client_id"
6463
id_token = self.build_id_token(
6564
oid="object1234", preferred_username="John Doe", aud=client_id)
@@ -132,6 +131,78 @@ def testAdd(self):
132131
"appmetadata-login.example.com-my_client_id")
133132
)
134133

134+
def testAddByAdfs(self):
135+
client_id = "my_client_id"
136+
id_token = self.build_id_token(aud=client_id, upn="[email protected]")
137+
self.cache.add({
138+
"client_id": client_id,
139+
"scope": ["s2", "s1", "s3"], # Not in particular order
140+
"token_endpoint": "https://fs.msidlab8.com/adfs/oauth2/token",
141+
"response": self.build_response(
142+
uid=None, utid=None, # ADFS will provide no client_info
143+
expires_in=3600, access_token="an access token",
144+
id_token=id_token, refresh_token="a refresh token"),
145+
}, now=1000)
146+
self.assertEqual(
147+
{
148+
'cached_at': "1000",
149+
'client_id': 'my_client_id',
150+
'credential_type': 'AccessToken',
151+
'environment': 'fs.msidlab8.com',
152+
'expires_on': "4600",
153+
'extended_expires_on': "4600",
154+
'home_account_id': "subject.adfs",
155+
'realm': 'adfs',
156+
'secret': 'an access token',
157+
'target': 's2 s1 s3',
158+
},
159+
self.cache._cache["AccessToken"].get(
160+
'subject.adfs-fs.msidlab8.com-accesstoken-my_client_id-adfs-s2 s1 s3')
161+
)
162+
self.assertEqual(
163+
{
164+
'client_id': 'my_client_id',
165+
'credential_type': 'RefreshToken',
166+
'environment': 'fs.msidlab8.com',
167+
'home_account_id': "subject.adfs",
168+
'secret': 'a refresh token',
169+
'target': 's2 s1 s3',
170+
},
171+
self.cache._cache["RefreshToken"].get(
172+
'subject.adfs-fs.msidlab8.com-refreshtoken-my_client_id--s2 s1 s3')
173+
)
174+
self.assertEqual(
175+
{
176+
'home_account_id': "subject.adfs",
177+
'environment': 'fs.msidlab8.com',
178+
'realm': 'adfs',
179+
'local_account_id': "subject",
180+
'username': "[email protected]",
181+
'authority_type': "ADFS",
182+
},
183+
self.cache._cache["Account"].get('subject.adfs-fs.msidlab8.com-adfs')
184+
)
185+
self.assertEqual(
186+
{
187+
'credential_type': 'IdToken',
188+
'secret': id_token,
189+
'home_account_id': "subject.adfs",
190+
'environment': 'fs.msidlab8.com',
191+
'realm': 'adfs',
192+
'client_id': 'my_client_id',
193+
},
194+
self.cache._cache["IdToken"].get(
195+
'subject.adfs-fs.msidlab8.com-idtoken-my_client_id-adfs-')
196+
)
197+
self.assertEqual(
198+
{
199+
"client_id": "my_client_id",
200+
'environment': 'fs.msidlab8.com',
201+
},
202+
self.cache._cache.get("AppMetadata", {}).get(
203+
"appmetadata-fs.msidlab8.com-my_client_id")
204+
)
205+
135206

136207
class SerializableTokenCacheTestCase(TokenCacheTestCase):
137208
# Run all inherited test methods, and have extra check in tearDown()

0 commit comments

Comments
 (0)