11import logging
22import os
33import json
4+ import time
45
56import requests
67
@@ -31,25 +32,34 @@ def assertLoosely(self, response, assertion=None,
3132 error_description = response .get ("error_description" )))
3233 assertion ()
3334
34- def assertCacheWorks (self , result_from_wire , username , scope ):
35- result = result_from_wire
35+ def assertCacheWorksForUser (self , result_from_wire , scope , username = None ):
3636 # You can filter by predefined username, or let end user to choose one
3737 accounts = self .app .get_accounts (username = username )
3838 self .assertNotEqual (0 , len (accounts ))
3939 account = accounts [0 ]
4040 # Going to test acquire_token_silent(...) to locate an AT from cache
4141 result_from_cache = self .app .acquire_token_silent (scope , account = account )
4242 self .assertIsNotNone (result_from_cache )
43- self .assertEqual (result ['access_token' ], result_from_cache ['access_token' ],
44- "We should get a cached AT" )
43+ self .assertEqual (
44+ result_from_wire ['access_token' ], result_from_cache ['access_token' ],
45+ "We should get a cached AT" )
4546
4647 # Going to test acquire_token_silent(...) to obtain an AT by a RT from cache
4748 self .app .token_cache ._cache ["AccessToken" ] = {} # A hacky way to clear ATs
4849 result_from_cache = self .app .acquire_token_silent (scope , account = account )
4950 self .assertIsNotNone (result_from_cache ,
5051 "We should get a result from acquire_token_silent(...) call" )
51- self .assertNotEqual (result ['access_token' ], result_from_cache ['access_token' ],
52- "We should get a fresh AT (via RT)" )
52+ self .assertNotEqual (
53+ result_from_wire ['access_token' ], result_from_cache ['access_token' ],
54+ "We should get a fresh AT (via RT)" )
55+
56+ def assertCacheWorksForApp (self , result_from_wire , scope ):
57+ # Going to test acquire_token_silent(...) to locate an AT from cache
58+ result_from_cache = self .app .acquire_token_silent (scope , account = None )
59+ self .assertIsNotNone (result_from_cache )
60+ self .assertEqual (
61+ result_from_wire ['access_token' ], result_from_cache ['access_token' ],
62+ "We should get a cached AT" )
5363
5464 def _test_username_password (self ,
5565 authority = None , client_id = None , username = None , password = None , scope = None ,
@@ -60,19 +70,137 @@ def _test_username_password(self,
6070 username , password , scopes = scope )
6171 self .assertLoosely (result )
6272 # self.assertEqual(None, result.get("error"), str(result))
63- self .assertCacheWorks (result , username , scope )
73+ self .assertCacheWorksForUser (result , scope , username = username )
6474
6575
66- CONFIG = os .path .join (os .path .dirname (__file__ ), "config.json" )
67- @unittest .skipIf (not os .path .exists (CONFIG ), "Optional %s not found" % CONFIG )
76+ THIS_FOLDER = os .path .dirname (__file__ )
77+ CONFIG = os .path .join (THIS_FOLDER , "config.json" )
78+ @unittest .skipUnless (os .path .exists (CONFIG ), "Optional %s not found" % CONFIG )
6879class FileBasedTestCase (E2eTestCase ):
69- def setUp (self ):
80+ # This covers scenarios that are not currently available for test automation.
81+ # So they mean to be run on maintainer's machine for semi-automated tests.
82+
83+ @classmethod
84+ def setUpClass (cls ):
7085 with open (CONFIG ) as f :
71- self .config = json .load (f )
86+ cls .config = json .load (f )
87+
88+ def skipUnlessWithConfig (self , fields ):
89+ for field in fields :
90+ if field not in self .config :
91+ self .skipTest ('Skipping due to lack of configuration "%s"' % field )
7292
7393 def test_username_password (self ):
94+ self .skipUnlessWithConfig (["client_id" , "username" , "password" , "scope" ])
7495 self ._test_username_password (** self .config )
7596
97+ def test_auth_code (self ):
98+ self .skipUnlessWithConfig (["client_id" , "scope" ])
99+ from msal .oauth2cli .authcode import obtain_auth_code
100+ self .app = msal .ClientApplication (
101+ self .config ["client_id" ],
102+ client_credential = self .config .get ("client_secret" ),
103+ authority = self .config .get ("authority" ))
104+ port = self .config .get ("listen_port" , 44331 )
105+ redirect_uri = "http://localhost:%s" % port
106+ auth_request_uri = self .app .get_authorization_request_url (
107+ self .config ["scope" ], redirect_uri = redirect_uri )
108+ ac = obtain_auth_code (port , auth_uri = auth_request_uri )
109+ self .assertNotEqual (ac , None )
110+
111+ result = self .app .acquire_token_by_authorization_code (
112+ ac , self .config ["scope" ], redirect_uri = redirect_uri )
113+ logger .debug ("%s.cache = %s" ,
114+ self .id (), json .dumps (self .app .token_cache ._cache , indent = 4 ))
115+ self .assertIn (
116+ "access_token" , result ,
117+ "{error}: {error_description}" .format (
118+ # Note: No interpolation here, cause error won't always present
119+ error = result .get ("error" ),
120+ error_description = result .get ("error_description" )))
121+ self .assertCacheWorksForUser (result , self .config ["scope" ], username = None )
122+
123+ def test_client_secret (self ):
124+ self .skipUnlessWithConfig (["client_id" , "client_secret" ])
125+ self .app = msal .ConfidentialClientApplication (
126+ self .config ["client_id" ],
127+ client_credential = self .config .get ("client_secret" ),
128+ authority = self .config .get ("authority" ))
129+ scope = self .config .get ("scope" , [])
130+ result = self .app .acquire_token_for_client (scope )
131+ self .assertIn ('access_token' , result )
132+ self .assertCacheWorksForApp (result , scope )
133+
134+ def test_client_certificate (self ):
135+ self .skipUnlessWithConfig (["client_id" , "client_certificate" ])
136+ client_cert = self .config ["client_certificate" ]
137+ assert "private_key_path" in client_cert and "thumbprint" in client_cert
138+ with open (os .path .join (THIS_FOLDER , client_cert ['private_key_path' ])) as f :
139+ private_key = f .read () # Should be in PEM format
140+ self .app = msal .ConfidentialClientApplication (
141+ self .config ['client_id' ],
142+ {"private_key" : private_key , "thumbprint" : client_cert ["thumbprint" ]})
143+ scope = self .config .get ("scope" , [])
144+ result = self .app .acquire_token_for_client (scope )
145+ self .assertIn ('access_token' , result )
146+ self .assertCacheWorksForApp (result , scope )
147+
148+ def test_subject_name_issuer_authentication (self ):
149+ self .skipUnlessWithConfig (["client_id" , "client_certificate" ])
150+ client_cert = self .config ["client_certificate" ]
151+ assert "private_key_path" in client_cert and "thumbprint" in client_cert
152+ if not "public_certificate" in client_cert :
153+ self .skipTest ("Skipping SNI test due to lack of public_certificate" )
154+ with open (os .path .join (THIS_FOLDER , client_cert ['private_key_path' ])) as f :
155+ private_key = f .read () # Should be in PEM format
156+ with open (os .path .join (THIS_FOLDER , client_cert ['public_certificate' ])) as f :
157+ public_certificate = f .read ()
158+ self .app = msal .ConfidentialClientApplication (
159+ self .config ['client_id' ], authority = self .config ["authority" ],
160+ client_credential = {
161+ "private_key" : private_key ,
162+ "thumbprint" : self .config ["thumbprint" ],
163+ "public_certificate" : public_certificate ,
164+ })
165+ scope = self .config .get ("scope" , [])
166+ result = self .app .acquire_token_for_client (scope )
167+ self .assertIn ('access_token' , result )
168+ self .assertCacheWorksForApp (result , scope )
169+
170+
171+ @unittest .skipUnless (os .path .exists (CONFIG ), "Optional %s not found" % CONFIG )
172+ class DeviceFlowTestCase (E2eTestCase ): # A leaf class so it will be run only once
173+ @classmethod
174+ def setUpClass (cls ):
175+ with open (CONFIG ) as f :
176+ cls .config = json .load (f )
177+
178+ def test_device_flow (self ):
179+ scopes = self .config ["scope" ]
180+ self .app = msal .PublicClientApplication (
181+ self .config ['client_id' ], authority = self .config ["authority" ])
182+ flow = self .app .initiate_device_flow (scopes = scopes )
183+ assert "user_code" in flow , "DF does not seem to be provisioned: %s" .format (
184+ json .dumps (flow , indent = 4 ))
185+ logger .info (flow ["message" ])
186+
187+ duration = 60
188+ logger .info ("We will wait up to %d seconds for you to sign in" % duration )
189+ flow ["expires_at" ] = min ( # Shorten the time for quick test
190+ flow ["expires_at" ], time .time () + duration )
191+ result = self .app .acquire_token_by_device_flow (flow )
192+ self .assertLoosely ( # It will skip this test if there is no user interaction
193+ result ,
194+ assertion = lambda : self .assertIn ('access_token' , result ),
195+ skippable_errors = self .app .client .DEVICE_FLOW_RETRIABLE_ERRORS )
196+ if "access_token" not in result :
197+ self .skip ("End user did not complete Device Flow in time" )
198+ self .assertCacheWorksForUser (result , scopes , username = None )
199+ result ["access_token" ] = result ["refresh_token" ] = "************"
200+ logger .info (
201+ "%s obtained tokens: %s" , self .id (), json .dumps (result , indent = 4 ))
202+
203+
76204def get_lab_user (mam = False , mfa = False , isFederated = False , federationProvider = None ):
77205 # Based on https://microsoft.sharepoint-df.com/teams/MSIDLABSExtended/SitePages/LAB.aspx
78206 user = requests .get ("https://api.msidlab.com/api/user" , params = dict ( # Publicly available
0 commit comments