Skip to content

Commit 8c5a1e2

Browse files
committed
Move some tests from test_application to test_e2e
1 parent 986ac01 commit 8c5a1e2

File tree

3 files changed

+144
-172
lines changed

3 files changed

+144
-172
lines changed

tests/test_application.py

Lines changed: 4 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# Note: Since Aug 2019 we move all e2e tests into test_e2e.py,
2+
# so this test_application file contains only unit tests without dependency.
13
import os
24
import json
35
import logging
@@ -13,91 +15,11 @@
1315
from tests.test_token_cache import TokenCacheTestCase
1416

1517

16-
THIS_FOLDER = os.path.dirname(__file__)
17-
CONFIG_FILE = os.path.join(THIS_FOLDER, 'config.json')
18-
CONFIG = {}
19-
if os.path.exists(CONFIG_FILE):
20-
with open(CONFIG_FILE) as conf:
21-
CONFIG = json.load(conf)
22-
2318
logger = logging.getLogger(__name__)
2419
logging.basicConfig(level=logging.DEBUG)
2520

2621

27-
class Oauth2TestCase(unittest.TestCase):
28-
29-
def assertLoosely(self, response, assertion=None,
30-
skippable_errors=("invalid_grant", "interaction_required")):
31-
if response.get("error") in skippable_errors:
32-
logger.debug("Response = %s", response)
33-
# Some of these errors are configuration issues, not library issues
34-
raise unittest.SkipTest(response.get("error_description"))
35-
else:
36-
if assertion is None:
37-
assertion = lambda: self.assertIn(
38-
"access_token", response,
39-
"{error}: {error_description}".format(
40-
# Do explicit response.get(...) rather than **response
41-
error=response.get("error"),
42-
error_description=response.get("error_description")))
43-
assertion()
44-
45-
def assertCacheWorks(self, result_from_wire):
46-
result = result_from_wire
47-
# You can filter by predefined username, or let end user to choose one
48-
accounts = self.app.get_accounts(username=CONFIG.get("username"))
49-
self.assertNotEqual(0, len(accounts))
50-
account = accounts[0]
51-
# Going to test acquire_token_silent(...) to locate an AT from cache
52-
result_from_cache = self.app.acquire_token_silent(
53-
CONFIG["scope"], account=account)
54-
self.assertIsNotNone(result_from_cache)
55-
self.assertEqual(result['access_token'], result_from_cache['access_token'],
56-
"We should get a cached AT")
57-
58-
# Going to test acquire_token_silent(...) to obtain an AT by a RT from cache
59-
self.app.token_cache._cache["AccessToken"] = {} # A hacky way to clear ATs
60-
result_from_cache = self.app.acquire_token_silent(
61-
CONFIG["scope"], account=account)
62-
self.assertIsNotNone(result_from_cache,
63-
"We should get a result from acquire_token_silent(...) call")
64-
self.assertNotEqual(result['access_token'], result_from_cache['access_token'],
65-
"We should get a fresh AT (via RT)")
66-
67-
68-
@unittest.skipUnless("client_id" in CONFIG, "client_id missing")
69-
class TestConfidentialClientApplication(unittest.TestCase):
70-
71-
def assertCacheWorks(self, result_from_wire, result_from_cache):
72-
self.assertIsNotNone(result_from_cache)
73-
self.assertEqual(
74-
result_from_wire['access_token'], result_from_cache['access_token'])
75-
76-
@unittest.skipUnless("client_secret" in CONFIG, "Missing client secret")
77-
def test_client_secret(self):
78-
app = ConfidentialClientApplication(
79-
CONFIG["client_id"], client_credential=CONFIG.get("client_secret"),
80-
authority=CONFIG.get("authority"))
81-
scope = CONFIG.get("scope", [])
82-
result = app.acquire_token_for_client(scope)
83-
self.assertIn('access_token', result)
84-
self.assertCacheWorks(result, app.acquire_token_silent(scope, account=None))
85-
86-
@unittest.skipUnless("client_certificate" in CONFIG, "Missing client cert")
87-
def test_client_certificate(self):
88-
client_certificate = CONFIG["client_certificate"]
89-
assert ("private_key_path" in client_certificate
90-
and "thumbprint" in client_certificate)
91-
key_path = os.path.join(THIS_FOLDER, client_certificate['private_key_path'])
92-
with open(key_path) as f:
93-
pem = f.read()
94-
app = ConfidentialClientApplication(
95-
CONFIG['client_id'],
96-
{"private_key": pem, "thumbprint": client_certificate["thumbprint"]})
97-
scope = CONFIG.get("scope", [])
98-
result = app.acquire_token_for_client(scope)
99-
self.assertIn('access_token', result)
100-
self.assertCacheWorks(result, app.acquire_token_silent(scope, account=None))
22+
class TestHelperExtractCerts(unittest.TestCase): # It is used by SNI scenario
10123

10224
def test_extract_a_tag_less_public_cert(self):
10325
pem = "my_cert"
@@ -116,92 +38,13 @@ def test_extract_multiple_tag_enclosed_certs(self):
11638
-----BEGIN CERTIFICATE-----
11739
my_cert1
11840
-----END CERTIFICATE-----
119-
41+
12042
-----BEGIN CERTIFICATE-----
12143
my_cert2
12244
-----END CERTIFICATE-----
12345
"""
12446
self.assertEqual(["my_cert1", "my_cert2"], extract_certs(pem))
12547

126-
@unittest.skipUnless("public_certificate" in CONFIG, "Missing Public cert")
127-
def test_subject_name_issuer_authentication(self):
128-
assert ("private_key_file" in CONFIG
129-
and "thumbprint" in CONFIG and "public_certificate" in CONFIG)
130-
with open(os.path.join(THIS_FOLDER, CONFIG['private_key_file'])) as f:
131-
pem = f.read()
132-
with open(os.path.join(THIS_FOLDER, CONFIG['public_certificate'])) as f:
133-
public_certificate = f.read()
134-
app = ConfidentialClientApplication(
135-
CONFIG['client_id'], authority=CONFIG["authority"],
136-
client_credential={"private_key": pem, "thumbprint": CONFIG["thumbprint"],
137-
"public_certificate": public_certificate})
138-
scope = CONFIG.get("scope", [])
139-
result = app.acquire_token_for_client(scope)
140-
self.assertIn('access_token', result)
141-
self.assertCacheWorks(result, app.acquire_token_silent(scope, account=None))
142-
143-
@unittest.skipUnless("client_id" in CONFIG, "client_id missing")
144-
class TestPublicClientApplication(Oauth2TestCase):
145-
146-
@unittest.skipUnless("username" in CONFIG and "password" in CONFIG, "Missing U/P")
147-
def test_username_password(self):
148-
self.app = PublicClientApplication(
149-
CONFIG["client_id"], authority=CONFIG["authority"])
150-
result = self.app.acquire_token_by_username_password(
151-
CONFIG["username"], CONFIG["password"], scopes=CONFIG.get("scope"))
152-
self.assertLoosely(result)
153-
self.assertCacheWorks(result)
154-
155-
def test_device_flow(self):
156-
self.app = PublicClientApplication(
157-
CONFIG["client_id"], authority=CONFIG["authority"])
158-
flow = self.app.initiate_device_flow(scopes=CONFIG.get("scope"))
159-
assert "user_code" in flow, str(flow) # Provision or policy might block DF
160-
logging.warning(flow["message"])
161-
162-
duration = 30
163-
logging.warning("We will wait up to %d seconds for you to sign in" % duration)
164-
flow["expires_at"] = time.time() + duration # Shorten the time for quick test
165-
result = self.app.acquire_token_by_device_flow(flow)
166-
self.assertLoosely(
167-
result,
168-
assertion=lambda: self.assertIn('access_token', result),
169-
skippable_errors=self.app.client.DEVICE_FLOW_RETRIABLE_ERRORS)
170-
171-
if "access_token" in result:
172-
self.assertCacheWorks(result)
173-
174-
175-
@unittest.skipUnless("client_id" in CONFIG, "client_id missing")
176-
class TestClientApplication(Oauth2TestCase):
177-
178-
@classmethod
179-
def setUpClass(cls):
180-
cls.app = ClientApplication(
181-
CONFIG["client_id"], client_credential=CONFIG.get("client_secret"),
182-
authority=CONFIG.get("authority"))
183-
184-
@unittest.skipUnless("scope" in CONFIG, "Missing scope")
185-
def test_auth_code(self):
186-
from msal.oauth2cli.authcode import obtain_auth_code
187-
port = CONFIG.get("listen_port", 44331)
188-
redirect_uri = "http://localhost:%s" % port
189-
auth_request_uri = self.app.get_authorization_request_url(
190-
CONFIG["scope"], redirect_uri=redirect_uri)
191-
ac = obtain_auth_code(port, auth_uri=auth_request_uri)
192-
self.assertNotEqual(ac, None)
193-
194-
result = self.app.acquire_token_by_authorization_code(
195-
ac, CONFIG["scope"], redirect_uri=redirect_uri)
196-
logging.debug("cache = %s", json.dumps(self.app.token_cache._cache, indent=4))
197-
self.assertIn(
198-
"access_token", result,
199-
"{error}: {error_description}".format(
200-
# Note: No interpolation here, cause error won't always present
201-
error=result.get("error"),
202-
error_description=result.get("error_description")))
203-
self.assertCacheWorks(result)
204-
20548

20649
class TestClientApplicationAcquireTokenSilentFociBehaviors(unittest.TestCase):
20750

tests/test_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def load_conf(filename):
7878

7979
# Since the OAuth2 specs uses snake_case, this test config also uses snake_case
8080
@unittest.skipUnless("client_id" in CONFIG, "client_id missing")
81+
@unittest.skipUnless(CONFIG.get("openid_configuration"), "openid_configuration missing")
8182
class TestClient(Oauth2TestCase):
8283

8384
@classmethod

tests/test_e2e.py

Lines changed: 139 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
import os
33
import json
4+
import time
45

56
import requests
67

@@ -31,25 +32,34 @@ def assertLoosely(self, response, assertion=None,
3132
error_description=response.get("error_description")))
3233
assertion()
3334

34-
def assertCacheWorks(self, result_from_wire, username, scope):
35-
result = result_from_wire
35+
def assertCacheWorksForUser(self, result_from_wire, scope, username=None):
3636
# You can filter by predefined username, or let end user to choose one
3737
accounts = self.app.get_accounts(username=username)
3838
self.assertNotEqual(0, len(accounts))
3939
account = accounts[0]
4040
# Going to test acquire_token_silent(...) to locate an AT from cache
4141
result_from_cache = self.app.acquire_token_silent(scope, account=account)
4242
self.assertIsNotNone(result_from_cache)
43-
self.assertEqual(result['access_token'], result_from_cache['access_token'],
44-
"We should get a cached AT")
43+
self.assertEqual(
44+
result_from_wire['access_token'], result_from_cache['access_token'],
45+
"We should get a cached AT")
4546

4647
# Going to test acquire_token_silent(...) to obtain an AT by a RT from cache
4748
self.app.token_cache._cache["AccessToken"] = {} # A hacky way to clear ATs
4849
result_from_cache = self.app.acquire_token_silent(scope, account=account)
4950
self.assertIsNotNone(result_from_cache,
5051
"We should get a result from acquire_token_silent(...) call")
51-
self.assertNotEqual(result['access_token'], result_from_cache['access_token'],
52-
"We should get a fresh AT (via RT)")
52+
self.assertNotEqual(
53+
result_from_wire['access_token'], result_from_cache['access_token'],
54+
"We should get a fresh AT (via RT)")
55+
56+
def assertCacheWorksForApp(self, result_from_wire, scope):
57+
# Going to test acquire_token_silent(...) to locate an AT from cache
58+
result_from_cache = self.app.acquire_token_silent(scope, account=None)
59+
self.assertIsNotNone(result_from_cache)
60+
self.assertEqual(
61+
result_from_wire['access_token'], result_from_cache['access_token'],
62+
"We should get a cached AT")
5363

5464
def _test_username_password(self,
5565
authority=None, client_id=None, username=None, password=None, scope=None,
@@ -60,19 +70,137 @@ def _test_username_password(self,
6070
username, password, scopes=scope)
6171
self.assertLoosely(result)
6272
# self.assertEqual(None, result.get("error"), str(result))
63-
self.assertCacheWorks(result, username, scope)
73+
self.assertCacheWorksForUser(result, scope, username=username)
6474

6575

66-
CONFIG = os.path.join(os.path.dirname(__file__), "config.json")
67-
@unittest.skipIf(not os.path.exists(CONFIG), "Optional %s not found" % CONFIG)
76+
THIS_FOLDER = os.path.dirname(__file__)
77+
CONFIG = os.path.join(THIS_FOLDER, "config.json")
78+
@unittest.skipUnless(os.path.exists(CONFIG), "Optional %s not found" % CONFIG)
6879
class FileBasedTestCase(E2eTestCase):
69-
def setUp(self):
80+
# This covers scenarios that are not currently available for test automation.
81+
# So they mean to be run on maintainer's machine for semi-automated tests.
82+
83+
@classmethod
84+
def setUpClass(cls):
7085
with open(CONFIG) as f:
71-
self.config = json.load(f)
86+
cls.config = json.load(f)
87+
88+
def skipUnlessWithConfig(self, fields):
89+
for field in fields:
90+
if field not in self.config:
91+
self.skipTest('Skipping due to lack of configuration "%s"' % field)
7292

7393
def test_username_password(self):
94+
self.skipUnlessWithConfig(["client_id", "username", "password", "scope"])
7495
self._test_username_password(**self.config)
7596

97+
def test_auth_code(self):
98+
self.skipUnlessWithConfig(["client_id", "scope"])
99+
from msal.oauth2cli.authcode import obtain_auth_code
100+
self.app = msal.ClientApplication(
101+
self.config["client_id"],
102+
client_credential=self.config.get("client_secret"),
103+
authority=self.config.get("authority"))
104+
port = self.config.get("listen_port", 44331)
105+
redirect_uri = "http://localhost:%s" % port
106+
auth_request_uri = self.app.get_authorization_request_url(
107+
self.config["scope"], redirect_uri=redirect_uri)
108+
ac = obtain_auth_code(port, auth_uri=auth_request_uri)
109+
self.assertNotEqual(ac, None)
110+
111+
result = self.app.acquire_token_by_authorization_code(
112+
ac, self.config["scope"], redirect_uri=redirect_uri)
113+
logger.debug("%s.cache = %s",
114+
self.id(), json.dumps(self.app.token_cache._cache, indent=4))
115+
self.assertIn(
116+
"access_token", result,
117+
"{error}: {error_description}".format(
118+
# Note: No interpolation here, cause error won't always present
119+
error=result.get("error"),
120+
error_description=result.get("error_description")))
121+
self.assertCacheWorksForUser(result, self.config["scope"], username=None)
122+
123+
def test_client_secret(self):
124+
self.skipUnlessWithConfig(["client_id", "client_secret"])
125+
self.app = msal.ConfidentialClientApplication(
126+
self.config["client_id"],
127+
client_credential=self.config.get("client_secret"),
128+
authority=self.config.get("authority"))
129+
scope = self.config.get("scope", [])
130+
result = self.app.acquire_token_for_client(scope)
131+
self.assertIn('access_token', result)
132+
self.assertCacheWorksForApp(result, scope)
133+
134+
def test_client_certificate(self):
135+
self.skipUnlessWithConfig(["client_id", "client_certificate"])
136+
client_cert = self.config["client_certificate"]
137+
assert "private_key_path" in client_cert and "thumbprint" in client_cert
138+
with open(os.path.join(THIS_FOLDER, client_cert['private_key_path'])) as f:
139+
private_key = f.read() # Should be in PEM format
140+
self.app = msal.ConfidentialClientApplication(
141+
self.config['client_id'],
142+
{"private_key": private_key, "thumbprint": client_cert["thumbprint"]})
143+
scope = self.config.get("scope", [])
144+
result = self.app.acquire_token_for_client(scope)
145+
self.assertIn('access_token', result)
146+
self.assertCacheWorksForApp(result, scope)
147+
148+
def test_subject_name_issuer_authentication(self):
149+
self.skipUnlessWithConfig(["client_id", "client_certificate"])
150+
client_cert = self.config["client_certificate"]
151+
assert "private_key_path" in client_cert and "thumbprint" in client_cert
152+
if not "public_certificate" in client_cert:
153+
self.skipTest("Skipping SNI test due to lack of public_certificate")
154+
with open(os.path.join(THIS_FOLDER, client_cert['private_key_path'])) as f:
155+
private_key = f.read() # Should be in PEM format
156+
with open(os.path.join(THIS_FOLDER, client_cert['public_certificate'])) as f:
157+
public_certificate = f.read()
158+
self.app = msal.ConfidentialClientApplication(
159+
self.config['client_id'], authority=self.config["authority"],
160+
client_credential={
161+
"private_key": private_key,
162+
"thumbprint": self.config["thumbprint"],
163+
"public_certificate": public_certificate,
164+
})
165+
scope = self.config.get("scope", [])
166+
result = self.app.acquire_token_for_client(scope)
167+
self.assertIn('access_token', result)
168+
self.assertCacheWorksForApp(result, scope)
169+
170+
171+
@unittest.skipUnless(os.path.exists(CONFIG), "Optional %s not found" % CONFIG)
172+
class DeviceFlowTestCase(E2eTestCase): # A leaf class so it will be run only once
173+
@classmethod
174+
def setUpClass(cls):
175+
with open(CONFIG) as f:
176+
cls.config = json.load(f)
177+
178+
def test_device_flow(self):
179+
scopes = self.config["scope"]
180+
self.app = msal.PublicClientApplication(
181+
self.config['client_id'], authority=self.config["authority"])
182+
flow = self.app.initiate_device_flow(scopes=scopes)
183+
assert "user_code" in flow, "DF does not seem to be provisioned: %s".format(
184+
json.dumps(flow, indent=4))
185+
logger.info(flow["message"])
186+
187+
duration = 60
188+
logger.info("We will wait up to %d seconds for you to sign in" % duration)
189+
flow["expires_at"] = min( # Shorten the time for quick test
190+
flow["expires_at"], time.time() + duration)
191+
result = self.app.acquire_token_by_device_flow(flow)
192+
self.assertLoosely( # It will skip this test if there is no user interaction
193+
result,
194+
assertion=lambda: self.assertIn('access_token', result),
195+
skippable_errors=self.app.client.DEVICE_FLOW_RETRIABLE_ERRORS)
196+
if "access_token" not in result:
197+
self.skip("End user did not complete Device Flow in time")
198+
self.assertCacheWorksForUser(result, scopes, username=None)
199+
result["access_token"] = result["refresh_token"] = "************"
200+
logger.info(
201+
"%s obtained tokens: %s", self.id(), json.dumps(result, indent=4))
202+
203+
76204
def get_lab_user(mam=False, mfa=False, isFederated=False, federationProvider=None):
77205
# Based on https://microsoft.sharepoint-df.com/teams/MSIDLABSExtended/SitePages/LAB.aspx
78206
user = requests.get("https://api.msidlab.com/api/user", params=dict( # Publicly available

0 commit comments

Comments
 (0)