Skip to content

Commit 0d47cae

Browse files
authored
Merge pull request #77 from AzureAD/adfs-direct
Adjusting token cache to work with ADFS2019
2 parents ab4cac4 + 7b8f6f5 commit 0d47cae

File tree

3 files changed

+107
-22
lines changed

3 files changed

+107
-22
lines changed

msal/authority.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,20 @@ def __init__(self, authority_url, validate_authority=True,
3838
self.proxies = proxies
3939
self.timeout = timeout
4040
canonicalized, self.instance, tenant = canonicalize(authority_url)
41-
tenant_discovery_endpoint = ( # Hard code a V2 pattern as default value
42-
'https://{}/{}/v2.0/.well-known/openid-configuration'
43-
.format(self.instance, tenant))
44-
if validate_authority and self.instance not in WELL_KNOWN_AUTHORITY_HOSTS:
41+
tenant_discovery_endpoint = (
42+
'https://{}/{}{}/.well-known/openid-configuration'.format(
43+
self.instance,
44+
tenant,
45+
"" if tenant == "adfs" else "/v2.0" # the AAD v2 endpoint
46+
))
47+
if (tenant != "adfs" and validate_authority
48+
and self.instance not in WELL_KNOWN_AUTHORITY_HOSTS):
4549
tenant_discovery_endpoint = instance_discovery(
4650
canonicalized + "/oauth2/v2.0/authorize",
4751
verify=verify, proxies=proxies, timeout=timeout)
52+
if tenant.lower() == "adfs":
53+
tenant_discovery_endpoint = ("https://{}/adfs/.well-known/openid-configuration"
54+
.format(self.instance))
4855
openid_config = tenant_discovery(
4956
tenant_discovery_endpoint,
5057
verify=verify, proxies=proxies, timeout=timeout)

msal/token_cache.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -111,18 +111,25 @@ def add(self, event, now=None):
111111
event, indent=4, sort_keys=True,
112112
default=str, # A workaround when assertion is in bytes in Python 3
113113
))
114+
environment = realm = None
115+
if "token_endpoint" in event:
116+
_, environment, realm = canonicalize(event["token_endpoint"])
114117
response = event.get("response", {})
115118
access_token = response.get("access_token")
116119
refresh_token = response.get("refresh_token")
117120
id_token = response.get("id_token")
121+
id_token_claims = (
122+
decode_id_token(id_token, client_id=event["client_id"])
123+
if id_token else {})
118124
client_info = {}
119-
home_account_id = None
120-
if "client_info" in response:
125+
home_account_id = None # It would remain None in client_credentials flow
126+
if "client_info" in response: # We asked for it, and AAD will provide it
121127
client_info = json.loads(base64decode(response["client_info"]))
122128
home_account_id = "{uid}.{utid}".format(**client_info)
123-
environment = realm = None
124-
if "token_endpoint" in event:
125-
_, environment, realm = canonicalize(event["token_endpoint"])
129+
elif id_token_claims: # This would be an end user on ADFS-direct scenario
130+
client_info["uid"] = id_token_claims.get("sub")
131+
home_account_id = id_token_claims.get("sub")
132+
126133
target = ' '.join(event.get("scope", [])) # Per schema, we don't sort it
127134

128135
with self._lock:
@@ -148,15 +155,15 @@ def add(self, event, now=None):
148155
self.modify(self.CredentialType.ACCESS_TOKEN, at, at)
149156

150157
if client_info:
151-
decoded_id_token = decode_id_token(
152-
id_token, client_id=event["client_id"]) if id_token else {}
153158
account = {
154159
"home_account_id": home_account_id,
155160
"environment": environment,
156161
"realm": realm,
157-
"local_account_id": decoded_id_token.get(
158-
"oid", decoded_id_token.get("sub")),
159-
"username": decoded_id_token.get("preferred_username"),
162+
"local_account_id": id_token_claims.get(
163+
"oid", id_token_claims.get("sub")),
164+
"username": id_token_claims.get("preferred_username") # AAD
165+
or id_token_claims.get("upn") # ADFS 2019
166+
or "", # The schema does not like null
160167
"authority_type":
161168
self.AuthorityType.ADFS if realm == "adfs"
162169
else self.AuthorityType.MSSTS,

tests/test_token_cache.py

Lines changed: 79 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,30 +16,29 @@ class TokenCacheTestCase(unittest.TestCase):
1616
@staticmethod
1717
def build_id_token(
1818
iss="issuer", sub="subject", aud="my_client_id", exp=None, iat=None,
19-
preferred_username="me", **claims):
19+
**claims): # AAD issues "preferred_username", ADFS issues "upn"
2020
return "header.%s.signature" % base64.b64encode(json.dumps(dict({
2121
"iss": iss,
2222
"sub": sub,
2323
"aud": aud,
2424
"exp": exp or (time.time() + 100),
2525
"iat": iat or time.time(),
26-
"preferred_username": preferred_username,
2726
}, **claims)).encode()).decode('utf-8')
2827

2928
@staticmethod
3029
def build_response( # simulate a response from AAD
31-
uid="uid", utid="utid", # They will form client_info
30+
uid=None, utid=None, # If present, they will form client_info
3231
access_token=None, expires_in=3600, token_type="some type",
3332
refresh_token=None,
3433
foci=None,
3534
id_token=None, # or something generated by build_id_token()
3635
error=None,
3736
):
38-
response = {
39-
"client_info": base64.b64encode(json.dumps({
37+
response = {}
38+
if uid and utid: # Mimic the AAD behavior for "client_info=1" request
39+
response["client_info"] = base64.b64encode(json.dumps({
4040
"uid": uid, "utid": utid,
41-
}).encode()).decode('utf-8'),
42-
}
41+
}).encode()).decode('utf-8')
4342
if error:
4443
response["error"] = error
4544
if access_token:
@@ -59,7 +58,7 @@ def build_response( # simulate a response from AAD
5958
def setUp(self):
6059
self.cache = TokenCache()
6160

62-
def testAdd(self):
61+
def testAddByAad(self):
6362
client_id = "my_client_id"
6463
id_token = self.build_id_token(
6564
oid="object1234", preferred_username="John Doe", aud=client_id)
@@ -132,6 +131,78 @@ def testAdd(self):
132131
"appmetadata-login.example.com-my_client_id")
133132
)
134133

134+
def testAddByAdfs(self):
135+
client_id = "my_client_id"
136+
id_token = self.build_id_token(aud=client_id, upn="[email protected]")
137+
self.cache.add({
138+
"client_id": client_id,
139+
"scope": ["s2", "s1", "s3"], # Not in particular order
140+
"token_endpoint": "https://fs.msidlab8.com/adfs/oauth2/token",
141+
"response": self.build_response(
142+
uid=None, utid=None, # ADFS will provide no client_info
143+
expires_in=3600, access_token="an access token",
144+
id_token=id_token, refresh_token="a refresh token"),
145+
}, now=1000)
146+
self.assertEqual(
147+
{
148+
'cached_at': "1000",
149+
'client_id': 'my_client_id',
150+
'credential_type': 'AccessToken',
151+
'environment': 'fs.msidlab8.com',
152+
'expires_on': "4600",
153+
'extended_expires_on': "4600",
154+
'home_account_id': "subject",
155+
'realm': 'adfs',
156+
'secret': 'an access token',
157+
'target': 's2 s1 s3',
158+
},
159+
self.cache._cache["AccessToken"].get(
160+
'subject-fs.msidlab8.com-accesstoken-my_client_id-adfs-s2 s1 s3')
161+
)
162+
self.assertEqual(
163+
{
164+
'client_id': 'my_client_id',
165+
'credential_type': 'RefreshToken',
166+
'environment': 'fs.msidlab8.com',
167+
'home_account_id': "subject",
168+
'secret': 'a refresh token',
169+
'target': 's2 s1 s3',
170+
},
171+
self.cache._cache["RefreshToken"].get(
172+
'subject-fs.msidlab8.com-refreshtoken-my_client_id--s2 s1 s3')
173+
)
174+
self.assertEqual(
175+
{
176+
'home_account_id': "subject",
177+
'environment': 'fs.msidlab8.com',
178+
'realm': 'adfs',
179+
'local_account_id': "subject",
180+
'username': "[email protected]",
181+
'authority_type': "ADFS",
182+
},
183+
self.cache._cache["Account"].get('subject-fs.msidlab8.com-adfs')
184+
)
185+
self.assertEqual(
186+
{
187+
'credential_type': 'IdToken',
188+
'secret': id_token,
189+
'home_account_id': "subject",
190+
'environment': 'fs.msidlab8.com',
191+
'realm': 'adfs',
192+
'client_id': 'my_client_id',
193+
},
194+
self.cache._cache["IdToken"].get(
195+
'subject-fs.msidlab8.com-idtoken-my_client_id-adfs-')
196+
)
197+
self.assertEqual(
198+
{
199+
"client_id": "my_client_id",
200+
'environment': 'fs.msidlab8.com',
201+
},
202+
self.cache._cache.get("AppMetadata", {}).get(
203+
"appmetadata-fs.msidlab8.com-my_client_id")
204+
)
205+
135206

136207
class SerializableTokenCacheTestCase(TokenCacheTestCase):
137208
# Run all inherited test methods, and have extra check in tearDown()

0 commit comments

Comments
 (0)