Skip to content

Commit

Permalink
Merge pull request #470 from stripe/mickjermsurawong/python-retry
Browse files Browse the repository at this point in the history
Python retry mechanism
  • Loading branch information
mickjermsurawong-stripe authored Sep 7, 2018
2 parents abc1d45 + 5e002cd commit 0c06d7b
Show file tree
Hide file tree
Showing 7 changed files with 438 additions and 29 deletions.
1 change: 1 addition & 0 deletions stripe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
proxy = None
default_http_client = None
app_info = None
max_network_retries = 0

# Set to either 'debug' or 'info', controls console logging
log = None
Expand Down
5 changes: 3 additions & 2 deletions stripe/api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import platform
import time
import uuid

import stripe
from stripe import error, oauth_error, http_client, version, util, six
Expand Down Expand Up @@ -220,6 +221,7 @@ def request_headers(self, api_key, method):

if method == 'post':
headers['Content-Type'] = 'application/x-www-form-urlencoded'
headers.setdefault('Idempotency-Key', str(uuid.uuid4()))

if self.api_version is not None:
headers['Stripe-Version'] = self.api_version
Expand Down Expand Up @@ -271,7 +273,6 @@ def request_raw(self, method, url, params=None, supplied_headers=None):
'assistance.' % (method,))

headers = self.request_headers(my_api_key, method)

if supplied_headers is not None:
for key, value in six.iteritems(supplied_headers):
headers[key] = value
Expand All @@ -281,7 +282,7 @@ def request_raw(self, method, url, params=None, supplied_headers=None):
'Post details',
post_data=encoded_params, api_version=self.api_version)

rbody, rcode, rheaders = self._client.request(
rbody, rcode, rheaders = self._client.request_with_retries(
method, abs_url, headers, post_data)

util.log_info(
Expand Down
7 changes: 6 additions & 1 deletion stripe/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@ class APIError(StripeError):


class APIConnectionError(StripeError):
pass
def __init__(self, message, http_body=None, http_status=None,
json_body=None, headers=None, code=None, should_retry=False):
super(APIConnectionError, self).__init__(message, http_body,
http_status,
json_body, headers, code)
self.should_retry = should_retry


class StripeErrorWithParamCode(StripeError):
Expand Down
97 changes: 94 additions & 3 deletions stripe/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import textwrap
import warnings
import email
import time
import random

from stripe import error, util, six
from stripe import error, util, six, max_network_retries

# - Requests is the preferred HTTP library
# - Google App Engine has urlfetch
Expand Down Expand Up @@ -77,6 +79,9 @@ def new_default_http_client(*args, **kwargs):


class HTTPClient(object):
MAX_DELAY = 2
INITIAL_DELAY = 0.5

def __init__(self, verify_ssl_certs=True, proxy=None):
self._verify_ssl_certs = verify_ssl_certs
if proxy:
Expand All @@ -89,10 +94,74 @@ def __init__(self, verify_ssl_certs=True, proxy=None):
" ""https"" and/or ""http"" keys.")
self._proxy = proxy.copy() if proxy else None

def request_with_retries(self, method, url, headers, post_data=None):
num_retries = 0

while True:
try:
num_retries += 1
response = self.request(method, url, headers, post_data)
connection_error = None
except error.APIConnectionError as e:
connection_error = e
response = None

if self._should_retry(response, connection_error, num_retries):
if connection_error:
util.log_info("Encountered a retryable error %s" %
connection_error.user_message)

sleep_time = self._sleep_time_seconds(num_retries)
util.log_info(("Initiating retry %i for request %s %s after "
"sleeping %.2f seconds." %
(num_retries, method, url, sleep_time)))
time.sleep(sleep_time)
else:
if response is not None:
return response
else:
raise connection_error

def request(self, method, url, headers, post_data=None):
raise NotImplementedError(
'HTTPClient subclasses must implement `request`')

def _should_retry(self, response, api_connection_error, num_retries):
if response is not None:
_, status_code, _ = response
should_retry = status_code == 409
else:
# We generally want to retry on timeout and connection
# exceptions, but defer this decision to underlying subclass
# implementations. They should evaluate the driver-specific
# errors worthy of retries, and set flag on the error returned.
should_retry = api_connection_error.should_retry
return should_retry and num_retries < self._max_network_retries()

def _max_network_retries(self):
# Configured retries, isolated here for tests
return max_network_retries

def _sleep_time_seconds(self, num_retries):
# Apply exponential backoff with initial_network_retry_delay on the
# number of num_retries so far as inputs.
# Do not allow the number to exceed max_network_retry_delay.
sleep_seconds = min(
HTTPClient.INITIAL_DELAY * (2 ** (num_retries - 1)),
HTTPClient.MAX_DELAY)

sleep_seconds = self._add_jitter_time(sleep_seconds)

# But never sleep less than the base sleep seconds.
sleep_seconds = max(HTTPClient.INITIAL_DELAY, sleep_seconds)
return sleep_seconds

def _add_jitter_time(self, sleep_seconds):
# Randomize the value in [(sleep_seconds/ 2) to (sleep_seconds)]
# Also separated method here to isolate randomness for tests
sleep_seconds *= (0.5 * (1 + random.uniform(0, 1)))
return sleep_seconds

def close(self):
raise NotImplementedError(
'HTTPClient subclasses must implement `close`')
Expand Down Expand Up @@ -146,11 +215,31 @@ def request(self, method, url, headers, post_data=None):
return content, status_code, result.headers

def _handle_request_error(self, e):
if isinstance(e, requests.exceptions.RequestException):

# Catch SSL error first as it belongs to ConnectionError,
# but we don't want to retry
if isinstance(e, requests.exceptions.SSLError):
msg = ("Could not verify Stripe's SSL certificate. Please make "
"sure that your network is not intercepting certificates. "
"If this problem persists, let us know at "
"support@stripe.com.")
err = "%s: %s" % (type(e).__name__, str(e))
should_retry = False
# Retry only timeout and connect errors; similar to urllib3 Retry
elif isinstance(e, requests.exceptions.Timeout) or \
isinstance(e, requests.exceptions.ConnectionError):
msg = ("Unexpected error communicating with Stripe. "
"If this problem persists, let us know at "
"support@stripe.com.")
err = "%s: %s" % (type(e).__name__, str(e))
should_retry = True
# Catch remaining request exceptions
elif isinstance(e, requests.exceptions.RequestException):
msg = ("Unexpected error communicating with Stripe. "
"If this problem persists, let us know at "
"support@stripe.com.")
err = "%s: %s" % (type(e).__name__, str(e))
should_retry = False
else:
msg = ("Unexpected error communicating with Stripe. "
"It looks like there's probably a configuration "
Expand All @@ -161,8 +250,10 @@ def _handle_request_error(self, e):
err += " with error message %s" % (str(e),)
else:
err += " with no error message"
should_retry = False

msg = textwrap.fill(msg) + "\n\n(Network error: %s)" % (err,)
raise error.APIConnectionError(msg)
raise error.APIConnectionError(msg, should_retry=should_retry)

def close(self):
if self._session is not None:
Expand Down
64 changes: 56 additions & 8 deletions tests/test_api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import datetime
import json
import tempfile
import uuid

import pytest

Expand Down Expand Up @@ -35,28 +36,32 @@ class APIHeaderMatcher(object):
'User-Agent',
'X-Stripe-Client-User-Agent',
]
METHOD_EXTRA_KEYS = {"post": ["Content-Type"]}
METHOD_EXTRA_KEYS = {"post": ["Content-Type", "Idempotency-Key"]}

def __init__(self, api_key=None, extra={}, request_method=None,
user_agent=None, app_info=None):
user_agent=None, app_info=None, idempotency_key=None):
self.request_method = request_method
self.api_key = api_key or stripe.api_key
self.extra = extra
self.user_agent = user_agent
self.app_info = app_info
self.idempotency_key = idempotency_key

def __eq__(self, other):
return (self._keys_match(other) and
self._auth_match(other) and
self._user_agent_match(other) and
self._x_stripe_ua_contains_app_info(other) and
self._idempotency_key_match(other) and
self._extra_match(other))

def __repr__(self):
return ("APIHeaderMatcher(request_method=%s, api_key=%s, extra=%s, "
"user_agent=%s, app_info=%s)" %
"user_agent=%s, app_info=%s, idempotency_key=%s)" %
(repr(self.request_method), repr(self.api_key),
repr(self.extra), repr(self.user_agent), repr(self.app_info)))
repr(self.extra), repr(self.user_agent), repr(self.app_info),
repr(self.idempotency_key))
)

def _keys_match(self, other):
expected_keys = list(set(self.EXP_KEYS + list(self.extra.keys())))
Expand All @@ -74,6 +79,11 @@ def _user_agent_match(self, other):

return True

def _idempotency_key_match(self, other):
if self.idempotency_key is not None:
return other['Idempotency-Key'] == self.idempotency_key
return True

def _x_stripe_ua_contains_app_info(self, other):
if self.app_info:
ua = json.loads(other['X-Stripe-Client-User-Agent'])
Expand Down Expand Up @@ -129,6 +139,19 @@ def __repr__(self):
return ("UrlMatcher(exp_parts=%s)" % (repr(self.exp_parts)))


class AnyUUID4Matcher(object):

def __eq__(self, other):
try:
uuid.UUID(other, version=4)
except ValueError:
return False
return True

def __repr__(self):
return "AnyUUID4Matcher()"


class TestAPIRequestor(object):
ENCODE_INPUTS = {
'dict': {
Expand Down Expand Up @@ -198,7 +221,7 @@ def requestor(self, http_client):
def mock_response(self, mocker, http_client):
def mock_response(return_body, return_code, headers=None):
print(return_code)
http_client.request = mocker.Mock(
http_client.request_with_retries = mocker.Mock(
return_value=(return_body, return_code, headers or {}))
return mock_response

Expand All @@ -211,7 +234,7 @@ def check_call(method, abs_url=None, headers=None,
if not headers:
headers = APIHeaderMatcher(request_method=method)

http_client.request.assert_called_with(
http_client.request_with_retries.assert_called_with(
method, abs_url, headers, post_data)
return check_call

Expand Down Expand Up @@ -417,6 +440,31 @@ def test_uses_app_info(self, requestor, mock_response, check_call):
finally:
stripe.app_info = old

def test_uses_given_idempotency_key(self, requestor, mock_response,
check_call):
mock_response('{}', 200)
meth = 'post'
requestor.request(meth, self.valid_path, {},
{'Idempotency-Key': '123abc'})

header_matcher = APIHeaderMatcher(
request_method=meth,
idempotency_key='123abc'
)
check_call(meth, headers=header_matcher, post_data='')

def test_uuid4_idempotency_key_when_not_given(self, requestor,
mock_response, check_call):
mock_response('{}', 200)
meth = 'post'
requestor.request(meth, self.valid_path, {})

header_matcher = APIHeaderMatcher(
request_method=meth,
idempotency_key=AnyUUID4Matcher()
)
check_call(meth, headers=header_matcher, post_data='')

def test_fails_without_api_key(self, requestor):
stripe.api_key = None

Expand Down Expand Up @@ -535,12 +583,12 @@ def test_default_http_client_called(self, mocker):
hc = mocker.Mock(stripe.http_client.HTTPClient)
hc._verify_ssl_certs = True
hc.name = 'mockclient'
hc.request = mocker.Mock(return_value=("{}", 200, {}))
hc.request_with_retries = mocker.Mock(return_value=("{}", 200, {}))

stripe.default_http_client = hc
stripe.Charge.list(limit=3)

hc.request.assert_called_with(
hc.request_with_retries.assert_called_with(
'get',
'https://api.stripe.com/v1/charges?limit=3',
mocker.ANY,
Expand Down
9 changes: 9 additions & 0 deletions tests/test_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,12 @@ def test_repr(self):
assert repr(err) == \
"CardError(message='öre', param='cparam', code='ccode', " \
"http_status=403, request_id='123')"


class TestApiConnectionError(object):
def test_default_no_retry(self):
err = error.APIConnectionError('msg')
assert err.should_retry is False

err = error.APIConnectionError('msg', should_retry=True)
assert err.should_retry
Loading

0 comments on commit 0c06d7b

Please sign in to comment.