Skip to content

Commit

Permalink
Feature/add rate limit logic (#39)
Browse files Browse the repository at this point in the history
* Adding Backoff Logic

* Caching authenticator and adding backoff

* Adding configurable backoff options

* Assign backoff variables

* Use the parent backoff_wait_generator if there is no custom backoff_type

* Adding additional backoff offset

* Adding ability to add full JSON message.

* Reformatted using Black

* Resolving typos and AWS Authentication issue

* Resolving linting issues and bumping SDK version

---------

Co-authored-by: Steve Clarke <stephen.clarke@health.govt.nz>
  • Loading branch information
s7clarke10 and Steve Clarke authored Nov 6, 2023
1 parent 02fdcd9 commit 28ebe25
Show file tree
Hide file tree
Showing 9 changed files with 622 additions and 415 deletions.
14 changes: 13 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ plugins:
kind: string
- name: use_request_body_not_params
kind: boolean
- name: backoff_type
kind: string
- name: backoff_param
kind: string
- name: backoff_time_extension
kind: integer
- name: store_raw_json_message
kind: boolean
- name: pagination_page_size
kind: integer
- name: pagination_results_limit
Expand Down Expand Up @@ -137,7 +145,11 @@ provided at the top-level will be the default values for each stream.:
- `api_url`: required: the base url/endpoint for the desired api.
- `pagination_request_style`: optional: style for requesting pagination, defaults to `default` which is the `jsonpath_paginator`, see Pagination below.
- `pagination_response_style`: optional: style of pagination results, defaults to `default` which is the `page` style response, see Pagination below.
- `use_request_body_not_params`: optional: sends the request parameters in the request body. This is normally not required, a few API's like OpenSearch require this. Defaults to `False`"
- `use_request_body_not_params`: optional: sends the request parameters in the request body. This is normally not required, a few API's like OpenSearch require this. Defaults to `False`.
- `backoff_type`: optional: The style of Backoff [message|header] applied to rate limited APIs. Backoff times (seconds) come from response either the `message` or `header`. Defaults to `None`.
- `backoff_param`: optional: the header parameter to inspect for a backoff time. Defaults to `Retry-After`.
- `backoff_time_extension`: optional: An additional extension (seconds) to the backoff time over and above a jitter value - use where an API is not precise in it's backoff times. Defaults to `0`.
- `store_raw_json_message`: optional: An additional extension which will emit the whole message into an field `_sdc_raw_json`. Useful for a dynamic schema which cannot be automatically discovered. Defaults to `False`.
- `pagination_page_size`: optional: limit for size of page, defaults to None.
- `pagination_results_limit`: optional: limits the max number of records. Note: Will cause an exception if the limit is hit (except for the `restapi_header_link_paginator`). This should be used for development purposes to restrict the total number of records returned by the API. Defaults to None.
- `pagination_next_page_param`: optional: The name of the param that indicates the page/offset. Defaults to None.
Expand Down
8 changes: 8 additions & 0 deletions meltano.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ plugins:
kind: string
- name: use_request_body_not_params
kind: boolean
- name: backoff_type
kind: string
- name: backoff_param
kind: string
- name: backoff_time_extension
kind: integer
- name: store_raw_json_message
kind: boolean
- name: pagination_page_size
kind: integer
- name: pagination_results_limit
Expand Down
801 changes: 421 additions & 380 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ readme = "README.md"
[tool.poetry.dependencies]
python = "<3.12,>=3.7.1"
requests = "^2.25.1"
singer-sdk = "^0.30.0"
singer-sdk = "^0.33.0"
genson = "^1.2.2"
atomicwrites = "^1.4.0"
requests-aws4auth = "^1.2.3"
Expand Down
36 changes: 36 additions & 0 deletions tap_rest_api_msdk/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import boto3
from requests_aws4auth import AWS4Auth
from singer_sdk.authenticators import (
APIAuthenticatorBase,
APIKeyAuthenticator,
BasicAuthenticator,
BearerTokenAuthenticator,
Expand Down Expand Up @@ -271,3 +272,38 @@ def select_authenticator(self) -> Any:
f"Unknown authentication method {auth_method}. Use api_key, basic, oauth, "
f"bearer_token, or aws."
)


def get_authenticator(self) -> Any:
"""Retrieve the appropriate authenticator in tap and stream.
If the authenticator already exists, use the cached
Authenticator
Note: Store the authenticator in class variables used by the SDK.
Returns:
None
"""
# Test where the config is located in self
if self.config: # Tap Config
my_config = self.config
elif self._config: # Stream Config
my_config = self._config

auth_method = my_config.get("auth_method", None)
self.http_auth = None

if not self._authenticator:
self._authenticator = select_authenticator(self)
if not self._authenticator:
# No Auth Method, use default Authenticator
self._authenticator = APIAuthenticatorBase(stream=self)
if auth_method == "oauth":
if not self._authenticator.is_token_valid():
# Obtain a new OAuth token as it has expired
self._authenticator = select_authenticator(self)
if auth_method == "aws":
# Set the http_auth which is used in the Request call for AWS
self.http_auth = self._authenticator
24 changes: 7 additions & 17 deletions tap_rest_api_msdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,19 @@
from pathlib import Path
from typing import Any

from singer_sdk.authenticators import APIAuthenticatorBase
from singer_sdk.streams import RESTStream
from tap_rest_api_msdk.auth import select_authenticator
from tap_rest_api_msdk.auth import get_authenticator

SCHEMAS_DIR = Path(__file__).parent / Path("./schemas")


class RestApiStream(RESTStream):
"""rest-api stream class."""

# Intialise self.http_auth used by prepare_request
http_auth = None
# Cache the authenticator using a Smart Singleton pattern
_authenticator = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.http_auth = None
self._authenticator = getattr(self, "assigned_authenticator", None)

@property
def url_base(self) -> Any:
Expand Down Expand Up @@ -50,16 +49,7 @@ def authenticator(self) -> Any:
A SDK Authenticator or APIAuthenticatorBase if no auth_method supplied.
"""
auth_method = self.config.get("auth_method", None)

if not self._authenticator:
self._authenticator = select_authenticator(self)
if not self._authenticator:
# No Auth Method, use default Authenticator
self._authenticator = APIAuthenticatorBase(stream=self)
elif auth_method == "oauth":
if not self._authenticator.is_token_valid():
# Obtain a new OAuth token as it has expired
self._authenticator = select_authenticator(self)
# Obtaining Authenticator for authorisation to extract data.
get_authenticator(self)

return self._authenticator
62 changes: 59 additions & 3 deletions tap_rest_api_msdk/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
from datetime import datetime
from string import Template
from typing import Any, Dict, Iterable, Optional
from typing import Any, Dict, Generator, Iterable, Optional, Union
from urllib.parse import parse_qs, parse_qsl, urlparse

import requests
Expand Down Expand Up @@ -64,6 +64,11 @@ def __init__(
source_search_field: Optional[str] = None,
source_search_query: Optional[str] = None,
use_request_body_not_params: Optional[bool] = False,
backoff_type: Optional[str] = None,
backoff_param: Optional[str] = "Retry-After",
backoff_time_extension: Optional[int] = 0,
store_raw_json_message: Optional[bool] = False,
authenticator: Optional[object] = None,
) -> None:
"""Class initialization.
Expand All @@ -90,6 +95,11 @@ def __init__(
source_search_field: see tap.py
source_search_query: see tap.py
use_request_body_not_params: see tap.py
backoff_type: see tap.py
backoff_param: see tap.py
backoff_time_extension: see tap.py
store_raw_json_message: see tap.py
authenticator: see tap.py
"""
super().__init__(tap=tap, name=tap.name, schema=schema)
Expand All @@ -101,6 +111,8 @@ def __init__(
self.path = path
self.params = params if params else {}
self.headers = headers
self.assigned_authenticator = authenticator
self._authenticator = authenticator
self.primary_keys = primary_keys
self.replication_key = replication_key
self.except_keys = except_keys
Expand Down Expand Up @@ -129,6 +141,10 @@ def __init__(
# processing is invoked.

self.use_request_body_not_params = use_request_body_not_params
self.backoff_type = backoff_type
self.backoff_param = backoff_param
self.backoff_time_extension = backoff_time_extension
self.store_raw_json_message = store_raw_json_message
if self.use_request_body_not_params:
self.prepare_request_payload = get_url_params_styles.get( # type: ignore
pagination_response_style, self._get_url_params_page_style
Expand Down Expand Up @@ -212,11 +228,51 @@ def http_headers(self) -> dict:

return headers

def backoff_wait_generator(
self,
) -> Generator[Union[int, float], None, None]:
"""Return a backoff generator as required to manage Rate Limited APIs.
Supply a backoff_type in the config to indicate the style of backoff.
If the backoff response is in a header, supply a backoff_param
indicating what key contains the backoff delay.
Note: If the backoff_type is message, the message is parsed for numeric
values. It is assumed that the highest numeric value discovered is the
backoff value in seconds.
Returns:
Backoff Generator with value to wait based on the API Response.
"""

def _backoff_from_headers(exception):
response_headers = exception.response.headers

return (
int(response_headers.get(self.backoff_param, 0))
+ self.backoff_time_extension
)

def _get_wait_time_from_response(exception):
response_message = exception.response.json().get("message", 0)
res = [int(i) for i in response_message.split() if i.isdigit()]

return int(max(res)) + self.backoff_time_extension

if self.backoff_type == "message":
return self.backoff_runtime(value=_get_wait_time_from_response)
elif self.backoff_type == "header":
return self.backoff_runtime(value=_backoff_from_headers)
else:
# No override required. Use SDK backoff_wait_generator
return super().backoff_wait_generator()

def get_new_paginator(self):
"""Return the requested paginator required to retrieve all data from the API.
Returns:
Paginator Class.
Paginator Class.
"""
self.logger.info(
Expand Down Expand Up @@ -521,4 +577,4 @@ def post_process(self, row: dict, context: Optional[dict] = None) -> dict:
A record that has been processed.
"""
return flatten_json(row, self.except_keys)
return flatten_json(row, self.except_keys, self.store_raw_json_message)
79 changes: 67 additions & 12 deletions tap_rest_api_msdk/tap.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from singer_sdk import Tap
from singer_sdk import typing as th
from singer_sdk.helpers.jsonpath import extract_jsonpath
from tap_rest_api_msdk.auth import select_authenticator
from tap_rest_api_msdk.auth import get_authenticator
from tap_rest_api_msdk.streams import DynamicStream
from tap_rest_api_msdk.utils import flatten_json

Expand All @@ -22,6 +22,10 @@ class TapRestApiMsdk(Tap):
# Required for Authentication in tap.py - function APIAuthenticatorBase
tap_name = name

# Used to cache the Authenticator to prevent over hitting the Authentication
# end-point for each stream.
_authenticator = None

common_properties = th.PropertiesList(
th.Property(
"path",
Expand Down Expand Up @@ -303,6 +307,46 @@ class TapRestApiMsdk(Tap):
"This is normally not required, a few API's like OpenSearch"
"require this. Defaults to `False`",
),
th.Property(
"backoff_type",
th.StringType,
default=None,
required=False,
allowed_values=[None, "message", "header"],
description="The style of Backoff applied to rate limited APIs."
"None: Default Meltano SDK backoff_wait_generator, message: Scans "
"the response message for a time interval, header: retrieves the "
"backoff value from a header key response."
" Defaults to `None`",
),
th.Property(
"backoff_param",
th.StringType,
default="Retry-After",
required=False,
description="The name of the key which contains a the "
"backoff value in the response. This is very applicable to backoff"
" values in headers. Defaults to `Retry-After`",
),
th.Property(
"backoff_time_extension",
th.IntegerType,
default=0,
required=False,
description="A time extension (in seconds) to add to the backoff "
"value from the API plus jitter. Some APIs are not precise"
", this adds an additional wait delay. Defaults to `0`",
),
th.Property(
"store_raw_json_message",
th.BooleanType,
default=False,
required=False,
description="Adds an additional _SDC_RAW_JSON column as an "
"object. This will store the raw incoming message in this "
"column when provisioned. Useful for semi-structured records "
"when the schema is not well defined. Defaults to `False`",
),
th.Property(
"pagination_page_size",
th.IntegerType,
Expand Down Expand Up @@ -475,6 +519,11 @@ def discover_streams(self) -> List[DynamicStream]: # type: ignore
use_request_body_not_params=self.config.get(
"use_request_body_not_params"
),
backoff_type=self.config.get("backoff_type"),
backoff_param=self.config.get("backoff_param"),
backoff_time_extension=self.config.get("backoff_time_extension"),
store_raw_json_message=self.config.get("store_raw_json_message"),
authenticator=self._authenticator,
)
)

Expand All @@ -491,9 +540,11 @@ def get_schema(
) -> Any:
"""Infer schema from the first records returned by api. Creates a Stream object.
If auth_method is set, will call select_authenticator to obtain credentials
to issue a request to sample some records. The select_authenticator will
set the self.http_auth if required by the request authenticator.
If auth_method is set, will call get_authenticator to obtain credentials
to issue a request to sample some records. The get_authenticator will:
- stores the authenticator in self._authenticator
- sets the self.http_auth if required by a given authenticator
- use an existing authenticator if one exists and is cached.
Args:
records_path: required - see config_jsonschema.
Expand All @@ -517,13 +568,11 @@ def get_schema(
self.http_auth = None

if auth_method and not auth_method == "no_auth":
# Initializing Authenticator for authorisation to obtain a schema.
# Will set the self.http_auth if required by a given authenticator
authenticator = select_authenticator(self)
if hasattr(authenticator, "auth_headers"):
headers.update(authenticator.auth_headers or {})
if hasattr(authenticator, "auth_params"):
params.update(authenticator.auth_params or {})
# Obtaining Authenticator for authorisation to obtain a schema.
get_authenticator(self)

headers.update(getattr(self._authenticator, "auth_headers", {}))
params.update(getattr(self._authenticator, "auth_params", {}))

r = requests.get(
self.config["api_url"] + path,
Expand All @@ -544,8 +593,14 @@ def get_schema(
self.logger.error("Input must be a dict object.")
raise ValueError("Input must be a dict object.")

flat_record = flatten_json(record, except_keys)
flat_record = flatten_json(
record, except_keys, store_raw_json_message=False
)

builder.add_object(flat_record)
# Optional add _sdc_raw_json field to store the raw message
if self.config.get("store_raw_json_message"):
builder.add_object({"_sdc_raw_json": {}})

if i >= inference_records:
break
Expand Down
Loading

0 comments on commit 28ebe25

Please sign in to comment.