Skip to content

Commit

Permalink
Fix rest_async and async tests (#556)
Browse files Browse the repository at this point in the history
### Changes

Broke rest_sync in #518 and an issue in the async tests meant it didn't
get picked up.

### References

fixes #555
  • Loading branch information
adamjmcgrath authored Nov 28, 2023
2 parents 5dce1cc + 00fa6fa commit 7625804
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 41 deletions.
7 changes: 7 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,15 @@ jobs:
poetry self add "poetry-dynamic-versioning[plugin]"
- name: Run tests
if: ${{ matrix.python-version != '3.7' }}
run: |
poetry run pytest --cov=auth0 --cov-report=term-missing:skip-covered --cov-report=xml
- name: Run tests 3.7
# Skip async tests in 3.7
if: ${{ matrix.python-version == '3.7' }}
run: |
poetry run pytest auth0/test
# bwrap ${{ env.BUBBLEWRAP_ARGUMENTS }} bash

# - name: Run lint
Expand Down
4 changes: 2 additions & 2 deletions auth0/rest_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,11 @@ async def _request(self, *args: Any, **kwargs: Any) -> Any:
kwargs["timeout"] = self.timeout
if self._session is not None:
# Request with re-usable session
return self._request_with_session(self.session, *args, **kwargs)
return await self._request_with_session(self._session, *args, **kwargs)
else:
# Request without re-usable session
async with aiohttp.ClientSession() as session:
return self._request_with_session(session, *args, **kwargs)
return await self._request_with_session(session, *args, **kwargs)

async def get(
self,
Expand Down
8 changes: 4 additions & 4 deletions auth0/test_async/test_async_auth0.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ def callback(url, **kwargs):
return callback, mock


class TestAuth0(unittest.TestCase):
class TestAuth0(unittest.IsolatedAsyncioTestCase):
@pytest.mark.asyncio
@aioresponses()
async def test_get(self, mocked):
callback, mock = get_callback()

await mocked.get(clients, callback=callback)
mocked.get(clients, callback=callback)

auth0 = Auth0(domain="example.com", token="jwt")

Expand All @@ -48,8 +48,8 @@ async def test_shared_session(self, mocked):
callback, mock = get_callback()
callback2, mock2 = get_callback()

await mocked.get(clients, callback=callback)
await mocked.put(factors, callback=callback2)
mocked.get(clients, callback=callback)
mocked.put(factors, callback=callback2)

async with Auth0(domain="example.com", token="jwt") as auth0:
self.assertEqual(await auth0.clients.all_async(), payload)
Expand Down
30 changes: 15 additions & 15 deletions auth0/test_async/test_async_token_verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@ def get_pem_bytes(rsa_public_key):
)


class TestAsyncAsymmetricSignatureVerifier(unittest.TestCase):
class TestAsyncAsymmetricSignatureVerifier(unittest.IsolatedAsyncioTestCase):
@pytest.mark.asyncio
@aioresponses()
async def test_async_asymmetric_verifier_fetches_key(self, mocked):
callback, mock = get_callback(200, JWKS_RESPONSE_SINGLE_KEY)
await mocked.get(JWKS_URI, callback=callback)
mocked.get(JWKS_URI, callback=callback)

verifier = AsyncAsymmetricSignatureVerifier(JWKS_URI)

Expand All @@ -69,7 +69,7 @@ async def test_async_asymmetric_verifier_fetches_key(self, mocked):
self.assertEqual(get_pem_bytes(key), RSA_PUB_KEY_1_PEM)


class TestAsyncJwksFetcher(unittest.TestCase):
class TestAsyncJwksFetcher(unittest.IsolatedAsyncioTestCase):
@pytest.mark.asyncio
@aioresponses()
@unittest.mock.patch(
Expand All @@ -81,8 +81,8 @@ async def test_async_get_jwks_json_twice_on_cache_expired(
fetcher = AsyncJwksFetcher(JWKS_URI, cache_ttl=100)

callback, mock = get_callback(200, JWKS_RESPONSE_SINGLE_KEY)
await mocked.get(JWKS_URI, callback=callback)
await mocked.get(JWKS_URI, callback=callback)
mocked.get(JWKS_URI, callback=callback)
mocked.get(JWKS_URI, callback=callback)

key_1 = await fetcher.get_key("test-key-1")
expected_key_1_pem = get_pem_bytes(key_1)
Expand Down Expand Up @@ -119,8 +119,8 @@ async def test_async_get_jwks_json_once_on_cache_hit(self, mocked):
fetcher = AsyncJwksFetcher(JWKS_URI, cache_ttl=1)

callback, mock = get_callback(200, JWKS_RESPONSE_MULTIPLE_KEYS)
await mocked.get(JWKS_URI, callback=callback)
await mocked.get(JWKS_URI, callback=callback)
mocked.get(JWKS_URI, callback=callback)
mocked.get(JWKS_URI, callback=callback)

key_1 = await fetcher.get_key("test-key-1")
key_2 = await fetcher.get_key("test-key-2")
Expand All @@ -144,7 +144,7 @@ async def test_async_fetches_jwks_json_forced_on_cache_miss(self, mocked):
fetcher = AsyncJwksFetcher(JWKS_URI, cache_ttl=1)

callback, mock = get_callback(200, {"keys": [RSA_PUB_KEY_1_JWK]})
await mocked.get(JWKS_URI, callback=callback)
mocked.get(JWKS_URI, callback=callback)

# Triggers the first call
key_1 = await fetcher.get_key("test-key-1")
Expand All @@ -161,7 +161,7 @@ async def test_async_fetches_jwks_json_forced_on_cache_miss(self, mocked):
self.assertEqual(mock.call_count, 1)

callback, mock = get_callback(200, JWKS_RESPONSE_MULTIPLE_KEYS)
await mocked.get(JWKS_URI, callback=callback)
mocked.get(JWKS_URI, callback=callback)

# Triggers the second call
key_2 = await fetcher.get_key("test-key-2")
Expand All @@ -183,7 +183,7 @@ async def test_async_fetches_jwks_json_once_on_cache_miss(self, mocked):
fetcher = AsyncJwksFetcher(JWKS_URI, cache_ttl=1)

callback, mock = get_callback(200, JWKS_RESPONSE_SINGLE_KEY)
await mocked.get(JWKS_URI, callback=callback)
mocked.get(JWKS_URI, callback=callback)

with self.assertRaises(Exception) as err:
await fetcher.get_key("missing-key")
Expand All @@ -206,8 +206,8 @@ async def test_async_fails_to_fetch_jwks_json_after_retrying_twice(self, mocked)
fetcher = AsyncJwksFetcher(JWKS_URI, cache_ttl=1)

callback, mock = get_callback(500, {})
await mocked.get(JWKS_URI, callback=callback)
await mocked.get(JWKS_URI, callback=callback)
mocked.get(JWKS_URI, callback=callback)
mocked.get(JWKS_URI, callback=callback)

with self.assertRaises(Exception) as err:
await fetcher.get_key("id1")
Expand All @@ -225,12 +225,12 @@ async def test_async_fails_to_fetch_jwks_json_after_retrying_twice(self, mocked)
self.assertEqual(mock.call_count, 2)


class TestAsyncTokenVerifier(unittest.TestCase):
class TestAsyncTokenVerifier(unittest.IsolatedAsyncioTestCase):
@pytest.mark.asyncio
@aioresponses()
async def test_RS256_token_signature_passes(self, mocked):
callback, mock = get_callback(200, {"keys": [PUBLIC_KEY]})
await mocked.get(JWKS_URI, callback=callback)
mocked.get(JWKS_URI, callback=callback)

issuer = "https://tokens-test.auth0.com/"
audience = "tokens-test-123"
Expand Down Expand Up @@ -261,7 +261,7 @@ async def test_RS256_token_signature_fails(self, mocked):
callback, mock = get_callback(
200, {"keys": [RSA_PUB_KEY_1_JWK]}
) # different pub key
await mocked.get(JWKS_URI, callback=callback)
mocked.get(JWKS_URI, callback=callback)

issuer = "https://tokens-test.auth0.com/"
audience = "tokens-test-123"
Expand Down
40 changes: 20 additions & 20 deletions auth0/test_async/test_asyncify.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ def callback(url, **kwargs):
return callback, mock


class TestAsyncify(unittest.TestCase):
class TestAsyncify(unittest.IsolatedAsyncioTestCase):
@pytest.mark.asyncio
@aioresponses()
async def test_get(self, mocked):
callback, mock = get_callback()
await mocked.get(clients, callback=callback)
mocked.get(clients, callback=callback)
c = asyncify(Clients)(domain="example.com", token="jwt")
self.assertEqual(await c.all_async(), payload)
mock.assert_called_with(
Expand All @@ -74,7 +74,7 @@ async def test_get(self, mocked):
@aioresponses()
async def test_post(self, mocked):
callback, mock = get_callback()
await mocked.post(clients, callback=callback)
mocked.post(clients, callback=callback)
c = asyncify(Clients)(domain="example.com", token="jwt")
data = {"client": 1}
self.assertEqual(await c.create_async(data), payload)
Expand All @@ -90,7 +90,7 @@ async def test_post(self, mocked):
@aioresponses()
async def test_post_auth(self, mocked):
callback, mock = get_callback()
await mocked.post(token, callback=callback)
mocked.post(token, callback=callback)
c = asyncify(GetToken)("example.com", "cid", client_secret="clsec")
self.assertEqual(
await c.login_async(username="usrnm", password="pswd"), payload
Expand All @@ -116,7 +116,7 @@ async def test_post_auth(self, mocked):
@aioresponses()
async def test_user_info(self, mocked):
callback, mock = get_callback()
await mocked.get(user_info, callback=callback)
mocked.get(user_info, callback=callback)
c = asyncify(Users)(domain="example.com")
self.assertEqual(
await c.userinfo_async(access_token="access-token-example"), payload
Expand All @@ -133,7 +133,7 @@ async def test_user_info(self, mocked):
@aioresponses()
async def test_file_post(self, mocked):
callback, mock = get_callback()
await mocked.post(users_imports, callback=callback)
mocked.post(users_imports, callback=callback)
j = asyncify(Jobs)(domain="example.com", token="jwt")
users = TemporaryFile()
self.assertEqual(await j.import_users_async("connection-1", users), payload)
Expand All @@ -158,7 +158,7 @@ async def test_file_post(self, mocked):
@aioresponses()
async def test_patch(self, mocked):
callback, mock = get_callback()
await mocked.patch(clients, callback=callback)
mocked.patch(clients, callback=callback)
c = asyncify(Clients)(domain="example.com", token="jwt")
data = {"client": 1}
self.assertEqual(await c.update_async("client-1", data), payload)
Expand All @@ -174,7 +174,7 @@ async def test_patch(self, mocked):
@aioresponses()
async def test_put(self, mocked):
callback, mock = get_callback()
await mocked.put(factors, callback=callback)
mocked.put(factors, callback=callback)
g = asyncify(Guardian)(domain="example.com", token="jwt")
data = {"factor": 1}
self.assertEqual(await g.update_factor_async("factor-1", data), payload)
Expand All @@ -190,7 +190,7 @@ async def test_put(self, mocked):
@aioresponses()
async def test_delete(self, mocked):
callback, mock = get_callback()
await mocked.delete(clients, callback=callback)
mocked.delete(clients, callback=callback)
c = asyncify(Clients)(domain="example.com", token="jwt")
self.assertEqual(await c.delete_async("client-1"), payload)
mock.assert_called_with(
Expand All @@ -206,7 +206,7 @@ async def test_delete(self, mocked):
@aioresponses()
async def test_shared_session(self, mocked):
callback, mock = get_callback()
await mocked.get(clients, callback=callback)
mocked.get(clients, callback=callback)
async with asyncify(Clients)(domain="example.com", token="jwt") as c:
self.assertEqual(await c.all_async(), payload)
mock.assert_called_with(
Expand All @@ -221,10 +221,10 @@ async def test_shared_session(self, mocked):
@aioresponses()
async def test_rate_limit(self, mocked):
callback, mock = get_callback(status=429)
await mocked.get(clients, callback=callback)
await mocked.get(clients, callback=callback)
await mocked.get(clients, callback=callback)
await mocked.get(clients, payload=payload)
mocked.get(clients, callback=callback)
mocked.get(clients, callback=callback)
mocked.get(clients, callback=callback)
mocked.get(clients, payload=payload)
c = asyncify(Clients)(domain="example.com", token="jwt")
rest_client = c._async_client.client
rest_client._skip_sleep = True
Expand All @@ -237,21 +237,21 @@ async def test_rate_limit(self, mocked):
@aioresponses()
async def test_rate_limit_post(self, mocked):
callback, mock = get_callback(status=429)
await mocked.post(clients, callback=callback)
await mocked.post(clients, callback=callback)
await mocked.post(clients, callback=callback)
await mocked.post(clients, payload=payload)
mocked.post(clients, callback=callback)
mocked.post(clients, callback=callback)
mocked.post(clients, callback=callback)
mocked.post(clients, payload=payload)
c = asyncify(Clients)(domain="example.com", token="jwt")
rest_client = c._async_client.client
rest_client._skip_sleep = True
self.assertEqual(await c.all_async(), payload)
self.assertEqual(await c.create_async({}), payload)
self.assertEqual(3, mock.call_count)

@pytest.mark.asyncio
@aioresponses()
async def test_timeout(self, mocked):
callback, mock = get_callback()
await mocked.get(clients, callback=callback)
mocked.get(clients, callback=callback)
c = asyncify(Clients)(domain="example.com", token="jwt", timeout=(8.8, 9.9))
self.assertEqual(await c.all_async(), payload)
mock.assert_called_with(
Expand Down

0 comments on commit 7625804

Please sign in to comment.