Skip to content

Commit

Permalink
fix: use new auth token on get() after reauthenticate()
Browse files Browse the repository at this point in the history
  • Loading branch information
sdewitt-newrelic committed Apr 4, 2024
1 parent 1eddda4 commit bb38364
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 11 deletions.
6 changes: 5 additions & 1 deletion src/newrelic_logging/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ def get(
)
auth.reauthenticate(session)

response = session.get(url, headers=headers, stream=stream)
new_headers = {
'Authorization': f'Bearer {auth.get_access_token()}'
}

response = session.get(url, headers=new_headers, stream=stream)
if response.status_code == 200:
return cb(response)

Expand Down
10 changes: 9 additions & 1 deletion src/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from requests import Session, RequestException

from newrelic_logging import DataFormat, LoginException, SalesforceApiException
from newrelic_logging.api import Api
from newrelic_logging.api import Api, ApiFactory
from newrelic_logging.auth import Authenticator
from newrelic_logging.cache import DataCache
from newrelic_logging.config import Config
Expand Down Expand Up @@ -84,6 +84,7 @@ def __init__(
data_cache: DataCache = None,
token_url: str = '',
access_token: str = '',
access_token_2: str = '',
instance_url: str = '',
grant_type: str = '',
authenticate_called: bool = False,
Expand All @@ -94,6 +95,7 @@ def __init__(
self.data_cache = data_cache
self.token_url = token_url
self.access_token = access_token
self.access_token_2 = access_token_2
self.instance_url = instance_url
self.grant_type = grant_type
self.authenticate_called = authenticate_called
Expand Down Expand Up @@ -143,6 +145,8 @@ def reauthenticate(
if self.raise_login_error:
raise LoginException('Unauthorized')

self.access_token = self.access_token_2


class AuthenticatorFactoryStub:
def __init__(self):
Expand Down Expand Up @@ -463,6 +467,7 @@ def __init__(
data_cache: DataCache,
authenticator: Authenticator,
pipeline: Pipeline,
api_factory: ApiFactory,
query_factory: QueryFactory,
initial_delay: int,
queries: list[dict] = None,
Expand All @@ -472,6 +477,7 @@ def __init__(
self.data_cache = data_cache
self.authenticator = authenticator
self.pipeline = pipeline
self.api_factory = api_factory
self.query_factory = query_factory
self.initial_delay = initial_delay
self.queries = queries
Expand All @@ -488,6 +494,7 @@ def new(
data_cache: DataCache,
authenticator: Authenticator,
pipeline: Pipeline,
api_factory: ApiFactory,
query_factory: QueryFactory,
initial_delay: int,
queries: list[dict] = None,
Expand All @@ -498,6 +505,7 @@ def new(
data_cache,
authenticator,
pipeline,
api_factory,
query_factory,
initial_delay,
queries,
Expand Down
69 changes: 61 additions & 8 deletions src/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def test_get_calls_reauthenticate_on_401_and_invokes_cb_with_response_on_200(sel
and when: response status code is 401
then: reauthenticate() is called
and when: reauthenticate() does not throw a LoginException
then: request is executed again with the same parameters
then: request is executed again with the same URL and stream setting as the first call to session.get() and the second access token
and when: it returns a 200
then: calls callback with response and returns result
'''
Expand All @@ -265,6 +265,7 @@ def test_get_calls_reauthenticate_on_401_and_invokes_cb_with_response_on_200(sel
auth = AuthenticatorStub(
instance_url='https://my.salesforce.test',
access_token='123456',
access_token_2='567890',
)
response1 = ResponseStub(401, 'Unauthorized', 'Unauthorized', [])
response2 = ResponseStub(200, 'OK', 'OK', [])
Expand Down Expand Up @@ -299,7 +300,7 @@ def cb(response):
self.assertTrue('Authorization' in session.requests[1]['headers'])
self.assertEqual(
session.requests[1]['headers']['Authorization'],
'Bearer 123456',
'Bearer 567890',
)
self.assertEqual(
session.requests[1]['stream'],
Expand All @@ -308,9 +309,9 @@ def cb(response):
self.assertIsNotNone(val)
self.assertEqual(val, 'OK')

def test_get_passes_same_params_to_get_on_reauthenticate(self):
def test_get_passed_correct_params_after_reauthenticate(self):
'''
get() receives the same set of parameters on the second call after reauthenticate() succeeds
get() receives the correct set of parameters when it is called after reauthenticate() succeeds
given: an authenticator
and given: a session
and given: a service url
Expand All @@ -320,13 +321,14 @@ def test_get_passes_same_params_to_get_on_reauthenticate(self):
and when: response status code is 401
then: reauthenticate() is called
and when: reauthenticate() does not throw a LoginException
then: request is executed again with the same set of parameters as the first call to session.get()
then: request is executed again with the same URL and stream setting as the first call to session.get() and the second access token
'''

# setup
auth = AuthenticatorStub(
instance_url='https://my.salesforce.test',
access_token='123456',
access_token_2='567890'
)
response1 = ResponseStub(401, 'Unauthorized', 'Unauthorized', [])
response2 = ResponseStub(200, 'OK', 'OK', [])
Expand All @@ -340,8 +342,33 @@ def cb(response):

# verify
self.assertEqual(len(session.requests), 2)
self.assertEqual(session.requests[0], session.requests[1])
self.assertTrue(auth.reauthenticate_called)
self.assertEqual(
session.requests[0]['url'],
'https://my.salesforce.test/foo',
)
self.assertTrue('Authorization' in session.requests[0]['headers'])
self.assertEqual(
session.requests[0]['headers']['Authorization'],
'Bearer 123456',
)
self.assertEqual(
session.requests[0]['stream'],
True,
)
self.assertEqual(
session.requests[1]['url'],
'https://my.salesforce.test/foo',
)
self.assertTrue('Authorization' in session.requests[1]['headers'])
self.assertEqual(
session.requests[1]['headers']['Authorization'],
'Bearer 567890',
)
self.assertEqual(
session.requests[1]['stream'],
True,
)
self.assertIsNotNone(val)
self.assertEqual(val, 'OK')

Expand All @@ -357,7 +384,7 @@ def test_get_calls_reauthenticate_on_401_and_raises_on_non_200(self):
and when: response status code is 401
then: reauthenticate() is called
and when: reauthenticate() does not throw a LoginException
then: request is executed again with the same parameters
then: request is executed again with the same URL and stream setting as the first call to session.get() and the second access token
and when: it returns a non-200 status code
then: throws a SalesforceApiException
'''
Expand All @@ -366,6 +393,7 @@ def test_get_calls_reauthenticate_on_401_and_raises_on_non_200(self):
auth = AuthenticatorStub(
instance_url='https://my.salesforce.test',
access_token='123456',
access_token_2='567890',
)
response1 = ResponseStub(401, 'Unauthorized', 'Unauthorized', [])
response2 = ResponseStub(401, 'Unauthorized', 'Unauthorized 2', [])
Expand All @@ -384,8 +412,33 @@ def cb(response):
)

self.assertEqual(len(session.requests), 2)
self.assertEqual(session.requests[0], session.requests[1])
self.assertTrue(auth.reauthenticate_called)
self.assertEqual(
session.requests[0]['url'],
'https://my.salesforce.test/foo',
)
self.assertTrue('Authorization' in session.requests[0]['headers'])
self.assertEqual(
session.requests[0]['headers']['Authorization'],
'Bearer 123456',
)
self.assertEqual(
session.requests[0]['stream'],
False,
)
self.assertEqual(
session.requests[1]['url'],
'https://my.salesforce.test/foo',
)
self.assertTrue('Authorization' in session.requests[1]['headers'])
self.assertEqual(
session.requests[1]['headers']['Authorization'],
'Bearer 567890',
)
self.assertEqual(
session.requests[1]['stream'],
False,
)

def test_stream_lines_sets_fallback_encoding_and_calls_iter_lines_with_chunk_size_and_decode_unicode(self):
'''
Expand Down
17 changes: 16 additions & 1 deletion src/tests/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest

from . import AuthenticatorFactoryStub, \
from . import ApiFactoryStub, \
AuthenticatorFactoryStub, \
CacheFactoryStub, \
NewRelicStub, \
NewRelicFactoryStub, \
Expand Down Expand Up @@ -38,6 +39,7 @@ def test_build_instance(self):
}
]
})
api_factory = ApiFactoryStub()
auth_factory = AuthenticatorFactoryStub()
cache_factory = CacheFactoryStub()
pipeline_factory = PipelineFactoryStub()
Expand All @@ -54,6 +56,7 @@ def test_build_instance(self):
cache_factory,
pipeline_factory,
salesforce_factory,
api_factory,
query_factory,
new_relic,
DataFormat.EVENTS,
Expand Down Expand Up @@ -141,6 +144,7 @@ def test_build_instance(self):
}
]
})
api_factory = ApiFactoryStub()
auth_factory = AuthenticatorFactoryStub()
cache_factory = CacheFactoryStub()
pipeline_factory = PipelineFactoryStub()
Expand All @@ -154,6 +158,7 @@ def test_build_instance(self):
cache_factory,
pipeline_factory,
salesforce_factory,
api_factory,
query_factory,
new_relic,
DataFormat.EVENTS,
Expand Down Expand Up @@ -200,6 +205,7 @@ def test_build_instance(self):
}
]
})
api_factory = ApiFactoryStub()
auth_factory = AuthenticatorFactoryStub()
cache_factory = CacheFactoryStub()
pipeline_factory = PipelineFactoryStub()
Expand All @@ -213,6 +219,7 @@ def test_build_instance(self):
cache_factory,
pipeline_factory,
salesforce_factory,
api_factory,
query_factory,
new_relic,
DataFormat.EVENTS,
Expand Down Expand Up @@ -273,6 +280,7 @@ def test_init(self):
}
})

api_factory = ApiFactoryStub()
auth_factory = AuthenticatorFactoryStub()
cache_factory = CacheFactoryStub()
pipeline_factory = PipelineFactoryStub()
Expand All @@ -290,6 +298,7 @@ def test_init(self):
cache_factory,
pipeline_factory,
salesforce_factory,
api_factory,
query_factory,
newrelic_factory,
event_type_fields_mapping,
Expand Down Expand Up @@ -345,6 +354,7 @@ def test_init(self):
}
})

api_factory = ApiFactoryStub()
auth_factory = AuthenticatorFactoryStub()
cache_factory = CacheFactoryStub()
pipeline_factory = PipelineFactoryStub()
Expand All @@ -361,6 +371,7 @@ def test_init(self):
cache_factory,
pipeline_factory,
salesforce_factory,
api_factory,
query_factory,
newrelic_factory,
event_type_fields_mapping,
Expand All @@ -384,6 +395,7 @@ def test_init(self):
}
})

api_factory = ApiFactoryStub()
auth_factory = AuthenticatorFactoryStub()
cache_factory = CacheFactoryStub()
pipeline_factory = PipelineFactoryStub()
Expand All @@ -399,6 +411,7 @@ def test_init(self):
cache_factory,
pipeline_factory,
salesforce_factory,
api_factory,
query_factory,
newrelic_factory,
event_type_fields_mapping,
Expand All @@ -425,6 +438,7 @@ def test_init(self):
}
})

api_factory = ApiFactoryStub()
auth_factory = AuthenticatorFactoryStub()
cache_factory = CacheFactoryStub()
pipeline_factory = PipelineFactoryStub()
Expand All @@ -440,6 +454,7 @@ def test_init(self):
cache_factory,
pipeline_factory,
salesforce_factory,
api_factory,
query_factory,
newrelic_factory,
event_type_fields_mapping,
Expand Down

0 comments on commit bb38364

Please sign in to comment.