From 4e3082d78e24ff5924abc837b1199e9370afc97c Mon Sep 17 00:00:00 2001 From: Scott DeWitt Date: Thu, 7 Mar 2024 14:27:26 -0500 Subject: [PATCH 01/11] feat: optimize memory usage part 1 --- src/newrelic_logging/__init__.py | 32 ++ src/newrelic_logging/auth.py | 311 +++++++++++++ src/newrelic_logging/cache.py | 193 ++++---- src/newrelic_logging/config.py | 28 +- src/newrelic_logging/env.py | 58 --- src/newrelic_logging/integration.py | 284 ++++-------- src/newrelic_logging/newrelic.py | 196 ++++---- src/newrelic_logging/pipeline.py | 462 +++++++++++++++++++ src/newrelic_logging/query.py | 134 +++--- src/newrelic_logging/salesforce.py | 663 +++++----------------------- src/newrelic_logging/util.py | 79 ++++ 11 files changed, 1395 insertions(+), 1045 deletions(-) create mode 100644 src/newrelic_logging/auth.py delete mode 100644 src/newrelic_logging/env.py create mode 100644 src/newrelic_logging/pipeline.py create mode 100644 src/newrelic_logging/util.py diff --git a/src/newrelic_logging/__init__.py b/src/newrelic_logging/__init__.py index 0770ec8..91deb11 100644 --- a/src/newrelic_logging/__init__.py +++ b/src/newrelic_logging/__init__.py @@ -1,6 +1,38 @@ +from enum import Enum + + # Integration definitions VERSION = "1.0.0" NAME = "salesforce-eventlogfile" PROVIDER = "newrelic-labs" COLLECTOR_NAME = "newrelic-logs-salesforce-eventlogfile" + + +class DataFormat(Enum): + LOGS = 1 + EVENTS = 2 + + +class ConfigException(Exception): + def __init__(self, prop_name: str = None, *args: object): + self.prop_name = prop_name + super().__init__(*args) + + +class LoginException(Exception): + pass + + +class SalesforceApiException(Exception): + def __init__(self, err_code: int = 0, *args: object): + self.err_code = err_code + super().__init__(*args) + + +class CacheException(Exception): + pass + + +class NewRelicApiException(Exception): + pass diff --git a/src/newrelic_logging/auth.py b/src/newrelic_logging/auth.py new file mode 100644 index 0000000..8080691 --- /dev/null +++ b/src/newrelic_logging/auth.py @@ -0,0 +1,311 @@ +from cryptography.hazmat.primitives import serialization +from datetime import datetime, timedelta +import jwt +from requests import RequestException, Session + +from . import ConfigException, LoginException +from .cache import DataCache +from .config import Config +from .telemetry import print_err, print_info, print_warn + + +AUTH_CACHE_KEY = 'com.newrelic.labs.sf_auth' +SF_GRANT_TYPE = 'SF_GRANT_TYPE' +SF_CLIENT_ID = 'SF_CLIENT_ID' +SF_CLIENT_SECRET = 'SF_CLIENT_SECRET' +SF_USERNAME = 'SF_USERNAME' +SF_PASSWORD = 'SF_PASSWORD' +SF_PRIVATE_KEY = 'SF_PRIVATE_KEY' +SF_SUBJECT = 'SF_SUBJECT' +SF_AUDIENCE = 'SF_AUDIENCE' +SF_TOKEN_URL = 'SF_TOKEN_URL' + + +class Authenticator: + def __init__( + self, + token_url: str, + auth_data: dict, + data_cache: DataCache + ): + self.token_url = token_url + self.auth_data = auth_data + self.data_cache = data_cache + self.access_token = None + self.instance_url = None + + def get_access_token(self) -> str: + return self.access_token + + def get_instance_url(self) -> str: + return self.instance_url + + def get_grant_type(self) -> str: + return self.auth_data['grant_type'] + + def set_auth_data(self, access_token: str, instance_url: str) -> None: + self.access_token = access_token + self.instance_url = instance_url + + def clear_auth(self): + self.set_auth_data(None, None) + + if self.data_cache: + try: + self.data_cache.redis.delete(AUTH_CACHE_KEY) + except Exception as e: + print_warn(f'Failed deleting data from cache: {e}') + + def load_auth_from_cache(self) -> bool: + try: + auth_exists = self.data_cache.redis.exists(AUTH_CACHE_KEY) + if auth_exists: + print_info('Retrieving credentials from Redis.') + #NOTE: hmget and hgetall both return byte arrays, not strings. We have to convert. + # We could fix it by adding the argument "decode_responses=True" to Redis constructor, + # but then we would have to change all places where we assume a byte array instead of a string, + # and refactoring in a language without static types is a pain. + try: + auth = self.data_cache.redis.hmget( + AUTH_CACHE_KEY, + ['access_token', 'instance_url'], + ) + + self.set_auth( + auth[0].decode("utf-8"), + auth[1].decode("utf-8"), + ) + + return True + except Exception as e: + print_err(f"Failed getting 'auth' key: {e}") + except Exception as e: + print_err(f"Failed checking 'auth' key: {e}") + + return False + + def store_auth(self, auth_resp: dict): + self.access_token = auth_resp['access_token'] + self.instance_url = auth_resp['instance_url'] + + if self.data_cache: + print_info('Storing credentials in cache.') + + auth = { + 'access_token': self.access_token, + 'instance_url': self.instance_url, + } + + try: + self.data_cache.redis.hmset(AUTH_CACHE_KEY, auth) + except Exception as e: + print_warn(f"Failed storing data in cache: {e}") + + def authenticate( + self, + session: Session, + ) -> None: + if self.data_cache and self.load_auth_from_cache(): + return + + oauth_type = self.get_grant_type() + if oauth_type == 'password': + self.authenticate_with_password(session) + print_info('Correctly authenticated with user/pass flow') + return + + self.authenticate_with_jwt(session) + print_info('Correctly authenticated with JWT flow') + + def authenticate_with_jwt(self, session: Session) -> None: + private_key_file = self.auth_data['private_key'] + client_id = self.auth_data['client_id'] + subject = self.auth_data['subject'] + audience = self.auth_data['audience'] + exp = int((datetime.utcnow() - timedelta(minutes=5)).timestamp()) + + with open(private_key_file, 'r') as f: + try: + private_key = f.read() + key = serialization.load_ssh_private_key(private_key.encode(), password=b'') + except ValueError as e: + raise LoginException(f'authentication failed for {self.instance_name}. error message: {str(e)}') + + jwt_claim_set = { + "iss": client_id, + "sub": subject, + "aud": audience, + "exp": exp + } + + signed_token = jwt.encode( + jwt_claim_set, + key, + algorithm='RS256', + ) + + params = { + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "assertion": signed_token, + "format": "json" + } + + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json" + } + + try: + print_info(f'retrieving salesforce token at {self.token_url}') + resp = session.post(self.token_url, params=params, + headers=headers) + if resp.status_code != 200: + raise LoginException(f'sfdc token request failed. http-status-code:{resp.status_code}, reason: {resp.text}') + + self.store_auth(resp.json()) + except ConnectionError as e: + raise LoginException(f'authentication failed for sfdc instance {self.instance_name}') from e + except RequestException as e: + raise LoginException(f'authentication failed for sfdc instance {self.instance_name}') from e + + def authenticate_with_password(self, session: Session) -> None: + client_id = self.auth_data['client_id'] + client_secret = self.auth_data['client_secret'] + username = self.auth_data['username'] + password = self.auth_data['password'] + + params = { + "grant_type": "password", + "client_id": client_id, + "client_secret": client_secret, + "username": username, + "password": password + } + + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json" + } + + try: + print_info(f'retrieving salesforce token at {self.token_url}') + resp = session.post(self.token_url, params=params, + headers=headers) + if resp.status_code != 200: + raise LoginException(f'salesforce token request failed. status-code:{resp.status_code}, reason: {resp.reason}') + + self.store_auth(resp.json()) + except ConnectionError as e: + raise LoginException(f'authentication failed for sfdc instance {self.instance_name}') from e + except RequestException as e: + raise LoginException(f'authentication failed for sfdc instance {self.instance_name}') from e + + +def validate_oauth_config(auth: dict) -> dict: + if not auth['client_id']: + raise ConfigException('client_id', 'missing OAuth client id') + + if not auth['client_secret']: + raise ConfigException( + 'client_secret', + 'missing OAuth client secret', + ) + + if not auth['username']: + raise ConfigException('username', 'missing OAuth username') + + if not auth['password']: + raise ConfigException('password', 'missing OAuth client secret') + + return auth + + +def validate_jwt_config(auth: dict) -> dict: + if not auth['client_id']: + raise ConfigException('client_id', 'missing JWT client id') + + if not auth['private_key']: + raise ConfigException('private_key', 'missing JWT private key') + + if not auth['subject']: + raise ConfigException('subject', 'missing JWT subject') + + if not auth['audience']: + raise ConfigException('audience', 'missing JWT audience') + + return auth + + +def make_auth_from_config(auth: Config) -> dict: + grant_type = auth.get( + 'grant_type', + 'password', + env_var_name=SF_GRANT_TYPE, + ).lower() + + if grant_type == 'password': + return validate_oauth_config({ + 'grant_type': grant_type, + 'client_id': auth.get('client_id', env_var_name=SF_CLIENT_ID), + 'client_secret': auth.get( + 'client_secret', + env_var_name=SF_CLIENT_SECRET + ), + 'username': auth.get('username', env_var_name=SF_USERNAME), + 'password': auth.get('password', env_var_name=SF_PASSWORD), + }) + + if grant_type == 'urn:ietf:params:oauth:grant-type:jwt-bearer': + return validate_jwt_config({ + 'grant_type': grant_type, + 'client_id': auth.get('client_id', env_var_name=SF_CLIENT_ID), + 'private_key': auth.get('private_key', env_var_name=SF_PRIVATE_KEY), + 'subject': auth.get('subject', env_var_name=SF_SUBJECT), + 'audience': auth.get('audience', env_var_name=SF_AUDIENCE), + }) + + raise Exception(f'Wrong or missing grant_type') + + +def make_auth_from_env(config: Config) -> dict: + grant_type = config.getenv(SF_GRANT_TYPE, 'password').lower() + + if grant_type == 'password': + return validate_oauth_config({ + 'grant_type': grant_type, + 'client_id': config.getenv(SF_CLIENT_ID), + 'client_secret': config.getenv(SF_CLIENT_SECRET), + 'username': config.getenv(SF_USERNAME), + 'password': config.getenv(SF_PASSWORD), + }) + + if grant_type == 'urn:ietf:params:oauth:grant-type:jwt-bearer': + return validate_jwt_config({ + 'grant_type': grant_type, + 'client_id': config.getenv(SF_CLIENT_ID), + 'private_key': config.getenv(SF_PRIVATE_KEY), + 'subject': config.getenv(SF_SUBJECT), + 'audience': config.getenv(SF_AUDIENCE), + }) + + raise Exception(f'Wrong or missing grant_type') + + +def New(config: Config, data_cache: DataCache) -> Authenticator: + token_url = config.get('token_url', env_var_name=SF_TOKEN_URL) + + if not token_url: + raise ConfigException('token_url', 'missing token URL') + + if 'auth' in config: + return Authenticator( + token_url, + make_auth_from_config(config.sub('auth')), + data_cache, + ) + + return Authenticator( + token_url, + make_auth_from_env(config), + data_cache + ) diff --git a/src/newrelic_logging/cache.py b/src/newrelic_logging/cache.py index 536f6fc..a0dc8b6 100644 --- a/src/newrelic_logging/cache.py +++ b/src/newrelic_logging/cache.py @@ -1,8 +1,10 @@ +import gc import redis from datetime import timedelta +from . import CacheException from .config import Config -from .telemetry import print_err, print_info +from .telemetry import print_info CONFIG_CACHE_ENABLED = 'cache_enabled' @@ -20,117 +22,113 @@ DEFAULT_REDIS_SSL = False -# Local cache, to store data before sending it to Redis. -class DataCache: - redis = None - redis_expire = None - cached_events = {} - cached_logs = {} - - def __init__(self, redis, redis_expire) -> None: +class RedisBackend: + def __init__(self, redis): self.redis = redis - self.redis_expire = redis_expire - def set_redis_expire(self, key): - try: - self.redis.expire(key, timedelta(days=self.redis_expire)) - except Exception as e: - print_err(f"Failed setting expire time for key {key}: {e}") - exit(1) + def exists(self, key): + return self.redis.exists(key) - def persist_logs(self, record_id: str) -> bool: - if record_id in self.cached_logs: - for row_id in self.cached_logs[record_id]: - try: - self.redis.rpush(record_id, row_id) - except Exception as e: - print_err(f"Failed pushing record {record_id}: {e}") - exit(1) - # Set expire date for the whole list only once, when it find the first entry ('init') - if row_id == 'init': - self.set_redis_expire(record_id) - del self.cached_logs[record_id] - return True - else: - return False + def put(self, key, item): + self.redis.set(key, item) - def persist_event(self, record_id: str) -> bool: - if record_id in self.cached_events: - try: - self.redis.set(record_id, '') - except Exception as e: - print_err(f"Failed setting record {record_id}: {e}") - exit(1) - self.set_redis_expire(record_id) - del self.cached_events[record_id] - return True - else: - return False + def list_length(self, key): + return self.redis.llen(key) - def can_skip_downloading_record(self, record_id: str) -> bool: - try: - does_exist = self.redis.exists(record_id) - except Exception as e: - print_err(f"Failed checking record {record_id}: {e}") - exit(1) - if does_exist: - try: - return self.redis.llen(record_id) > 1 - except Exception as e: - print_err(f"Failed checking len for record {record_id}: {e}") - exit(1) + def list_slice(self, key, start, end): + return self.redis.lrange(key, start, end) + + def list_append(self, key, item): + self.redis.rpush(key, item) + + def set_expiry(self, key, days): + self.redis.expire(key, timedelta(days=days)) - return False - def retrieve_cached_message_list(self, record_id: str): +class DataCache: + def __init__(self, backend, expiry): + self.backend = backend + self.expiry = expiry + self.cached_events = {} + self.cached_logs = {} + + def can_skip_downloading_logfile(self, record_id: str) -> bool: try: - cache_key_exists = self.redis.exists(record_id) + return self.backend.exists(record_id) and \ + self.backend.list_length(record_id) > 1 except Exception as e: - print_err(f"Failed checking record {record_id}: {e}") - exit(1) + raise CacheException(f'failed checking record {record_id}: {e}') + + def load_cached_log_lines(self, record_id: str) -> None: + try: + if self.backend.exists(record_id): + self.cached_logs[record_id] = \ + self.backend.list_slice(record_id, 0, -1) + return - if cache_key_exists: - try: - cached_messages = self.redis.lrange(record_id, 0, -1) - except Exception as e: - print_err(f"Failed getting list range for record {record_id}: {e}") - exit(1) - return cached_messages - else: self.cached_logs[record_id] = ['init'] + except Exception as e: + raise CacheException(f'failed checking log record {record_id}: {e}') - return None + # Cache log + def check_and_set_log_line(self, record_id: str, row: dict) -> bool: + row_id = row["REQUEST_ID"] + + row_id_b = row_id.encode('utf-8') + if row_id_b in self.cached_logs[record_id]: + return True + + self.cached_logs[record_id].append(row_id) + + return False # Cache event - def check_cached_id(self, record_id: str): + def check_and_set_event_id(self, record_id: str) -> bool: try: - does_exist = self.redis.exists(record_id) - except Exception as e: - print_err(f"Failed checking record {record_id}: {e}") - exit(1) + if self.backend.exists(record_id): + return True - if does_exist: - return True - else: self.cached_events[record_id] = '' + return False + except Exception as e: + raise CacheException(f'failed checking record {record_id}: {e}') - # Cache log - def record_or_skip_row(self, record_id: str, row: dict, cached_messages: dict) -> bool: - row_id = row["REQUEST_ID"] + def flush(self) -> None: + # Flush cached log line ids for each log record + for record_id in self.cached_logs: + for row_id in self.cached_logs[record_id]: + try: + self.backend.list_append(record_id, row_id) - if cached_messages is not None: - row_id_b = row_id.encode('utf-8') - if row_id_b in cached_messages: - return True - self.cached_logs[record_id].append(row_id) - else: - self.cached_logs[record_id].append(row_id) + # Set expire date for the whole list only once, when we find + # the first entry ('init') + if row_id == 'init': + self.backend.set_expiry(record_id, self.expiry) + except Exception as e: + raise CacheException( + f'failed pushing row {row_id} for record {record_id}: {e}' + ) - return False + # Attempt to release memory + del self.cached_logs[record_id] + # Flush any cached event record ids + for record_id in self.cached_events: + try: + self.backend.put(record_id, '') + self.backend.set_expiry(record_id, self.expiry) + + # Attempt to release memory + del self.cached_events[record_id] + except Exception as e: + raise CacheException(f"failed setting record {record_id}: {e}") -def make_cache(config: Config): + # Run a gc in an attempt to reclaim memory + gc.collect() + + +def New(config: Config): if config.get_bool(CONFIG_CACHE_ENABLED, DEFAULT_CACHE_ENABLED): host = config.get(CONFIG_REDIS_HOST, DEFAULT_REDIS_HOST) port = config.get_int(CONFIG_REDIS_PORT, DEFAULT_REDIS_PORT) @@ -144,13 +142,16 @@ def make_cache(config: Config): f'Cache enabled, connecting to redis instance {host}:{port}:{db}, ssl={ssl}, password={password_display}' ) - return DataCache(redis.Redis( - host=host, - port=port, - db=db, - password=password, - ssl=ssl - ), expire_days) + return DataCache( + RedisBackend( + redis.Redis( + host=host, + port=port, + db=db, + password=password, + ssl=ssl + ), expire_days) + ) print_info('Cache disabled') diff --git a/src/newrelic_logging/config.py b/src/newrelic_logging/config.py index 81a4b42..a8e6d15 100644 --- a/src/newrelic_logging/config.py +++ b/src/newrelic_logging/config.py @@ -4,6 +4,18 @@ from typing import Any +CONFIG_DATE_FIELD = 'date_field' +CONFIG_GENERATION_INTERVAL = 'generation_interval' +CONFIG_TIME_LAG_MINUTES = 'time_lag_minutes' + +DATE_FIELD_LOG_DATE = 'LogDate' +DATE_FIELD_CREATE_DATE = 'CreateDate' +GENERATION_INTERVAL_DAILY = 'Daily' +GENERATION_INTERVAL_HOURLY = 'Hourly' + +DEFAULT_TIME_LAG_MINUTES = 300 +DEFAULT_GENERATION_INTERVAL = GENERATION_INTERVAL_DAILY + BOOL_TRUE_VALS = ['true', '1', 'on', 'yes'] NOT_FOUND = SimpleNamespace() @@ -74,15 +86,21 @@ def set_prefix(self, prefix: str) -> None: def getenv(self, env_var_name: str, default = None) -> str: return getenv(env_var_name, default, self.prefix) - def get(self, key: str, default = None, allow_none = False) -> Any: + def get( + self, + key: str, + default = None, + allow_none = False, + env_var_name = None, + ) -> Any: val = get_nested(self.config, key) if not val == NOT_FOUND and not allow_none and not val == None: return val - return self.getenv( - re.sub(r'[^a-zA-Z0-9_]', '_', key.upper()), - default, - ) + var_name = env_var_name if env_var_name else \ + re.sub(r'[^a-zA-Z0-9_]', '_', key.upper()) + + return self.getenv(var_name, default) def get_int(self, key: str, default = None) -> int: val = self.get(key, default) diff --git a/src/newrelic_logging/env.py b/src/newrelic_logging/env.py deleted file mode 100644 index 9fadebe..0000000 --- a/src/newrelic_logging/env.py +++ /dev/null @@ -1,58 +0,0 @@ -SF_GRANT_TYPE = 'SF_GRANT_TYPE' -SF_CLIENT_ID = 'SF_CLIENT_ID' -SF_CLIENT_SECRET = 'SF_CLIENT_SECRET' -SF_USERNAME = 'SF_USERNAME' -SF_PASSWORD = 'SF_PASSWORD' -SF_PRIVATE_KEY = 'SF_PRIVATE_KEY' -SF_SUBJECT = 'SF_SUBJECT' -SF_AUDIENCE = 'SF_AUDIENCE' -SF_TOKEN_URL = 'SF_TOKEN_URL' - -class AuthEnv: - def __init__(self, config): - self.config = config - - def get_grant_type(self): - return self.config.getenv(SF_GRANT_TYPE) - - def get_client_id(self): - return self.config.getenv(SF_CLIENT_ID) - - def get_client_secret(self): - return self.config.getenv(SF_CLIENT_SECRET) - - def get_username(self): - return self.config.getenv(SF_USERNAME) - - def get_password(self): - return self.config.getenv(SF_PASSWORD) - - def get_private_key(self): - return self.config.getenv(SF_PRIVATE_KEY) - - def get_subject(self): - return self.config.getenv(SF_SUBJECT) - - def get_audience(self): - return self.config.getenv(SF_AUDIENCE) - - def get_token_url(self): - return self.config.getenv(SF_TOKEN_URL) - - -class Auth: - access_token = None - instance_url = None - # Never used, maybe in the future - token_type = None - - def __init__(self, access_token: str, instance_url: str, token_type: str) -> None: - self.access_token = access_token - self.instance_url = instance_url - self.token_type = token_type - - def get_access_token(self) -> str: - return self.access_token - - def get_instance_url(self) -> str: - return self.instance_url diff --git a/src/newrelic_logging/integration.py b/src/newrelic_logging/integration.py index ba50de3..295b444 100644 --- a/src/newrelic_logging/integration.py +++ b/src/newrelic_logging/integration.py @@ -1,39 +1,50 @@ -import sys +from requests import Session + +from . import \ + ConfigException, \ + CacheException, \ + DataFormat, \ + LoginException, \ + NewRelicApiException, \ + SalesforceApiException +from . import auth +from . import cache +from . import config as mod_config +from . import newrelic +from . import pipeline from .http_session import new_retry_session -from .newrelic import NewRelic -from .cache import DataCache -from .salesforce import SalesForce, SalesforceApiException -from .env import AuthEnv -from enum import Enum -from .config import Config, getenv +from .salesforce import SalesForce from .telemetry import Telemetry, print_info, print_err -NR_LICENSE_KEY = 'NR_LICENSE_KEY' -NR_ACCOUNT_ID = 'NR_ACCOUNT_ID' - - -class DataFormat(Enum): - LOGS = 1 - EVENTS = 2 - -# TODO: move queries to the instance level, so we can have different queries for +# @TODO: move queries to the instance level, so we can have different queries for # each instance. -# TODO: also keep general queries that apply to all instances. +# @TODO: also keep general queries that apply to all instances. -class Integration: - numeric_fields_list = set() +class Integration: def __init__( self, - config: Config, + config: mod_config.Config, event_type_fields_mapping: dict = {}, numeric_fields_list: set = set(), initial_delay: int = 0, ): - Integration.numeric_fields_list = numeric_fields_list - self.instances = [] Telemetry(config["integration_name"]) + + data_format = config.get('newrelic.data_format', 'logs').lower() + if data_format == 'logs': + data_format = DataFormat.LOGS + elif data_format == 'events': + data_format = DataFormat.EVENTS + else: + raise ConfigException(f'invalid data format {data_format}') + + # Fill credentials for NR APIs + + new_relic = newrelic.New(config) + + self.instances = [] for count, instance in enumerate(config['instances']): instance_name = instance['name'] labels = instance['labels'] @@ -43,72 +54,30 @@ def __init__( instance_config['auth_env_prefix'] \ if 'auth_env_prefix' in instance_config else '' ) - auth_env = AuthEnv(instance_config) - - if 'queries' in config: - client = SalesForce(auth_env, instance_name, instance_config, event_type_fields_mapping, initial_delay, config['queries']) - else: - client = SalesForce(auth_env, instance_name, instance_config, event_type_fields_mapping, initial_delay) - - if 'auth' in instance_config: - auth = instance_config['auth'] - if 'grant_type' in auth: - oauth_type = auth['grant_type'] - else: - sys.exit("No 'grant_type' specified under 'auth' section in config.yml for instance '" + instance_name + "'") - else: - oauth_type = auth_env.get_grant_type() - - self.instances.append({'labels': labels, 'client': client, "oauth_type": oauth_type, 'name': instance_name}) - - newrelic_config = config['newrelic'] - data_format = newrelic_config['data_format'].lower() \ - if 'data_format' in newrelic_config else 'logs' - - if data_format == "logs": - self.data_format = DataFormat.LOGS - elif data_format == "events": - self.data_format = DataFormat.EVENTS - else: - sys.exit(f'invalid data_format specified. valid values are "logs" or "events"') - - # Fill credentials for NR APIs - if 'license_key' in newrelic_config: - NewRelic.logs_license_key = NewRelic.events_api_key = newrelic_config['license_key'] - else: - NewRelic.logs_license_key = NewRelic.events_api_key = getenv(NR_LICENSE_KEY) - - if self.data_format == DataFormat.EVENTS: - if 'account_id' in newrelic_config: - account_id = newrelic_config['account_id'] - else: - account_id = getenv(NR_ACCOUNT_ID) - NewRelic.set_api_endpoint(newrelic_config['api_endpoint'], account_id) - NewRelic.set_logs_endpoint(newrelic_config['api_endpoint']) - - def run(self): - sfdc_session = new_retry_session() - - for instance in self.instances: - print_info(f"Running instance '{instance['name']}'") - - labels = instance['labels'] - client = instance['client'] - oauth_type = instance['oauth_type'] - - logs = self.auth_and_fetch(True, client, oauth_type, sfdc_session) - if self.response_empty(logs): - print_info("No data to be sent") - self.process_telemetry() - continue - - if self.data_format == DataFormat.LOGS: - self.process_logs(logs, labels, client.data_cache) - else: - self.process_events(logs, labels, client.data_cache) - - self.process_telemetry() + data_cache = cache.New(instance_config) + authenticator = auth.New(instance_config, data_cache) + + self.instances.append({ + 'client': SalesForce( + instance_name, + instance_config, + data_cache, + authenticator, + pipeline.New( + instance_config, + data_cache, + new_relic, + data_format, + labels, + event_type_fields_mapping, + numeric_fields_list, + ), + initial_delay, + config['queries'] if 'queries' in config else None, + ), + 'name': instance_name, + }) def process_telemetry(self): if not Telemetry().is_empty(): @@ -118,118 +87,43 @@ def process_telemetry(self): else: print_info("No telemetry data") - def auth_and_fetch(self, retry, client, oauth_type, sfdc_session): - if not client.authenticate(oauth_type, sfdc_session): - return None + def auth_and_fetch( + self, + client: SalesForce, + session: Session, + retry: bool = True, + ) -> None: - logs = None try: - logs = client.fetch_logs(sfdc_session) + client.authenticate(session) + return client.fetch_logs(session) + except LoginException as e: + print_err(f'authentication failed: {e}') except SalesforceApiException as e: if e.err_code == 401: if retry: - print_err("Invalid token, retry auth and fetch...") + print_err('authentication failed, retrying...') client.clear_auth() - return self.auth_and_fetch(False, client, oauth_type, sfdc_session) - else: - print_err(f"Exception while fetching data from SF: {e}") - return None - else: - print_err(f"Exception while fetching data from SF: {e}") - return None + self.auth_and_fetch( + client, + session, + False, + ) + return + + print_err(f'exception while fetching data from SF: {e}') + return + + print_err(f'exception while fetching data from SF: {e}') + except CacheException as e: + print_err(f'exception while accessing Redis cache: {e}') + except NewRelicApiException as e: + print_err(f'exception while posting data to New Relic: {e}') except Exception as e: - print_err(f"Exception while fetching data from SF: {e}") - return None - - return logs - - @staticmethod - def response_empty(logs): - # Empty or None - if not logs: - return True - for l in logs: - if "log_entries" in l and l["log_entries"]: - return False - return True - - @staticmethod - def cache_processed_data(log_file_id, log_entries, data_cache: DataCache): - if data_cache and data_cache.redis: - if log_file_id == '': - # Events - for log in log_entries: - log_id = log.get('attributes', {}).get('Id', '') - data_cache.persist_event(log_id) - else: - # Logs - data_cache.persist_logs(log_file_id) - - @staticmethod - def process_logs(logs, labels, data_cache: DataCache): - nr_session = new_retry_session() - for log_file_obj in logs: - log_entries = log_file_obj['log_entries'] - if len(log_entries) == 0: - continue - - payload = [{'common': labels, 'logs': log_entries}] - log_type = log_file_obj.get('log_type', '') - log_file_id = log_file_obj.get('Id', '') - - status_code = NewRelic.post_logs(nr_session, payload) - if status_code != 202: - print_err(f'newrelic logs api returned code- {status_code}') - else: - print_info(f"Sent {len(log_entries)} log messages from log file {log_type}/{log_file_id}") - Integration.cache_processed_data(log_file_id, log_entries, data_cache) - - @staticmethod - def process_events(logs, labels, data_cache: DataCache): - nr_session = new_retry_session() - for log_file_obj in logs: - log_file_id = log_file_obj.get('Id', '') - log_entries = log_file_obj['log_entries'] - if len(log_entries) == 0: - continue - log_events = [] - for log_entry in log_entries: - log_event = {} - attributes = log_entry['attributes'] - for event_name in attributes: - # currently no need to modify as we did not see any special chars that need to be removed - modified_event_name = event_name - event_value = attributes[event_name] - if event_name in Integration.numeric_fields_list: - if event_value: - try: - log_event[modified_event_name] = int(event_value) - except (TypeError, ValueError) as e: - try: - log_event[modified_event_name] = float(event_value) - except (TypeError, ValueError) as e: - print_err(f'Type conversion error for {event_name}[{event_value}]') - log_event[modified_event_name] = event_value - else: - log_event[modified_event_name] = 0 - else: - log_event[modified_event_name] = event_value - log_event.update(labels) - event_type = log_event.get('EVENT_TYPE', "UnknownSFEvent") - log_event['eventType'] = event_type - log_events.append(log_event) - - # NOTE: this is probably unnecessary now, because we already have a slicing method with a limit of 1000 in SalesForce.extract_row_slice - # since the max number of events that can be posted in a single payload to New Relic is 2000 - max_events = 2000 - x = [log_events[i:i + max_events] for i in range(0, len(log_events), max_events)] - - for log_entries_slice in x: - status_code = NewRelic.post_events(nr_session, log_entries_slice) - if status_code != 200: - print_err(f'newrelic events api returned code- {status_code}') - else: - log_type = log_file_obj.get('log_type', '') - log_file_id = log_file_obj.get('Id', '') - print_info(f"Posted {len(log_entries_slice)} events from log file {log_type}/{log_file_id}") - Integration.cache_processed_data(log_file_id, log_entries, data_cache) + print_err(f'unknown exception occurred: {e}') + + def run(self): + for instance in self.instances: + print_info(f"Running instance '{instance['name']}'") + self.auth_and_fetch(instance['client'], new_retry_session()) + self.process_telemetry() diff --git a/src/newrelic_logging/newrelic.py b/src/newrelic_logging/newrelic.py index 0df766c..ed9d7e1 100644 --- a/src/newrelic_logging/newrelic.py +++ b/src/newrelic_logging/newrelic.py @@ -1,31 +1,45 @@ import gzip import json -from .telemetry import print_info, print_err -from requests import RequestException -from newrelic_logging import VERSION, NAME, PROVIDER, COLLECTOR_NAME +from requests import RequestException, Session -class NewRelicApiException(Exception): - pass +from . import \ + VERSION, \ + NAME, \ + PROVIDER, \ + COLLECTOR_NAME, \ + NewRelicApiException +from .config import Config +from .telemetry import print_info -class NewRelic: - INGEST_SERVICE_VERSION = "v1" - US_LOGGING_ENDPOINT = "https://log-api.newrelic.com/log/v1" - EU_LOGGING_ENDPOINT = "https://log-api.eu.newrelic.com/log/v1" - LOGS_EVENT_SOURCE = 'logs' - US_EVENTS_ENDPOINT = "https://insights-collector.newrelic.com/v1/accounts/{account_id}/events" - EU_EVENTS_ENDPOINT = "https://insights-collector.eu01.nr-data.net/v1/accounts/{account_id}/events" +NR_LICENSE_KEY = 'NR_LICENSE_KEY' +NR_ACCOUNT_ID = 'NR_ACCOUNT_ID' + +US_LOGGING_ENDPOINT = 'https://log-api.newrelic.com/log/v1' +EU_LOGGING_ENDPOINT = 'https://log-api.eu.newrelic.com/log/v1' +LOGS_EVENT_SOURCE = 'logs' - CONTENT_ENCODING = 'gzip' +US_EVENTS_ENDPOINT = 'https://insights-collector.newrelic.com/v1/accounts/{account_id}/events' +EU_EVENTS_ENDPOINT = 'https://insights-collector.eu01.nr-data.net/v1/accounts/{account_id}/events' - logs_api_endpoint = US_LOGGING_ENDPOINT - logs_license_key = '' +CONTENT_ENCODING = 'gzip' +MAX_EVENTS = 2000 - events_api_endpoint = US_EVENTS_ENDPOINT - events_api_key = '' - @classmethod - def post_logs(cls, session, data): +class NewRelic: + def __init__( + self, + logs_api_endpoint, + logs_license_key, + events_api_endpoint, + events_api_key, + ): + self.logs_api_endpoint = logs_api_endpoint + self.logs_license_key = logs_license_key + self.events_api_endpoint = events_api_endpoint + self.events_api_key = events_api_key + + def post_logs(self, session: Session, data: list[dict]) -> None: # Append integration attributes for log in data[0]['logs']: if not 'attributes' in log: @@ -35,77 +49,87 @@ def post_logs(cls, session, data): log['attributes']['instrumentation.version'] = VERSION log['attributes']['collector.name'] = COLLECTOR_NAME - json_payload = json.dumps(data).encode() - - # print("----- POST DATA (LOGS) -----") - # print(json_payload.decode("utf-8")) - # print("----------------------------") - # return 202 - - payload = gzip.compress(json_payload) - headers = { - "X-License-Key": cls.logs_license_key, - "X-Event-Source": cls.LOGS_EVENT_SOURCE, - "Content-Encoding": cls.CONTENT_ENCODING, - } try: - r = session.post(cls.logs_api_endpoint, data=payload, - headers=headers) - except RequestException as e: - print_err(f"Failed posting logs to New Relic: {repr(e)}") - return 0 - - response = r.content.decode("utf-8") - print_info(f"NR Log API response body = {response}") - - return r.status_code - - @classmethod - def post_events(cls, session, data): + r = session.post( + self.logs_api_endpoint, + data=gzip.compress(json.dumps(data).encode()), + headers={ + 'X-License-Key': self.logs_license_key, + 'X-Event-Source': LOGS_EVENT_SOURCE, + 'Content-Encoding': CONTENT_ENCODING, + }, + ) + + if r.status_code != 202: + raise NewRelicApiException( + f'newrelic logs api returned code {r.status_code}' + ) + + response = r.content.decode("utf-8") + print_info(f"NR Log API response body = {response}") + except RequestException: + raise NewRelicApiException('newrelic logs api request failed') + + def post_events(self, session: Session, events: list[dict]) -> None: # Append integration attributes - for event in data: + for event in events: event['instrumentation.name'] = NAME event['instrumentation.provider'] = PROVIDER event['instrumentation.version'] = VERSION event['collector.name'] = COLLECTOR_NAME - json_payload = json.dumps(data).encode() - - # print("----- POST DATA (EVENTS) -----") - # print(json_payload.decode("utf-8")) - # print("------------------------------") - # return 200 - - payload = gzip.compress(json_payload) - headers = { - "Api-Key": cls.events_api_key, - "Content-Encoding": cls.CONTENT_ENCODING, - } - try: - r = session.post(cls.events_api_endpoint, data=payload, - headers=headers) - except RequestException as e: - print_err(f"Failed posting events to New Relic: {repr(e)}") - return 0 - - response = r.content.decode("utf-8") - print_info(f"NR Event API response body = {response}") - - return r.status_code - - @classmethod - def set_api_endpoint(cls, api_endpoint, account_id): - if api_endpoint == "US": - api_endpoint = NewRelic.US_EVENTS_ENDPOINT; - elif api_endpoint == "EU": - api_endpoint = NewRelic.EU_EVENTS_ENDPOINT - NewRelic.events_api_endpoint = api_endpoint.format(account_id='account_id') - - @classmethod - def set_logs_endpoint(cls, api_endpoint): - if api_endpoint == "US": - NewRelic.logs_api_endpoint = NewRelic.US_LOGGING_ENDPOINT - elif api_endpoint == "EU": - NewRelic.logs_api_endpoint = NewRelic.EU_LOGGING_ENDPOINT - else: - NewRelic.logs_api_endpoint = api_endpoint + # This funky code produces an array of arrays where each one will be at most + # length 2000 with the last one being <= 2000. This is done to account for + # the fact that only 2000 events can be posted at a time. + + slices = [events[i:(i + MAX_EVENTS)] \ + for i in range(0, len(events), MAX_EVENTS)] + + for slice in slices: + try: + r = session.post( + self.events_api_endpoint, + data=gzip.compress(json.dumps(slice).encode()), + headers={ + 'Api-Key': self.events_api_key, + 'Content-Encoding': CONTENT_ENCODING, + }, + ) + + if r.status_code != 200: + raise NewRelicApiException( + f'newrelic events api returned code {r.status_code}' + ) + + response = r.content.decode("utf-8") + print_info(f"NR Event API response body = {response}") + except RequestException: + raise NewRelicApiException('newrelic events api request failed') + + +def New( + config: Config, +): + license_key = config.get( + 'newrelic.license_key', + env_var_name=NR_LICENSE_KEY, + ) + + region = config.get('newrelic.api_endpoint') + account_id = config.get('newrelic.account_id', env_var_name=NR_ACCOUNT_ID) + + if region == "US": + logs_api_endpoint = US_LOGGING_ENDPOINT + events_api_endpoint = US_EVENTS_ENDPOINT.format(account_id=account_id) + elif region == "EU": + logs_api_endpoint = EU_LOGGING_ENDPOINT + events_api_endpoint = EU_EVENTS_ENDPOINT.format(account_id=account_id) + else: + raise NewRelicApiException(f'Invalid region {region}') + + return NewRelic( + logs_api_endpoint, + license_key, + events_api_endpoint, + license_key, + ) diff --git a/src/newrelic_logging/pipeline.py b/src/newrelic_logging/pipeline.py new file mode 100644 index 0000000..310fb93 --- /dev/null +++ b/src/newrelic_logging/pipeline.py @@ -0,0 +1,462 @@ +from copy import deepcopy +import csv +import datetime +import gc +from requests import Session + +from . import DataFormat, SalesforceApiException +from .cache import DataCache +from .config import Config +from .http_session import new_retry_session +from .newrelic import NewRelic +from .query import Query +from .telemetry import print_err, print_info +from .util import generate_record_id, get_row_timestamp, is_logfile_response + + +DEFAULT_CHUNK_SIZE = 4096 +DEFAULT_MAX_ROWS = 1000 +MAX_ROWS = 2000 + + +def pack_csv_into_log( + query: Query, + record_id: str, + record_event_type: str, + row: dict, + row_index: int, + event_type_fields_mapping: dict, +) -> dict: + attrs = {} + if record_event_type in event_type_fields_mapping: + for field in event_type_fields_mapping[record_event_type]: + attrs[field] = row[field] + else: + attrs = row + + timestamp = get_row_timestamp(row) + attrs.pop('TIMESTAMP', None) + + attrs['LogFileId'] = record_id + + actual_event_type = attrs.pop('EVENT_TYPE', "SFEvent") + new_event_type = query.get("event_type", actual_event_type) + attrs['EVENT_TYPE'] = new_event_type + + timestamp_field_name = query.get("rename_timestamp", "timestamp") + attrs[timestamp_field_name] = int(timestamp) + + log_entry = { + 'message': "LogFile " + record_id + " row " + str(row_index), + 'attributes': attrs + } + + if timestamp_field_name == 'timestamp': + log_entry[timestamp_field_name] = int(timestamp) + + return log_entry + + +def export_log_lines( + session: Session, + url: str, + access_token: str, + chunk_size: int, +): + print_info(f'Downloading log lines for log file: {url}') + + # Request the log lines for the log file record url + response = session.get( + url, + headers={ + 'Authorization': f'Bearer {access_token}' + }, + stream=True, + ) + if response.status_code != 200: + error_message = f'salesforce event log file download failed. ' \ + f'status-code: {response.status_code}, ' \ + f'reason: {response.reason} ' \ + f'response: {response.text}' + raise SalesforceApiException(response.status_code, error_message) + + # Stream the response as a set of lines. This function will return an + # iterator that yields one line at a time holding only the minimum + # amount of data chunks in memory to make up a single line + return response.iter_lines(chunk_size=chunk_size, decode_unicode=True) + + +def transform_log_lines( + iter, + query: Query, + record_id: str, + record_event_type: str, + data_cache: DataCache, + event_type_fields_mapping: dict, +): + # iter is a generator iterator that yields a single line at a time + reader = csv.DictReader(iter) + + # This should cause the reader to request the next line from the iterator + # which will cause the generator iterator to yield the next line + + row_index = 0 + + for row in reader: + # If we've already seen this log line, skip it + if data_cache and data_cache.check_and_set_log_line(record_id, row): + continue + + # Otherwise, pack it up for shipping and yield it for consumption + yield pack_csv_into_log( + query, + record_id, + record_event_type, + row, + row_index, + event_type_fields_mapping, + ) + + row_index += 1 + + +def pack_event_into_log( + query: Query, + record_id: str, + row: dict, +): + # Make a copy of it so we aren't modifying the row passed by the caller, and + # set attributes appropriately + attrs = deepcopy(row) + if record_id: + attrs['Id'] = record_id + + timestamp_attr = query.get('timestamp_attr', 'CreatedDate') + if timestamp_attr in attrs: + created_date = attrs[timestamp_attr] + timestamp = int(datetime.strptime( + created_date, '%Y-%m-%dT%H:%M:%S.%f%z').timestamp() * 1000 + ) + else: + created_date = "" + timestamp = int(datetime.now().timestamp() * 1000) + + message = query.get('event_type', 'SFEvent') + if 'attributes' in attrs and type(attrs['attributes']) == dict: + attributes = attrs.pop('attributes', []) + if 'type' in attributes and type(attributes['type']) == str: + attrs['EVENT_TYPE'] = message = \ + query.get('event_type', attributes['type']) + + if created_date != "": + message = message + " " + created_date + + timestamp_field_name = query.get('rename_timestamp', 'timestamp') + attrs[timestamp_field_name] = int(timestamp) + + log_entry = { + 'message': message, + 'attributes': attrs, + } + + if timestamp_field_name == 'timestamp': + log_entry[timestamp_field_name] = timestamp + + return log_entry + + +def transform_event_records(iter, query: Query, data_cache: DataCache): + # iter here is a list which does mean it's entirely held in memory but these + # are event records not log lines so hopefully it is # not as bad. + # @TODO figure out if we can stream event records + for row in iter: + record_id = row['Id'] if 'Id' in row \ + else generate_record_id(query.get('id', []), row) + + # If we've already seen this event record, skip it. + if data_cache and data_cache.check_and_set_event_id(record_id): + return None + + # Build a New Relic log record from the SF event record + yield pack_event_into_log( + query, + record_id, + row, + data_cache, + ) + + +def load_as_logs(iter, new_relic: NewRelic, labels: dict, max_rows: int): + nr_session = new_retry_session() + + logs = [] + count = total = 0 + + def send_logs(): + nonlocal logs + nonlocal count + + new_relic.post_logs(nr_session, [{'common': labels, 'logs': logs}]) + + print_info(f'Sent {count} log messages.') + + # Attempt to release memory + del logs + + logs = [] + count = 0 + + for log in iter: + if count == max_rows: + send_logs() + + logs.append(log) + + count += 1 + total += 1 + + if len(logs) > 0: + send_logs() + + print_info(f'Sent a total of {total} log messages.') + + # Attempt to reclaim memory + gc.collect() + + +def pack_log_into_event(log: dict, labels: dict, numeric_fields_list: set): + log_event = {} + + attributes = log['attributes'] + for event_name in attributes: + # currently no need to modify as we did not see any special chars + # that need to be removed + modified_event_name = event_name + event_value = attributes[event_name] + if event_name in numeric_fields_list: + if event_value: + try: + log_event[modified_event_name] = int(event_value) + except (TypeError, ValueError) as _: + try: + log_event[modified_event_name] = float(event_value) + except (TypeError, ValueError) as _: + print_err(f'Type conversion error for {event_name}[{event_value}]') + log_event[modified_event_name] = event_value + else: + log_event[modified_event_name] = 0 + else: + log_event[modified_event_name] = event_value + + log_event.update(labels) + log_event['eventType'] = log_event.get('EVENT_TYPE', "UnknownSFEvent") + + return log_event + + +def load_as_events( + iter, + new_relic: NewRelic, + labels: dict, + max_rows: int, + numeric_fields_list: set, +): + nr_session = new_retry_session() + + events = [] + count = total = 0 + + def send_events(): + nonlocal events + nonlocal count + + new_relic.post_events(nr_session, events) + + print_info(f'Sent {count} events.') + + # Attempt to release memory + del events + + events = [] + count = 0 + + + for log_entry in iter: + if count == max_rows: + send_events() + + events.append(pack_log_into_event( + log_entry, + labels, + numeric_fields_list, + )) + + count += 1 + total += 1 + + if len(events) > 0: + send_events() + + print_info(f'Sent a total of {total} events.') + + # Attempt to reclaim memory + gc.collect() + + +def load_data( + logs: dict, + new_relic: NewRelic, + data_format: DataFormat, + labels: dict, + max_rows: int, + numeric_fields_list: set, +): + if data_format == DataFormat.LOGS: + load_as_logs( + logs, + new_relic, + labels, + max_rows, + ) + return + + load_as_events( + logs, + new_relic, + labels, + max_rows, + numeric_fields_list, + ) + +class Pipeline: + def __init__( + self, + config: Config, + data_cache: DataCache, + new_relic: NewRelic, + data_format: DataFormat, + labels: dict, + event_type_field_mappings: dict, + numeric_fields_list: set, + ): + self.config = config + self.data_cache = data_cache + self.new_relic = new_relic + self.data_format = data_format + self.labels = labels + self.event_type_field_mappings = event_type_field_mappings + self.numeric_fields_list = numeric_fields_list + self.max_rows = max( + self.config.get('max_rows', DEFAULT_MAX_ROWS), + MAX_ROWS, + ) + + def process_log_record( + self, + session: Session, + query: Query, + instance_url: str, + access_token: str, + record: dict, + ): + record_id = str(record['Id']) + record_event_type = query.get("event_type", record['EventType']) + record_file_name = record['LogFile'] + interval = record['Interval'] + + # NOTE: only Hourly logs can be skipped, because Daily logs can change + # and the same record_id can contain different data. + if interval == 'Hourly' and self.data_cache and \ + self.data_cache.can_skip_downloading_logfile(record_id): + print_info( + f'Log lines for logfile with id {record_id} already cached, skipping download' + ) + return None + + if self.data_cache: + self.data_cache.load_cached_log_lines(record_id) + + load_data( + transform_log_lines( + export_log_lines( + session, + f'{instance_url}{record_file_name}', + access_token, + self.config.get('chunk_size', DEFAULT_CHUNK_SIZE) + ), + query, + record_id, + record_event_type, + self.data_cache, + self.event_type_field_mappings, + ), + self.new_relic, + self.data_format, + self.labels, + self.max_rows, + self.numeric_fields_list, + ) + + def process_event_records( + self, + query: Query, + records: list[dict], + ): + load_data( + transform_event_records( + records, + query, + self.data_cache, + self.config.get('max_rows', DEFAULT_MAX_ROWS) + ), + self.new_relic, + self.data_format, + self.labels, + self.max_rows, + self.numeric_fields_list, + ) + + def execute( + self, + session: Session, + query: Query, + instance_url: str, + access_token: str, + records: list[dict], + ): + if is_logfile_response(records): + for record in records: + if 'LogFile' in record: + self.process_log_record( + session, + query, + instance_url, + access_token, + record, + ) + + return + + self.process_event_records(query, records) + + # Flush the cache + self.data_cache.flush() + +def New( + config: Config, + data_cache: DataCache, + new_relic: NewRelic, + data_format: DataFormat, + labels: dict, + event_type_field_mappings: dict, + numeric_fields_list: set, +): + return Pipeline( + config, + data_cache, + new_relic, + data_format, + labels, + event_type_field_mappings, + numeric_fields_list, + ) diff --git a/src/newrelic_logging/query.py b/src/newrelic_logging/query.py index 26bb501..9299b53 100644 --- a/src/newrelic_logging/query.py +++ b/src/newrelic_logging/query.py @@ -1,55 +1,83 @@ +import copy +from datetime import datetime, timedelta +from requests import RequestException, Session + +from . import SalesforceApiException +from .config import Config +from .telemetry import print_info +from .util import substitute + class Query: - query = None - env = None - - def __init__(self, query) -> None: - if type(query) == dict: - self.query = query.get("query", "") - query.pop('query', None) - self.env = query - elif type(query) == str: - self.query = query - self.env = {} - - def get_query(self) -> str: - return self.query - - def set_query(self, query: str) -> None: + def __init__( + self, + query: str, + config: Config, + api_ver: str, + ): self.query = query - - def get_env(self) -> dict: - return self.env - -# NOTE: this sandbox can be jailbroken using the trick to exec statements inside an exec block, and run an import (and other tricks): -# https://book.hacktricks.xyz/generic-methodologies-and-resources/python/bypass-python-sandboxes#operators-and-short-tricks -# https://stackoverflow.com/a/3068475/2076108 -# Would be better to use a real sandbox like https://pypi.org/project/RestrictedPython/ or https://doc.pypy.org/en/latest/sandbox.html -# or parse a small language that only supports funcion calls and binary expressions. -def sandbox(code): - __import__ = None - __loader__ = None - __build_class__ = None - exec = None - - from datetime import datetime, timedelta - - def sf_time(t: datetime): - return t.isoformat(timespec='milliseconds') + "Z" - - def now(delta: timedelta = None): - if delta: - return sf_time(datetime.utcnow() + delta) - else: - return sf_time(datetime.utcnow()) - - try: - return eval(code) - except Exception as e: - return e - -def substitute(args: dict, query_template: str, env: dict) -> str: - for key, command in env.items(): - args[key] = sandbox(command) - for key, val in args.items(): - query_template = query_template.replace('{' + key + '}', val) - return query_template + self.config = config + self.api_ver = api_ver + + def get(self, key: str, default = None): + return self.config.get(key, default) + + def get_config(self): + return self.config + + def execute( + self, + session: Session, + instance_url: str, + access_token: str, + ): + url = f'{instance_url}/services/data/v{self.api_ver}/query?q={self.query}' + + try: + print_info(f'Running query {self.query} using url {url}') + + query_response = session.get(url, headers={ + 'Authorization': f'Bearer {access_token}' + }) + if query_response.status_code != 200: + raise SalesforceApiException( + query_response.status_code, + f'error when trying to run SOQL query. ' \ + f'status-code:{query_response.status_code}, ' \ + f'reason: {query_response.reason} ' \ + f'response: {query_response.text} ' + ) + + return query_response.json() + except RequestException as e: + raise SalesforceApiException( + -1, + f'error when trying to run SOQL query. cause: {e}', + ) from e + + +def New( + q: dict, + time_lag_minutes: int, + last_to_timestamp: str, + generation_interval: str, + default_api_ver: str, +) -> Query: + to_timestamp = ( + datetime.utcnow() - timedelta(minutes=time_lag_minutes) + ).isoformat(timespec='milliseconds') + "Z" + from_timestamp = last_to_timestamp + + qp = copy.deepcopy(q) + qq = qp.pop('query', '') + + args = { + 'to_timestamp': to_timestamp, + 'from_timestamp': from_timestamp, + 'log_interval_type': generation_interval, + } + + return Query( + substitute(args, qq, qp).replace(' ', '+'), + Config(qp), + qp.get('api_ver', default_api_ver) + ) diff --git a/src/newrelic_logging/salesforce.py b/src/newrelic_logging/salesforce.py index 61085da..c8d15cb 100644 --- a/src/newrelic_logging/salesforce.py +++ b/src/newrelic_logging/salesforce.py @@ -1,29 +1,16 @@ -import base64 -import csv -import json -import sys from datetime import datetime, timedelta -import jwt -from cryptography.hazmat.primitives import serialization -import pytz -from requests import RequestException -import copy -import hashlib -from .cache import make_cache -from .env import Auth, AuthEnv -from .query import Query, substitute -from .telemetry import print_info, print_err +from requests import Session -class LoginException(Exception): - pass +from . import DataFormat +from .auth import Authenticator +from .cache import DataCache +from . import config as mod_config +from .pipeline import Pipeline +from . import query as mod_query +from .telemetry import print_info -class SalesforceApiException(Exception): - err_code = 0 - def __init__(self, err_code: int, *args: object) -> None: - self.err_code = err_code - super().__init__(*args) - pass +CSV_SLICE_SIZE = 1000 SALESFORCE_CREATED_DATE_QUERY = \ "SELECT Id,EventType,CreatedDate,LogDate,Interval,LogFile,Sequence From EventLogFile Where CreatedDate>={" \ "from_timestamp} AND CreatedDate<{to_timestamp} AND Interval='{log_interval_type}'" @@ -31,548 +18,120 @@ def __init__(self, err_code: int, *args: object) -> None: "SELECT Id,EventType,CreatedDate,LogDate,Interval,LogFile,Sequence From EventLogFile Where LogDate>={" \ "from_timestamp} AND LogDate<{to_timestamp} AND Interval='{log_interval_type}'" -CSV_SLICE_SIZE = 1000 - -def base64_url_encode(json_obj): - json_str = json.dumps(json_obj) - encoded_bytes = base64.urlsafe_b64encode(json_str.encode('utf-8')) - encoded_str = str(encoded_bytes, 'utf-8') - return encoded_str class SalesForce: - auth = None - oauth_type = None - token_url = '' - query_template = None - data_cache = None - default_api_ver = '' - - def __init__(self, auth_env: AuthEnv, instance_name, config, event_type_fields_mapping, initial_delay, queries=[]): + def __init__( + self, + instance_name: str, + config: mod_config.Config, + data_cache: DataCache, + authenticator: Authenticator, + pipeline: Pipeline, + initial_delay: int, + queries=None, + ): self.instance_name = instance_name self.default_api_ver = config.get('api_ver', '52.0') - if 'auth' in config: - self.auth_data = config['auth'] - else: - self.auth_data = {'grant_type': auth_env.get_grant_type()} - if self.auth_data['grant_type'] == 'password': - # user/pass flow - try: - self.auth_data["client_id"] = auth_env.get_client_id() - self.auth_data["client_secret"] = auth_env.get_client_secret() - self.auth_data["username"] = auth_env.get_username() - self.auth_data["password"] = auth_env.get_password() - except: - print_err(f'Missing credentials for user/pass flow') - sys.exit(1) - elif self.auth_data['grant_type'] == 'urn:ietf:params:oauth:grant-type:jwt-bearer': - # jwt flow - try: - self.auth_data["client_id"] = auth_env.get_client_id() - self.auth_data["private_key"] = auth_env.get_private_key() - self.auth_data["subject"] = auth_env.get_subject() - self.auth_data["audience"] = auth_env.get_audience() - except: - print_err(f'Missing credentials for JWT flow') - sys.exit(1) - else: - print_err(f'Wrong or missing grant_type') - sys.exit(1) - - if 'token_url' in config: - self.token_url = config['token_url'] - else: - self.token_url = auth_env.get_token_url() - - try: - self.time_lag_minutes = config['time_lag_minutes'] - self.generation_interval = config['generation_interval'] - self.date_field = config['date_field'] - except KeyError as e: - print_err(f'Please specify a "{e.args[0]}" parameter for sfdc instance "{instance_name}" in config.yml') - sys.exit(1) - + self.data_cache = data_cache + self.auth = authenticator + self.time_lag_minutes = config.get( + mod_config.CONFIG_TIME_LAG_MINUTES, + mod_config.DEFAULT_TIME_LAG_MINUTES if not self.data_cache else 0, + ) + self.date_field = config.get( + mod_config.CONFIG_DATE_FIELD, + mod_config.DATE_FIELD_LOG_DATE if not self.data_cache \ + else mod_config.DATE_FIELD_CREATE_DATE, + ) + self.generation_interval = config.get( + mod_config.CONFIG_GENERATION_INTERVAL, + mod_config.DEFAULT_GENERATION_INTERVAL, + ) self.last_to_timestamp = (datetime.utcnow() - timedelta( - minutes=self.time_lag_minutes + initial_delay)).isoformat(timespec='milliseconds') + "Z" - - if len(queries) > 0: - self.query_template = queries - else: - if self.date_field.lower() == "logdate": - self.query_template = SALESFORCE_LOG_DATE_QUERY - else: - self.query_template = SALESFORCE_CREATED_DATE_QUERY - - self.data_cache = make_cache(config) - self.event_type_fields_mapping = event_type_fields_mapping + minutes=self.time_lag_minutes + initial_delay + )).isoformat(timespec='milliseconds') + 'Z' - def clear_auth(self): - if self.data_cache: - try: - self.data_cache.redis.delete("auth") - except Exception as e: - print_err(f"Failed deleting 'auth' key from Redis: {e}") - exit(1) - self.auth = None - - def store_auth(self, auth_resp): - access_token = auth_resp['access_token'] - instance_url = auth_resp['instance_url'] - token_type = auth_resp['token_type'] - if self.data_cache: - print_info("Storing credentials on Redis.") - auth = { - "access_token": access_token, - "instance_url": instance_url, - "token_type": token_type - } - try: - self.data_cache.redis.hmset("auth", auth) - except Exception as e: - print_err(f"Failed setting 'auth' key: {e}") - exit(1) - self.auth = Auth(access_token, instance_url, token_type) - - def authenticate(self, oauth_type, session): - self.oauth_type = oauth_type - if self.data_cache: - try: - auth_exists = self.data_cache.redis.exists("auth") - except Exception as e: - print_err(f"Failed checking 'auth' key: {e}") - exit(1) - if auth_exists: - print_info("Retrieving credentials from Redis.") - #NOTE: hmget and hgetall both return byte arrays, not strings. We have to convert. - # We could fix it by adding the argument "decode_responses=True" to Redis constructor, - # but then we would have to change all places where we assume a byte array instead of a string, - # and refactoring in a language without static types is a pain. - try: - auth = self.data_cache.redis.hmget("auth", ["access_token", "instance_url", "token_type"]) - auth = { - "access_token": auth[0].decode("utf-8"), - "instance_url": auth[1].decode("utf-8"), - "token_type": auth[2].decode("utf-8") - } - self.store_auth(auth) - except Exception as e: - print_err(f"Failed getting 'auth' key: {e}") - exit(1) - - return True - - if oauth_type == 'password': - if not self.authenticate_with_password(session): - print_err(f"Error authenticating with {self.token_url}") - return False - print_info("Correctly authenticated with user/pass flow") + if queries: + self.queries = queries else: - if not self.authenticate_with_jwt(session): - print_err(f"Error authenticating with {self.token_url}") - return False - print_info("Correctly authenticated with JWT flow") - return True - - def authenticate_with_jwt(self, session): - try: - private_key_file = self.auth_data['private_key'] - client_id = self.auth_data['client_id'] - subject = self.auth_data['subject'] - audience = self.auth_data['audience'] - except KeyError as e: - print_err(f'Please specify a "{e.args[0]}" parameter under "auth" section ' - 'of salesforce instance in config.yml') - sys.exit(1) + self.queries = [{ + 'query': SALESFORCE_LOG_DATE_QUERY \ + if self.date_field.lower() == 'logdate' \ + else SALESFORCE_CREATED_DATE_QUERY + }] - exp = int((datetime.utcnow() - timedelta(minutes=5)).timestamp()) + self.pipeline = pipeline - private_key = open(private_key_file, 'r').read() - try: - key = serialization.load_ssh_private_key(private_key.encode(), password=b'') - except ValueError as e: - print_err(f'Authentication failed for {self.instance_name}. error message: {str(e)}') - return False - - jwt_claim_set = {"iss": client_id, - "sub": subject, - "aud": audience, - "exp": exp} - - signed_token = jwt.encode( - jwt_claim_set, - key, - algorithm='RS256', - ) - - params = { - "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", - "assertion": signed_token, - "format": "json" - } - - headers = { - "Content-Type": "application/x-www-form-urlencoded", - "Accept": "application/json" - } - - try: - print_info(f'retrieving salesforce token at {self.token_url}') - resp = session.post(self.token_url, params=params, - headers=headers) - if resp.status_code != 200: - error_message = f'sfdc token request failed. http-status-code:{resp.status_code}, reason: {resp.text}' - print_err(f'Authentication failed for {self.instance_name}. message: {error_message}', file=sys.stderr) - return False - - self.store_auth(resp.json()) - return True - except ConnectionError as e: - print_err(f"SFDC auth failed for instance {self.instance_name}: {repr(e)}") - raise LoginException(f'authentication failed for sfdc instance {self.instance_name}') from e - except RequestException as e: - print_err(f"SFDC auth failed for instance {self.instance_name}: {repr(e)}") - raise LoginException(f'authentication failed for sfdc instance {self.instance_name}') from e - - def authenticate_with_password(self, session): - client_id = self.auth_data['client_id'] - client_secret = self.auth_data['client_secret'] - username = self.auth_data['username'] - password = self.auth_data['password'] - - params = { - "grant_type": "password", - "client_id": client_id, - "client_secret": client_secret, - "username": username, - "password": password - } - - headers = { - "Content-Type": "application/x-www-form-urlencoded", - "Accept": "application/json" - } - - try: - print_info(f'retrieving salesforce token at {self.token_url}') - resp = session.post(self.token_url, params=params, - headers=headers) - if resp.status_code != 200: - error_message = f'salesforce token request failed. status-code:{resp.status_code}, reason: {resp.reason}' - print_err(error_message) - return False - - self.store_auth(resp.json()) - return True - except ConnectionError as e: - print_err(f"SFDC auth failed for instance {self.instance_name}: {repr(e)}") - raise LoginException(f'authentication failed for sfdc instance {self.instance_name}') from e - except RequestException as e: - print_err(f"SFDC auth failed for instance {self.instance_name}: {repr(e)}") - raise LoginException(f'authentication failed for sfdc instance {self.instance_name}') from e - - def make_multiple_queries(self, query_objects) -> list[Query]: - return [self.make_single_query(Query(obj)) for obj in query_objects] - - def make_single_query(self, query_obj: Query) -> Query: - to_timestamp = (datetime.utcnow() - timedelta(minutes=self.time_lag_minutes)).isoformat( - timespec='milliseconds') + "Z" - from_timestamp = self.last_to_timestamp - - env = copy.deepcopy(query_obj.get_env().get('env', {})) - args = { - 'to_timestamp': to_timestamp, - 'from_timestamp': from_timestamp, - 'log_interval_type': self.generation_interval - } - query = substitute(args, query_obj.get_query(), env) - query = query.replace(' ', '+') - - query_obj.set_query(query) - return query_obj + def authenticate(self, sfdc_session: Session): + self.auth.authenticate(sfdc_session) def slide_time_range(self): - self.last_to_timestamp = (datetime.utcnow() - timedelta(minutes=self.time_lag_minutes)).isoformat( - timespec='milliseconds') + "Z" - - def execute_query(self, query: Query, session): - api_ver = query.get_env().get("api_ver", self.default_api_ver) - url = f'{self.auth.get_instance_url()}/services/data/v{api_ver}/query?q={query.get_query()}' - - try: - headers = { - 'Authorization': f'Bearer {self.auth.get_access_token()}' - } - query_response = session.get(url, headers=headers) - if query_response.status_code != 200: - error_message = f'salesforce event log query failed. ' \ - f'status-code:{query_response.status_code}, ' \ - f'reason: {query_response.reason} ' \ - f'response: {query_response.text} ' - - print_err(f"SOQL query failed with code {query_response.status_code}: {error_message}") - raise SalesforceApiException(query_response.status_code, f'error when trying to run SOQL query. message: {error_message}') - return query_response.json() - except RequestException as e: - print_err(f"Error while trying SOQL query: {repr(e)}") - raise SalesforceApiException(-1, f'error when trying to run SOQL query. cause: {e}') from e + self.last_to_timestamp = ( + datetime.utcnow() - timedelta(minutes=self.time_lag_minutes)) \ + .isoformat(timespec='milliseconds') + "Z" # NOTE: Is it possible that different SF orgs have overlapping IDs? If this is possible, we should use a different # database for each org, or add a prefix to keys to avoid conflicts. - def download_file(self, session, url): - print_info(f"Downloading CSV file: {url}") - - headers = { - 'Authorization': f'Bearer {self.auth.get_access_token()}' - } - response = session.get(url, headers=headers) - if response.status_code != 200: - error_message = f'salesforce event log file download failed. ' \ - f'status-code: {response.status_code}, ' \ - f'reason: {response.reason} ' \ - f'response: {response.text}' - print_err(error_message) - raise SalesforceApiException(response.status_code, error_message) - return response - - def parse_csv(self, download_response, record_id, record_event_type, cached_messages): - content = download_response.content.decode('utf-8') - reader = csv.DictReader(content.splitlines()) - rows = [] - for row in reader: - if self.data_cache and self.data_cache.record_or_skip_row(record_id, row, cached_messages): - continue - rows.append(row) - return rows - - def fetch_logs(self, session): - print_info(f"Query object = {self.query_template}") - - if type(self.query_template) is list: - # "query_template" contains a list of objects, each one is a Query object - queries = self.make_multiple_queries(copy.deepcopy(self.query_template)) - response = self.fetch_logs_from_multiple_req(session, queries) - self.slide_time_range() - return response - else: - # "query_template" contains a string with the SOQL to run. - query = self.make_single_query(Query(self.query_template)) - response = self.fetch_logs_from_single_req(session, query) - self.slide_time_range() - return response - - def fetch_logs_from_multiple_req(self, session, queries: list[Query]): - logs = [] - for query in queries: - part_logs = self.fetch_logs_from_single_req(session, query) - logs.extend(part_logs) - return logs - - def fetch_logs_from_single_req(self, session, query: Query): - print_info(f'Running query {query.get_query()}') - response = self.execute_query(query, session) - - # Show query response - #print("Response = ", response) - - records = response['records'] - if self.is_logfile_response(records): - logs = [] - for record in records: - if 'LogFile' in record: - log = self.build_log_from_logfile(True, session, record, query) - if log is not None: - logs.extend(log) - else: - logs = self.build_log_from_event(records, query) - - return logs - - def is_logfile_response(self, records): - if len(records) > 0: - return 'LogFile' in records[0] - else: - return True - - def build_log_from_event(self, records, query: Query): - logs = [] - while True: - part_rows = self.extract_row_slice(records) - if len(part_rows) > 0: - logs.append(self.pack_event_into_log(part_rows, query)) - else: - break - return logs - - def pack_event_into_log(self, rows, query: Query): - log_entries = [] - for row in rows: - if 'Id' in row: - record_id = row['Id'] - if self.data_cache and self.data_cache.check_cached_id(record_id): - # Record cached, skip it - continue - else: - id_keys = query.get_env().get("id", []) - compound_id = "" - for key in id_keys: - if key not in row: - print_err(f"Error building compound id, key '{key}' not found") - raise Exception(f"Error building compound id, key '{key}' not found") - compound_id = compound_id + str(row.get(key, "")) - if compound_id != "": - m = hashlib.sha3_256() - m.update(compound_id.encode('utf-8')) - row['Id'] = m.hexdigest() - record_id = row['Id'] - if self.data_cache and self.data_cache.check_cached_id(record_id): - # Record cached, skip it - continue - - timestamp_attr = query.get_env().get("timestamp_attr", "CreatedDate") - if timestamp_attr in row: - created_date = row[timestamp_attr] - timestamp = int(datetime.strptime(created_date, '%Y-%m-%dT%H:%M:%S.%f%z').timestamp() * 1000) - else: - created_date = "" - timestamp = int(datetime.now().timestamp() * 1000) - - message = query.get_env().get("event_type", "SFEvent") - if 'attributes' in row and type(row['attributes']) == dict: - attributes = row.pop('attributes', []) - if 'type' in attributes and type(attributes['type']) == str: - event_type_attr_name = query.get_env().get("event_type", attributes['type']) - message = event_type_attr_name - row['EVENT_TYPE'] = event_type_attr_name - - if created_date != "": - message = message + " " + created_date - - timestamp_field_name = query.get_env().get("rename_timestamp", "timestamp") - row[timestamp_field_name] = int(timestamp) - - log_entry = { - 'message': message, - 'attributes': row, - } - - if timestamp_field_name == 'timestamp': - log_entry[timestamp_field_name] = timestamp - - log_entries.append(log_entry) - return { - 'log_entries': log_entries - } - - def build_log_from_logfile(self, retry, session, record, query: Query): - record_file_name = record['LogFile'] - record_id = str(record['Id']) - interval = record['Interval'] - record_event_type = query.get_env().get("event_type", record['EventType']) - - # NOTE: only Hourly logs can be skipped, because Daily logs can change and the same record_id can contain different data. - if interval == 'Hourly' and self.data_cache and \ - self.data_cache.can_skip_downloading_record(record_id): - print_info(f"Record {record_id} already cached, skip downloading CSV") - return None - - cached_messages = None if not self.data_cache else \ - self.data_cache.retrieve_cached_message_list(record_id) - - try: - download_response = self.download_file(session, f'{self.auth.get_instance_url()}{record_file_name}') - if download_response is None: - return None - except SalesforceApiException as e: - if e.err_code == 401: - if retry: - print_err("invalid token while downloading CSV file, retry auth and download...") - self.clear_auth() - if self.authenticate(self.oauth_type, session): - return self.build_log_from_logfile(False, session, record, query) - else: - return None - else: - print_err(f'salesforce event log file "{record_file_name}" download failed: {e}') - return None - else: - print_err(f'salesforce event log file "{record_file_name}" download failed: {e}') - return None - except RequestException as e: - print_err(f'salesforce event log file "{record_file_name}" download failed: {e}') - return None - - csv_rows = self.parse_csv(download_response, record_id, record_event_type, cached_messages) - - print_info(f"CSV rows = {len(csv_rows)}") - - # Split CSV rows into smaller chunks to avoid hitting API payload limits - logs = [] - row_offset = 0 - while True: - part_rows = self.extract_row_slice(csv_rows) - part_rows_len = len(part_rows) - if part_rows_len > 0: - logs.append(self.pack_csv_into_log(record, row_offset, part_rows, query)) - row_offset += part_rows_len - else: - break - - return logs - - def pack_csv_into_log(self, record, row_offset, csv_rows, query: Query): - record_id = str(record['Id']) - record_event_type = query.get_env().get("event_type", record['EventType']) - - log_entries = [] - for row_index, row in enumerate(csv_rows): - message = {} - if record_event_type in self.event_type_fields_mapping: - for field in self.event_type_fields_mapping[record_event_type]: - message[field] = row[field] - else: - message = row - - if row.get('TIMESTAMP'): - timestamp_obj = datetime.strptime(row.get('TIMESTAMP'), '%Y%m%d%H%M%S.%f') - timestamp = pytz.utc.localize(timestamp_obj).replace(microsecond=0).timestamp() - else: - timestamp = datetime.utcnow().replace(microsecond=0).timestamp() - - message['LogFileId'] = record_id - message.pop('TIMESTAMP', None) - - actual_event_type = message.pop('EVENT_TYPE', "SFEvent") - new_event_type = query.get_env().get("event_type", actual_event_type) - message['EVENT_TYPE'] = new_event_type - - timestamp_field_name = query.get_env().get("rename_timestamp", "timestamp") - message[timestamp_field_name] = int(timestamp) - - log_entry = { - 'message': "LogFile " + record_id + " row " + str(row_index + row_offset), - 'attributes': message - } - - if timestamp_field_name == 'timestamp': - log_entry[timestamp_field_name] = int(timestamp) - - log_entries.append(log_entry) - - return { - 'log_type': record_event_type, - 'Id': record_id, - 'CreatedDate': record['CreatedDate'], - 'LogDate': record['LogDate'], - 'log_entries': log_entries - } - - # Slice record into smaller chunks - def extract_row_slice(self, rows): - part_rows = [] - i = 0 - while len(rows) > 0: - part_rows.append(rows.pop()) - i += 1 - if i >= CSV_SLICE_SIZE: - break - return part_rows + def fetch_logs(self, session: Session) -> list[dict]: + print_info(f"Queries = {self.queries}") + + for query in self.queries: + response = mod_query.New( + query, + self.time_lag_minutes, + self.last_to_timestamp, + self.generation_interval, + self.default_api_ver, + ).execute( + session, + self.auth.get_instance_url(), + self.auth.get_access_token(), + ) + + # Show query response + #print("Response = ", response) + + self.pipeline.execute( + session, + query, + self.auth.get_instance_url(), + self.auth.get_access_token(), + response['records'], + ) + + self.slide_time_range() + +# @TODO need to handle this logic but only when exporting logfiles and at this +# level we don't make a distinction but in the pipeline we don't have the right +# info from this level to reauth +# +# try: +# download_response = download_file(session, f'{url}{record_file_name}') +# if download_response is None: +# return +# except SalesforceApiException as e: +# pass +# if e.err_code == 401: +# if retry: +# print_err("invalid token while downloading CSV file, retry auth and download...") +# self.clear_auth() +# if self.authenticate(self.oauth_type, session): +# return self.build_log_from_logfile(False, session, record, query) +# else: +# return None +# else: +# print_err(f'salesforce event log file "{record_file_name}" download failed: {e}') +# return None +# else: +# print_err(f'salesforce event log file "{record_file_name}" download failed: {e}') +# return None +# except RequestException as e: +# print_err( +# f'salesforce event log file "{record_file_name}" download failed: {e}' +# ) +# return +# +# csv_rows = self.parse_csv(download_response, record_id, record_event_type, cached_messages) +# +# print_info(f"CSV rows = {len(csv_rows)}") diff --git a/src/newrelic_logging/util.py b/src/newrelic_logging/util.py new file mode 100644 index 0000000..dd66414 --- /dev/null +++ b/src/newrelic_logging/util.py @@ -0,0 +1,79 @@ +from datetime import datetime, timedelta +import hashlib +import pytz + +def is_logfile_response(records): + if len(records) > 0: + return 'LogFile' in records[0] + else: + return True + + +def get_row_timestamp(row): + epoch = row.get('TIMESTAMP') + + if epoch: + return pytz.utc.localize( + datetime.strptime(epoch, '%Y%m%d%H%M%S.%f') + ).replace(microsecond=0).timestamp() + + return datetime.utcnow().replace(microsecond=0).timestamp() + + +def generate_record_id(id_keys: list[str], row: dict) -> str: + compound_id = '' + for key in id_keys: + if key not in row: + raise Exception( + f'error building compound id, key \'{key}\' not found' + ) + + compound_id = compound_id + str(row.get(key, '')) + + if compound_id != '': + m = hashlib.sha3_256() + m.update(compound_id.encode('utf-8')) + return m.hexdigest() + + return '' + + +# NOTE: this sandbox can be jailbroken using the trick to exec statements inside +# an exec block, and run an import (and other tricks): +# https://book.hacktricks.xyz/generic-methodologies-and-resources/python/bypass-python-sandboxes#operators-and-short-tricks +# https://stackoverflow.com/a/3068475/2076108 +# Would be better to use a real sandbox like +# https://pypi.org/project/RestrictedPython/ or https://doc.pypy.org/en/latest/sandbox.html +# or parse a small language that only supports funcion calls and binary +# expressions. +# +# @TODO See if we can do this a different way We shouldn't be executing eval ever. + +def sandbox(code): + __import__ = None + __loader__ = None + __build_class__ = None + exec = None + + + def sf_time(t: datetime): + return t.isoformat(timespec='milliseconds') + "Z" + + def now(delta: timedelta = None): + if delta: + return sf_time(datetime.utcnow() + delta) + else: + return sf_time(datetime.utcnow()) + + try: + return eval(code) + except Exception as e: + return e + + +def substitute(args: dict, template: str, env: dict) -> str: + for key, command in env.items(): + args[key] = sandbox(command) + for key, val in args.items(): + template = template.replace('{' + key + '}', val) + return template From 6644a6a6b815767b55dc10b62ff4f0cd1365e369 Mon Sep 17 00:00:00 2001 From: Scott DeWitt Date: Thu, 7 Mar 2024 15:14:19 -0500 Subject: [PATCH 02/11] feat: optimize memory usage part 2 - add APM agent --- newrelic.ini | 256 +++++++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 3 +- src/__main__.py | 5 +- 3 files changed, 262 insertions(+), 2 deletions(-) create mode 100644 newrelic.ini diff --git a/newrelic.ini b/newrelic.ini new file mode 100644 index 0000000..0de7da5 --- /dev/null +++ b/newrelic.ini @@ -0,0 +1,256 @@ +# --------------------------------------------------------------------------- + +# +# This file configures the New Relic Python Agent. +# +# The path to the configuration file should be supplied to the function +# newrelic.agent.initialize() when the agent is being initialized. +# +# The configuration file follows a structure similar to what you would +# find for Microsoft Windows INI files. For further information on the +# configuration file format see the Python ConfigParser documentation at: +# +# https://docs.python.org/library/configparser.html +# +# For further discussion on the behaviour of the Python agent that can +# be configured via this configuration file see: +# +# https://docs.newrelic.com/docs/apm/agents/python-agent/configuration/python-agent-configuration/ +# + +# --------------------------------------------------------------------------- + +# Here are the settings that are common to all environments. + +[newrelic] + +# You must specify the license key associated with your New +# Relic account. This may also be set using the NEW_RELIC_LICENSE_KEY +# environment variable. This key binds the Python Agent's data to +# your account in the New Relic service. For more information on +# storing and generating license keys, see +# https://docs.newrelic.com/docs/apis/intro-apis/new-relic-api-keys/#ingest-license-key +#license_key = [YOUR LICENSE KEY] + +# The application name. Set this to be the name of your +# application as you would like it to show up in New Relic UI. +# You may also set this using the NEW_RELIC_APP_NAME environment variable. +# The UI will then auto-map instances of your application into a +# entry on your home dashboard page. You can also specify multiple +# app names to group your aggregated data. For further details, +# please see: +# https://docs.newrelic.com/docs/apm/agents/manage-apm-agents/app-naming/use-multiple-names-app/ +app_name = Salesforce Eventlogfile Integration + +# When "true", the agent collects performance data about your +# application and reports this data to the New Relic UI at +# newrelic.com. This global switch is normally overridden for +# each environment below. It may also be set using the +# NEW_RELIC_MONITOR_MODE environment variable. +monitor_mode = true + +# Sets the name of a file to log agent messages to. Whatever you +# set this to, you must ensure that the permissions for the +# containing directory and the file itself are correct, and +# that the user that your web application runs as can write out +# to the file. If not able to out a log file, it is also +# possible to say "stderr" and output to standard error output. +# This would normally result in output appearing in your web +# server log. It can also be set using the NEW_RELIC_LOG +# environment variable. +log_file = stdout + +# Sets the level of detail of messages sent to the log file, if +# a log file location has been provided. Possible values, in +# increasing order of detail, are: "critical", "error", "warning", +# "info" and "debug". When reporting any agent issues to New +# Relic technical support, the most useful setting for the +# support engineers is "debug". However, this can generate a lot +# of information very quickly, so it is best not to keep the +# agent at this level for longer than it takes to reproduce the +# problem you are experiencing. This may also be set using the +# NEW_RELIC_LOG_LEVEL environment variable. +log_level = info + +# High Security Mode enforces certain security settings, and prevents +# them from being overridden, so that no sensitive data is sent to New +# Relic. Enabling High Security Mode means that request parameters are +# not collected and SQL can not be sent to New Relic in its raw form. +# To activate High Security Mode, it must be set to 'true' in this +# local .ini configuration file AND be set to 'true' in the +# server-side configuration in the New Relic user interface. It can +# also be set using the NEW_RELIC_HIGH_SECURITY environment variable. +# For details, see +# https://docs.newrelic.com/docs/subscriptions/high-security +high_security = false + +# The Python Agent will attempt to connect directly to the New +# Relic service. If there is an intermediate firewall between +# your host and the New Relic service that requires you to use a +# HTTP proxy, then you should set both the "proxy_host" and +# "proxy_port" settings to the required values for the HTTP +# proxy. The "proxy_user" and "proxy_pass" settings should +# additionally be set if proxy authentication is implemented by +# the HTTP proxy. The "proxy_scheme" setting dictates what +# protocol scheme is used in talking to the HTTP proxy. This +# would normally always be set as "http" which will result in the +# agent then using a SSL tunnel through the HTTP proxy for end to +# end encryption. +# See https://docs.newrelic.com/docs/apm/agents/python-agent/configuration/python-agent-configuration/#proxy +# for information on proxy configuration via environment variables. +# proxy_scheme = http +# proxy_host = hostname +# proxy_port = 8080 +# proxy_user = +# proxy_pass = + +# Capturing request parameters is off by default. To enable the +# capturing of request parameters, first ensure that the setting +# "attributes.enabled" is set to "true" (the default value), and +# then add "request.parameters.*" to the "attributes.include" +# setting. For details about attributes configuration, please +# consult the documentation. +# attributes.include = request.parameters.* + +# The transaction tracer captures deep information about slow +# transactions and sends this to the UI on a periodic basis. The +# transaction tracer is enabled by default. Set this to "false" +# to turn it off. +transaction_tracer.enabled = true + +# Threshold in seconds for when to collect a transaction trace. +# When the response time of a controller action exceeds this +# threshold, a transaction trace will be recorded and sent to +# the UI. Valid values are any positive float value, or (default) +# "apdex_f", which will use the threshold for a dissatisfying +# Apdex controller action - four times the Apdex T value. +transaction_tracer.transaction_threshold = apdex_f + +# When the transaction tracer is on, SQL statements can +# optionally be recorded. The recorder has three modes, "off" +# which sends no SQL, "raw" which sends the SQL statement in its +# original form, and "obfuscated", which strips out numeric and +# string literals. +transaction_tracer.record_sql = obfuscated + +# Threshold in seconds for when to collect stack trace for a SQL +# call. In other words, when SQL statements exceed this +# threshold, then capture and send to the UI the current stack +# trace. This is helpful for pinpointing where long SQL calls +# originate from in an application. +transaction_tracer.stack_trace_threshold = 0.5 + +# Determines whether the agent will capture query plans for slow +# SQL queries. Only supported in MySQL and PostgreSQL. Set this +# to "false" to turn it off. +transaction_tracer.explain_enabled = true + +# Threshold for query execution time below which query plans +# will not not be captured. Relevant only when "explain_enabled" +# is true. +transaction_tracer.explain_threshold = 0.5 + +# Space separated list of function or method names in form +# 'module:function' or 'module:class.function' for which +# additional function timing instrumentation will be added. +transaction_tracer.function_trace = + +# The error collector captures information about uncaught +# exceptions or logged exceptions and sends them to UI for +# viewing. The error collector is enabled by default. Set this +# to "false" to turn it off. For more details on errors, see +# https://docs.newrelic.com/docs/apm/agents/manage-apm-agents/agent-data/manage-errors-apm-collect-ignore-or-mark-expected/ +error_collector.enabled = true + +# To stop specific errors from reporting to the UI, set this to +# a space separated list of the Python exception type names to +# ignore. The exception name should be of the form 'module:class'. +error_collector.ignore_classes = + +# Expected errors are reported to the UI but will not affect the +# Apdex or error rate. To mark specific errors as expected, set this +# to a space separated list of the Python exception type names to +# expected. The exception name should be of the form 'module:class'. +error_collector.expected_classes = + +# Browser monitoring is the Real User Monitoring feature of the UI. +# For those Python web frameworks that are supported, this +# setting enables the auto-insertion of the browser monitoring +# JavaScript fragments. +browser_monitoring.auto_instrument = true + +# A thread profiling session can be scheduled via the UI when +# this option is enabled. The thread profiler will periodically +# capture a snapshot of the call stack for each active thread in +# the application to construct a statistically representative +# call tree. For more details on the thread profiler tool, see +# https://docs.newrelic.com/docs/apm/apm-ui-pages/events/thread-profiler-tool/ +thread_profiler.enabled = true + +# Your application deployments can be recorded through the +# New Relic REST API. To use this feature provide your API key +# below then use the `newrelic-admin record-deploy` command. +# This can also be set using the NEW_RELIC_API_KEY +# environment variable. +# api_key = + +# Distributed tracing lets you see the path that a request takes +# through your distributed system. For more information, please +# consult our distributed tracing planning guide. +# https://docs.newrelic.com/docs/transition-guide-distributed-tracing +distributed_tracing.enabled = true + +# This setting enables log decoration, the forwarding of log events, +# and the collection of logging metrics if these sub-feature +# configurations are also enabled. If this setting is false, no +# logging instrumentation features are enabled. This can also be +# set using the NEW_RELIC_APPLICATION_LOGGING_ENABLED environment +# variable. +# application_logging.enabled = true + +# If true, the agent captures log records emitted by your application +# and forwards them to New Relic. `application_logging.enabled` must +# also be true for this setting to take effect. You can also set +# this using the NEW_RELIC_APPLICATION_LOGGING_FORWARDING_ENABLED +# environment variable. +# application_logging.forwarding.enabled = true + +# If true, the agent decorates logs with metadata to link to entities, +# hosts, traces, and spans. `application_logging.enabled` must also +# be true for this setting to take effect. This can also be set +# using the NEW_RELIC_APPLICATION_LOGGING_LOCAL_DECORATING_ENABLED +# environment variable. +# application_logging.local_decorating.enabled = true + +# If true, the agent captures metrics related to the log lines +# being sent up by your application. This can also be set +# using the NEW_RELIC_APPLICATION_LOGGING_METRICS_ENABLED +# environment variable. +# application_logging.metrics.enabled = true + +startup_timeout = 10.0 + +# --------------------------------------------------------------------------- + +# +# The application environments. These are specific settings which +# override the common environment settings. The settings related to a +# specific environment will be used when the environment argument to the +# newrelic.agent.initialize() function has been defined to be either +# "development", "test", "staging" or "production". +# + +[newrelic:development] +monitor_mode = false + +[newrelic:test] +monitor_mode = false + +[newrelic:staging] +app_name = Python Application (Staging) +monitor_mode = true + +[newrelic:production] +monitor_mode = true + +# --------------------------------------------------------------------------- diff --git a/requirements.txt b/requirements.txt index 671da52..342b75d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,5 @@ cryptography==41.0.6 pip==23.3 wheel==0.38.1 setuptools==65.5.1 -future~=0.18.2 \ No newline at end of file +future~=0.18.2 +newrelic==9.7.0 diff --git a/src/__main__.py b/src/__main__.py index 42e0fbd..6d9059a 100644 --- a/src/__main__.py +++ b/src/__main__.py @@ -1,4 +1,7 @@ #!/usr/bin/env python +import newrelic.agent +newrelic.agent.initialize('./newrelic.ini') + import optparse import os import sys @@ -179,7 +182,7 @@ def run( run_as_service(config, event_type_fields_mapping, numeric_fields_list) - +@newrelic.agent.background_task() def main(): print_info(f'Integration start. Using program arguments {sys.argv[1:]}') From 127069868e6e2510b5506b1360eca5d2afc293da Mon Sep 17 00:00:00 2001 From: Scott DeWitt Date: Wed, 13 Mar 2024 13:37:04 -0400 Subject: [PATCH 03/11] feat: optimize memory usage part 3: add unit tests --- src/newrelic_logging/cache.py | 11 +- src/newrelic_logging/pipeline.py | 171 ++-- src/newrelic_logging/util.py | 35 +- src/tests/__init__.py | 80 ++ src/tests/sample_event_records.json | 31 + src/tests/sample_log_lines.csv | 3 + src/tests/sample_log_lines.json | 48 + src/tests/sample_log_records.json | 28 + src/tests/test_pipeline.py | 1429 +++++++++++++++++++++++++++ src/tests/test_util.py | 138 +++ 10 files changed, 1884 insertions(+), 90 deletions(-) create mode 100644 src/tests/__init__.py create mode 100644 src/tests/sample_event_records.json create mode 100644 src/tests/sample_log_lines.csv create mode 100644 src/tests/sample_log_lines.json create mode 100644 src/tests/sample_log_records.json create mode 100644 src/tests/test_pipeline.py create mode 100644 src/tests/test_util.py diff --git a/src/newrelic_logging/cache.py b/src/newrelic_logging/cache.py index a0dc8b6..4438f77 100644 --- a/src/newrelic_logging/cache.py +++ b/src/newrelic_logging/cache.py @@ -62,20 +62,21 @@ def can_skip_downloading_logfile(self, record_id: str) -> bool: def load_cached_log_lines(self, record_id: str) -> None: try: if self.backend.exists(record_id): - self.cached_logs[record_id] = \ - self.backend.list_slice(record_id, 0, -1) - return + self.cached_logs[record_id] = \ + self.backend.list_slice(record_id, 0, -1) + return self.cached_logs[record_id] = ['init'] except Exception as e: raise CacheException(f'failed checking log record {record_id}: {e}') # Cache log + # @TODO this function assumes you have called load_cached_log_lines + # which isn't obvious. def check_and_set_log_line(self, record_id: str, row: dict) -> bool: row_id = row["REQUEST_ID"] - row_id_b = row_id.encode('utf-8') - if row_id_b in self.cached_logs[record_id]: + if row_id.encode('utf-8') in self.cached_logs[record_id]: return True self.cached_logs[record_id].append(row_id) diff --git a/src/newrelic_logging/pipeline.py b/src/newrelic_logging/pipeline.py index 310fb93..c53bfb5 100644 --- a/src/newrelic_logging/pipeline.py +++ b/src/newrelic_logging/pipeline.py @@ -1,7 +1,8 @@ from copy import deepcopy import csv -import datetime +from datetime import datetime import gc +import pytz from requests import Session from . import DataFormat, SalesforceApiException @@ -10,49 +11,76 @@ from .http_session import new_retry_session from .newrelic import NewRelic from .query import Query -from .telemetry import print_err, print_info -from .util import generate_record_id, get_row_timestamp, is_logfile_response +from .telemetry import print_info +from .util import generate_record_id, \ + is_logfile_response, \ + maybe_convert_str_to_num DEFAULT_CHUNK_SIZE = 4096 DEFAULT_MAX_ROWS = 1000 MAX_ROWS = 2000 +def init_fields_from_log_line( + record_event_type: str, + log_line: dict, + event_type_fields_mapping: dict, +) -> dict: + if record_event_type in event_type_fields_mapping: + attrs = {} + + for field in event_type_fields_mapping[record_event_type]: + attrs[field] = log_line[field] + + return attrs -def pack_csv_into_log( + return deepcopy(log_line) + + +def get_log_line_timestamp(log_line: dict) -> float: + epoch = log_line.get('TIMESTAMP') + + if epoch: + return pytz.utc.localize( + datetime.strptime(epoch, '%Y%m%d%H%M%S.%f') + ).replace(microsecond=0).timestamp() + + return datetime.utcnow().replace(microsecond=0).timestamp() + + +def pack_log_line_into_log( query: Query, record_id: str, record_event_type: str, - row: dict, - row_index: int, + log_line: dict, + line_no: int, event_type_fields_mapping: dict, ) -> dict: - attrs = {} - if record_event_type in event_type_fields_mapping: - for field in event_type_fields_mapping[record_event_type]: - attrs[field] = row[field] - else: - attrs = row + attrs = init_fields_from_log_line( + record_event_type, + log_line, + event_type_fields_mapping, + ) - timestamp = get_row_timestamp(row) + timestamp = int(get_log_line_timestamp(log_line)) attrs.pop('TIMESTAMP', None) attrs['LogFileId'] = record_id - actual_event_type = attrs.pop('EVENT_TYPE', "SFEvent") - new_event_type = query.get("event_type", actual_event_type) + actual_event_type = attrs.pop('EVENT_TYPE', 'SFEvent') + new_event_type = query.get('event_type', actual_event_type) attrs['EVENT_TYPE'] = new_event_type - timestamp_field_name = query.get("rename_timestamp", "timestamp") - attrs[timestamp_field_name] = int(timestamp) + timestamp_field_name = query.get('rename_timestamp', 'timestamp') + attrs[timestamp_field_name] = timestamp log_entry = { - 'message': "LogFile " + record_id + " row " + str(row_index), + 'message': f'LogFile {record_id} row {str(line_no)}', 'attributes': attrs } if timestamp_field_name == 'timestamp': - log_entry[timestamp_field_name] = int(timestamp) + log_entry[timestamp_field_name] = timestamp return log_entry @@ -108,7 +136,7 @@ def transform_log_lines( continue # Otherwise, pack it up for shipping and yield it for consumption - yield pack_csv_into_log( + yield pack_log_line_into_log( query, record_id, record_event_type, @@ -120,37 +148,35 @@ def transform_log_lines( row_index += 1 -def pack_event_into_log( +def pack_event_record_into_log( query: Query, record_id: str, - row: dict, -): + record: dict, +) -> dict: # Make a copy of it so we aren't modifying the row passed by the caller, and # set attributes appropriately - attrs = deepcopy(row) + attrs = deepcopy(record) if record_id: attrs['Id'] = record_id + message = query.get('event_type', 'SFEvent') + if 'attributes' in attrs and type(attrs['attributes']) == dict: + attributes = attrs.pop('attributes') + if 'type' in attributes and type(attributes['type']) == str: + attrs['EVENT_TYPE'] = message = \ + query.get('event_type', attributes['type']) + timestamp_attr = query.get('timestamp_attr', 'CreatedDate') if timestamp_attr in attrs: created_date = attrs[timestamp_attr] + message += f' {created_date}' timestamp = int(datetime.strptime( - created_date, '%Y-%m-%dT%H:%M:%S.%f%z').timestamp() * 1000 + created_date, + '%Y-%m-%dT%H:%M:%S.%f%z').timestamp() * 1000, ) else: - created_date = "" timestamp = int(datetime.now().timestamp() * 1000) - message = query.get('event_type', 'SFEvent') - if 'attributes' in attrs and type(attrs['attributes']) == dict: - attributes = attrs.pop('attributes', []) - if 'type' in attributes and type(attributes['type']) == str: - attrs['EVENT_TYPE'] = message = \ - query.get('event_type', attributes['type']) - - if created_date != "": - message = message + " " + created_date - timestamp_field_name = query.get('rename_timestamp', 'timestamp') attrs[timestamp_field_name] = int(timestamp) @@ -167,26 +193,35 @@ def pack_event_into_log( def transform_event_records(iter, query: Query, data_cache: DataCache): # iter here is a list which does mean it's entirely held in memory but these - # are event records not log lines so hopefully it is # not as bad. + # are event records not log lines so hopefully it is not as bad. # @TODO figure out if we can stream event records - for row in iter: - record_id = row['Id'] if 'Id' in row \ - else generate_record_id(query.get('id', []), row) + for record in iter: + config = query.get_config() + + record_id = record['Id'] if 'Id' in record \ + else generate_record_id( + config['id'] if 'id' in config else [], + record, + ) # If we've already seen this event record, skip it. if data_cache and data_cache.check_and_set_event_id(record_id): - return None + continue # Build a New Relic log record from the SF event record - yield pack_event_into_log( + yield pack_event_record_into_log( query, record_id, - row, - data_cache, + record, ) -def load_as_logs(iter, new_relic: NewRelic, labels: dict, max_rows: int): +def load_as_logs( + iter, + new_relic: NewRelic, + labels: dict, + max_rows: int, +) -> None: nr_session = new_retry_session() logs = [] @@ -224,29 +259,24 @@ def send_logs(): gc.collect() -def pack_log_into_event(log: dict, labels: dict, numeric_fields_list: set): +def pack_log_into_event( + log: dict, + labels: dict, + numeric_fields_list: set, +) -> dict: log_event = {} attributes = log['attributes'] - for event_name in attributes: - # currently no need to modify as we did not see any special chars - # that need to be removed - modified_event_name = event_name - event_value = attributes[event_name] - if event_name in numeric_fields_list: - if event_value: - try: - log_event[modified_event_name] = int(event_value) - except (TypeError, ValueError) as _: - try: - log_event[modified_event_name] = float(event_value) - except (TypeError, ValueError) as _: - print_err(f'Type conversion error for {event_name}[{event_value}]') - log_event[modified_event_name] = event_value - else: - log_event[modified_event_name] = 0 - else: - log_event[modified_event_name] = event_value + for key in attributes: + value = attributes[key] + + if key in numeric_fields_list: + log_event[key] = \ + maybe_convert_str_to_num(value) if value \ + else 0 + continue + + log_event[key] = value log_event.update(labels) log_event['eventType'] = log_event.get('EVENT_TYPE', "UnknownSFEvent") @@ -260,7 +290,7 @@ def load_as_events( labels: dict, max_rows: int, numeric_fields_list: set, -): +) -> None: nr_session = new_retry_session() events = [] @@ -304,7 +334,7 @@ def send_events(): def load_data( - logs: dict, + logs, new_relic: NewRelic, data_format: DataFormat, labels: dict, @@ -407,7 +437,6 @@ def process_event_records( records, query, self.data_cache, - self.config.get('max_rows', DEFAULT_MAX_ROWS) ), self.new_relic, self.data_format, @@ -435,12 +464,16 @@ def execute( record, ) + if self.data_cache: + self.data_cache.flush() + return self.process_event_records(query, records) # Flush the cache - self.data_cache.flush() + if self.data_cache: + self.data_cache.flush() def New( config: Config, diff --git a/src/newrelic_logging/util.py b/src/newrelic_logging/util.py index dd66414..639ffab 100644 --- a/src/newrelic_logging/util.py +++ b/src/newrelic_logging/util.py @@ -1,34 +1,26 @@ from datetime import datetime, timedelta import hashlib -import pytz +from typing import Union + +from .telemetry import print_warn + def is_logfile_response(records): if len(records) > 0: return 'LogFile' in records[0] - else: - return True - - -def get_row_timestamp(row): - epoch = row.get('TIMESTAMP') - if epoch: - return pytz.utc.localize( - datetime.strptime(epoch, '%Y%m%d%H%M%S.%f') - ).replace(microsecond=0).timestamp() + return True - return datetime.utcnow().replace(microsecond=0).timestamp() - -def generate_record_id(id_keys: list[str], row: dict) -> str: +def generate_record_id(id_keys: list[str], record: dict) -> str: compound_id = '' for key in id_keys: - if key not in row: + if key not in record: raise Exception( f'error building compound id, key \'{key}\' not found' ) - compound_id = compound_id + str(row.get(key, '')) + compound_id = compound_id + str(record.get(key, '')) if compound_id != '': m = hashlib.sha3_256() @@ -38,6 +30,17 @@ def generate_record_id(id_keys: list[str], row: dict) -> str: return '' +def maybe_convert_str_to_num(val: str) -> Union[int, str, float]: + try: + return int(val) + except (TypeError, ValueError) as _: + try: + return float(val) + except (TypeError, ValueError) as _: + print_warn(f'Type conversion error for "{val}"') + return val + + # NOTE: this sandbox can be jailbroken using the trick to exec statements inside # an exec block, and run an import (and other tricks): # https://book.hacktricks.xyz/generic-methodologies-and-resources/python/bypass-python-sandboxes#operators-and-short-tricks diff --git a/src/tests/__init__.py b/src/tests/__init__.py new file mode 100644 index 0000000..8b46fd7 --- /dev/null +++ b/src/tests/__init__.py @@ -0,0 +1,80 @@ +from requests import Session + +from newrelic_logging.config import Config + + +class QueryStub: + def __init__(self, config: dict): + self.config = Config(config) + + def get(self, key: str, default = None): + return self.config.get(key, default) + + def get_config(self): + return self.config + + def execute(): + pass + + +class ResponseStub: + def __init__(self, status_code, reason, text, lines): + self.status_code = status_code + self.reason = reason + self.text = text + self.lines = lines + + def iter_lines(self, *args, **kwargs): + yield from self.lines + + +class SessionStub: + def __init__(self, lines): + self.response = None + + def get(self, *args, **kwargs): + return self.response + + +class DataCacheStub: + def __init__( + self, + cached_logs = {}, + cached_events = [], + skip_record_ids = [], + cached_log_lines = {}, + ): + self.cached_logs = cached_logs + self.cached_events = cached_events + self.skip_record_ids = skip_record_ids + self.cached_log_lines = cached_log_lines + self.flush_called = False + + def can_skip_downloading_logfile(self, record_id: str) -> bool: + return record_id in self.skip_record_ids + + def load_cached_log_lines(self, record_id: str) -> None: + if record_id in self.cached_log_lines: + self.cached_logs[record_id] = self.cached_log_lines[record_id] + + def check_and_set_log_line(self, record_id: str, row: dict) -> bool: + return record_id in self.cached_logs and \ + row['REQUEST_ID'] in self.cached_logs[record_id] + + def check_and_set_event_id(self, record_id: str) -> bool: + return record_id in self.cached_events + + def flush(self) -> None: + self.flush_called = True + + +class NewRelicStub: + def __init__(self): + self.logs = [] + self.events = [] + + def post_logs(self, session: Session, data: list[dict]) -> None: + self.logs.append(data) + + def post_events(self, session: Session, events: list[dict]) -> None: + self.events.append(events) diff --git a/src/tests/sample_event_records.json b/src/tests/sample_event_records.json new file mode 100644 index 0000000..7f2b206 --- /dev/null +++ b/src/tests/sample_event_records.json @@ -0,0 +1,31 @@ +[ + { + "attributes": { + "type": "Account", + "url": "/services/data/v58.0/sobjects/Account/12345" + }, + "Id": "000012345", + "Name": "My Account", + "BillingCity": null, + "CreatedDate": "2024-03-11T00:00:00.000+0000" + }, + { + "attributes": { + "type": "Account", + "url": "/services/data/v58.0/sobjects/Account/54321" + }, + "Id": "000054321", + "Name": "My Other Account", + "BillingCity": null, + "CreatedDate": "2024-03-10T00:00:00.000+0000" + }, + { + "attributes": { + "type": "Account", + "url": "/services/data/v58.0/sobjects/Account/00000" + }, + "Name": "My Last Account", + "BillingCity": null, + "CreatedDate": "2024-03-09T00:00:00.000+0000" + } +] diff --git a/src/tests/sample_log_lines.csv b/src/tests/sample_log_lines.csv new file mode 100644 index 0000000..f8a120a --- /dev/null +++ b/src/tests/sample_log_lines.csv @@ -0,0 +1,3 @@ +"EVENT_TYPE","TIMESTAMP","REQUEST_ID","ORGANIZATION_ID","USER_ID","RUN_TIME","CPU_TIME","URI","SESSION_KEY","LOGIN_KEY","TYPE","METHOD","SUCCESS","STATUS_CODE","TIME","REQUEST_SIZE","RESPONSE_SIZE","URL","TIMESTAMP_DERIVED","USER_ID_DERIVED","CLIENT_IP","URI_ID_DERIVED" +"ApexCallout","20240311160000.000","YYZ:abcdef123456","001122334455667","000000001111111","2112","10","TEST-LOG-1","","","REST","POST","1","200","1234","8192","4096","""https://test.local.test""","2024-03-11T16:00:00.000Z","001122334455667","","" +"ApexCallout","20240311170000.000","YYZ:fedcba654321","776655443322110","111111110000000","5150","20","TEST-LOG-2","","","REST","POST","1","200","4321","8192","4096","""https://test.local.test""","2024-03-11T17:00:00.000Z","776655443322110","","" diff --git a/src/tests/sample_log_lines.json b/src/tests/sample_log_lines.json new file mode 100644 index 0000000..10d1eff --- /dev/null +++ b/src/tests/sample_log_lines.json @@ -0,0 +1,48 @@ +[{ + "EVENT_TYPE": "ApexCallout", + "TIMESTAMP": "20240311160000.000", + "REQUEST_ID": "YYZ:abcdef123456", + "ORGANIZATION_ID": "000000001111111", + "USER_ID": "001122334455667", + "RUN_TIME": "2112", + "CPU_TIME": "10", + "URI": "TEST-LOG-1", + "SESSION_KEY": "", + "LOGIN_KEY": "", + "TYPE": "REST", + "METHOD": "POST", + "SUCCESS": "1", + "STATUS_CODE": "200", + "TIME": "1234", + "REQUEST_SIZE": "8192", + "RESPONSE_SIZE": "4096", + "URL": "\"https://test.local.test\"", + "TIMESTAMP_DERIVED": "2024-03-11T16:00:00.000Z", + "USER_ID_DERIVED": "001122334455667", + "CLIENT_IP": "", + "URI_ID_DERIVED": "" +}, +{ + "EVENT_TYPE": "ApexCallout", + "TIMESTAMP": "20240311170000.000", + "REQUEST_ID": "YYZ:fedcba654321", + "ORGANIZATION_ID": "111111110000000", + "USER_ID": "776655443322110", + "RUN_TIME": "5150", + "CPU_TIME": "20", + "URI": "TEST-LOG-2", + "SESSION_KEY": "", + "LOGIN_KEY": "", + "TYPE": "REST", + "METHOD": "POST", + "SUCCESS": "1", + "STATUS_CODE": "200", + "TIME": "4321", + "REQUEST_SIZE": "8192", + "RESPONSE_SIZE": "4096", + "URL": "\"https://test.local.test\"", + "TIMESTAMP_DERIVED": "2024-03-11T17:00:00.000Z", + "USER_ID_DERIVED": "776655443322110", + "CLIENT_IP": "", + "URI_ID_DERIVED": "" +}] diff --git a/src/tests/sample_log_records.json b/src/tests/sample_log_records.json new file mode 100644 index 0000000..b48cdd5 --- /dev/null +++ b/src/tests/sample_log_records.json @@ -0,0 +1,28 @@ +[ + { + "attributes": { + "type": "EventLogFile", + "url": "/services/data/v52.0/sobjects/EventLogFile/00001111AAAABBBB" + }, + "Id": "00001111AAAABBBB", + "EventType": "ApexCallout", + "CreatedDate": "2024-03-11T15:00:00.000+0000", + "LogDate": "2024-03-11T02:00:00.000+0000", + "Interval": "Hourly", + "LogFile": "/services/data/v52.0/sobjects/EventLogFile/00001111AAAABBBB/LogFile", + "Sequence": 1 + }, + { + "attributes": { + "type": "EventLogFile", + "url": "/services/data/v52.0/sobjects/EventLogFile/00002222AAAABBBB" + }, + "Id": "00002222AAAABBBB", + "EventType": "ApexCallout", + "CreatedDate": "2024-03-11T16:00:00.000+0000", + "LogDate": "2024-03-11T03:00:00.000+0000", + "Interval": "Hourly", + "LogFile": "/services/data/v52.0/sobjects/EventLogFile/00002222AAAABBBB/LogFile", + "Sequence": 2 + } +] diff --git a/src/tests/test_pipeline.py b/src/tests/test_pipeline.py new file mode 100644 index 0000000..52d4f64 --- /dev/null +++ b/src/tests/test_pipeline.py @@ -0,0 +1,1429 @@ +import copy +from datetime import datetime +import json +import pytz +import unittest + +from newrelic_logging import \ + config, \ + DataFormat, \ + pipeline, \ + util, \ + SalesforceApiException +from . import \ + DataCacheStub, \ + NewRelicStub, \ + QueryStub, \ + ResponseStub, \ + SessionStub + +class TestPipeline(unittest.TestCase): + def setUp(self): + with open('./tests/sample_log_lines.csv') as stream: + self.log_rows = stream.readlines() + + with open('./tests/sample_log_lines.json') as stream: + self.log_lines = json.load(stream) + + with open('./tests/sample_event_records.json') as stream: + self.event_records = json.load(stream) + + with open('./tests/sample_log_records.json') as stream: + self.log_records = json.load(stream) + + def test_init_fields_from_log_line(self): + ''' + given: an event type, log line, and event fields mapping + when: there is no matching mapping for the event type in the event + fields mapping + then: copy all fields in log line + ''' + + # setup + log_line = { + 'foo': 'bar', + 'beep': 'boop', + } + + # execute + attrs = pipeline.init_fields_from_log_line('ApexCallout', log_line, {}) + + # verify + self.assertTrue(len(attrs) == 2) + self.assertTrue(attrs['foo'] == 'bar') + self.assertTrue(attrs['beep'] == 'boop') + + ''' + given: an event type, log line, and event fields mapping + when: there is a matching mapping for the event type in the event + fields mapping + then: copy only the fields in the event fields mapping + ''' + + # execute + attrs = pipeline.init_fields_from_log_line( + 'ApexCallout', + log_line, + { 'ApexCallout': ['foo'] } + ) + + # verify + self.assertTrue(len(attrs) == 1) + self.assertTrue(attrs['foo'] == 'bar') + self.assertTrue(not 'beep' in attrs) + + def test_get_log_line_timestamp(self): + ''' + given: a log line + when: there is no TIMESTAMP attribute + then: return the current timestamp + ''' + + # setup + now = datetime.utcnow().replace(microsecond=0) + + # execute + ts = pipeline.get_log_line_timestamp({}) + + # verify + self.assertEqual(now.timestamp(), ts) + + ''' + given: a log line + when: there is a TIMESTAMP attribute + then: parse the string in the format YYYYMMDDHHmmss.FFF and return + the representative timestamp + ''' + + # setup + epoch = now.strftime('%Y%m%d%H%M%S.%f') + + # execute + ts1 = pytz.utc.localize(now).replace(microsecond=0).timestamp() + ts2 = pipeline.get_log_line_timestamp({ 'TIMESTAMP': epoch }) + + # verify + self.assertEqual(ts1, ts2) + + def test_pack_log_line_into_log(self): + ''' + given: a query object, record ID, event type, log line, line number, + and event fields mapping + when: there is a TIMESTAMP and EVENT_TYPE field, no query options, and + no matching event mapping + then: return a log entry with the message "LogFile $ID row $LINENO", + and attributes dict containing all fields from the log line, + an EVENT_TYPE field with the log line event type, a TIMESTAMP + field with the log line timestamp, and a timestamp field with + the timestamp epoch value + ''' + + # setup + query = QueryStub({}) + + # execute + log = pipeline.pack_log_line_into_log( + query, + '00001111AAAABBBB', + 'ApexCallout', + self.log_lines[0], + 0, + {}, + ) + + # verify + self.assertTrue('message' in log) + self.assertTrue('attributes' in log) + self.assertEqual(log['message'], 'LogFile 00001111AAAABBBB row 0') + + attrs = log['attributes'] + self.assertEqual(attrs['EVENT_TYPE'], 'ApexCallout') + self.assertEqual(attrs['LogFileId'], '00001111AAAABBBB') + self.assertTrue('timestamp' in attrs) + ts = pipeline.get_log_line_timestamp({ + 'TIMESTAMP': '20240311160000.000' + }) + self.assertEqual(int(ts), attrs['timestamp']) + self.assertTrue('REQUEST_ID' in attrs) + self.assertTrue('RUN_TIME' in attrs) + self.assertTrue('CPU_TIME' in attrs) + self.assertEqual('YYZ:abcdef123456', attrs['REQUEST_ID']) + self.assertEqual('2112', attrs['RUN_TIME']) + self.assertEqual('10', attrs['CPU_TIME']) + + ''' + given: a query object, record ID, event type, log line, line number, + and event fields mapping + when: there is a TIMESTAMP and EVENT_TYPE field, no matching event + mapping, and the event_type and rename_timestamp query options + then: return the same as case 1 but with the event type specified + in the query options, the epoch value in the field specified + in the query options and no timestamp field + ''' + + # setup + query = QueryStub({ + 'event_type': 'CustomSFEvent', + 'rename_timestamp': 'custom_timestamp', + }) + + # execute + log = pipeline.pack_log_line_into_log( + query, + '00001111AAAABBBB', + 'ApexCallout', + self.log_lines[0], + 0, + {}, + ) + + # verify + attrs = log['attributes'] + self.assertTrue('custom_timestamp' in attrs) + self.assertEqual(attrs['EVENT_TYPE'], 'CustomSFEvent') + self.assertEqual(attrs['custom_timestamp'], int(ts)) + self.assertTrue(not 'timestamp' in log) + + def test_export_log_lines(self): + ''' + given: an http session, url, access token, and chunk size + when: the response produces a non-200 status code + then: raise a SalesforceApiException + ''' + + # setup + session = SessionStub([]) + session.response = ResponseStub(500, 'Error', '', []) + + # execute/verify + with self.assertRaises(SalesforceApiException): + pipeline.export_log_lines(session, '', '', 100) + + ''' + given: an http session, url, access token, and chunk size + when: the response produces a 200 status code + then: return a generator iterator that yields one line of data at a time + ''' + + # setup + session.response = ResponseStub(200, 'OK', '', self.log_rows) + + #execute + response = pipeline.export_log_lines(session, '', '', 100) + + lines = [] + for line in response: + lines.append(line) + + # verify + self.assertEqual(len(lines), 3) + self.assertEqual(lines[0], self.log_rows[0]) + self.assertEqual(lines[1], self.log_rows[1]) + self.assertEqual(lines[2], self.log_rows[2]) + + def test_transform_log_lines(self): + ''' + given: an iterable of log rows, query, record id, event type, event + types mapping and no data cache + when: the response produces a 200 status code + then: return a generator iterator that yields one New Relic log + object for each row except the header row + ''' + + # execute + logs = pipeline.transform_log_lines( + self.log_rows, + QueryStub({}), + '00001111AAAABBBB', + 'ApexCallout', + None, + {}, + ) + + l = [] + + for log in logs: + l.append(log) + + # verify + self.assertEqual(len(l), 2) + self.assertTrue('message' in l[0]) + self.assertTrue('attributes' in l[0]) + self.assertEqual(l[0]['message'], 'LogFile 00001111AAAABBBB row 0') + attrs = l[0]['attributes'] + self.assertTrue('EVENT_TYPE' in attrs) + self.assertTrue('timestamp' in attrs) + self.assertEqual(1710172800, attrs['timestamp']) + + self.assertTrue('message' in l[1]) + self.assertTrue('attributes', l[1]) + self.assertEqual(l[1]['message'], 'LogFile 00001111AAAABBBB row 1') + attrs = l[1]['attributes'] + self.assertTrue('EVENT_TYPE' in attrs) + self.assertTrue('timestamp' in attrs) + self.assertEqual(1710176400, attrs['timestamp']) + + ''' + given: an iterable of log rows, query, record id, event type, event + types mapping and a data cache + when: the data cache contains the REQUEST_ID for some of the log lines + then: return a generator iterator that yields one New Relic log + object for each row with a REQUEST_ID + ''' + + # execute + logs = pipeline.transform_log_lines( + self.log_rows, + QueryStub({}), + '00001111AAAABBBB', + 'ApexCallout', + DataCacheStub({ + '00001111AAAABBBB': [ 'YYZ:abcdef123456' ] + }), + {}, + ) + + l = [] + + for log in logs: + l.append(log) + + # verify + self.assertEqual(len(l), 1) + self.assertEqual( + l[0]['attributes']['REQUEST_ID'], + 'YYZ:fedcba654321' + ) + + + def test_pack_event_record_into_log(self): + ''' + given: a query, record id and event record + when: there are no query options and the event record contains a 'type' + field in 'attributes' + then: return a log with the 'message' attribute set to the event type + specified in the 'type' field + the created date, all attributes + from the original event record minus the 'attributes' field set in + the log 'attributes' field as well as the passed record id, and a + 'timestamp' field with the epoch value representing the + 'CreatedDate' field + ''' + + # setup + query = QueryStub({}) + + # execute + log = pipeline.pack_event_record_into_log( + query, + '00001111AAAABBBB', + self.event_records[0] + ) + + # verify + created_date = self.event_records[0]['CreatedDate'] + timestamp = int(datetime.strptime( + created_date, + '%Y-%m-%dT%H:%M:%S.%f%z').timestamp() * 1000, + ) + self.assertTrue('message' in log) + self.assertTrue('attributes' in log) + self.assertTrue('timestamp' in log) + self.assertEqual(log['message'], f'Account {created_date}') + self.assertEqual(timestamp, log['timestamp']) + + attrs = log['attributes'] + self.assertTrue(not 'attributes' in attrs) + self.assertTrue('Id' in attrs) + self.assertTrue('Name' in attrs) + self.assertTrue('BillingCity' in attrs) + self.assertTrue('CreatedDate' in attrs) + self.assertEqual('00001111AAAABBBB', attrs['Id']) + self.assertEqual('My Account', attrs['Name']) + self.assertEqual(None, attrs['BillingCity']) + self.assertEqual('2024-03-11T00:00:00.000+0000', attrs['CreatedDate']) + + ''' + given: a query, record id, and an event record + when: the record id is empty and there are no query options and the + event record contains a 'type' field in 'attributes' + then: return a log as in use case 1 but with no 'Id' value in the + log 'attributes' field + ''' + + # setup + event_record = copy.deepcopy(self.event_records[0]) + event_record.pop('Id') + + # execute + log = pipeline.pack_event_record_into_log( + query, + None, + event_record + ) + + # verify + self.assertTrue(not 'Id' in log['attributes']) + + ''' + given: a query, record id, and an event record + when: the 'event_type' query option is specified + then: return a log as in use case 1 but with the event type in the log + message set to the custom event type specified in the 'event_type' + query option plus the created date. + ''' + + # setup + event_record = copy.deepcopy(self.event_records[0]) + + # execute + log = pipeline.pack_event_record_into_log( + QueryStub({ 'event_type': 'CustomEvent' }), + '00001111AAAABBBB', + event_record + ) + + # verify + self.assertEqual(log['message'], f'CustomEvent {created_date}') + + ''' + given: a query, record id, and an event record + when: the event record does not contain an 'attributes' field + then: return a log as in use case 1 but with the event type in the log + message set to the default event type specified in the + 'event_type' query option plus the created date. + ''' + + # setup + event_record = copy.deepcopy(self.event_records[0]) + event_record.pop('attributes') + + # execute + log = pipeline.pack_event_record_into_log( + query, + '00001111AAAABBBB', + event_record + ) + + # verify + self.assertEqual(log['message'], f'SFEvent {created_date}') + + ''' + given: a query, record id, and an event record + when: the event record does contains an 'attributes' field but it is not + a dictionary + then: return a log as in the previous use case + ''' + + # setup + event_record = copy.deepcopy(self.event_records[0]) + event_record['attributes'] = 'test' + + # execute + log = pipeline.pack_event_record_into_log( + query, + '00001111AAAABBBB', + event_record + ) + + # verify + self.assertEqual(log['message'], f'SFEvent {created_date}') + + ''' + given: a query, record id, and an event record + when: the event record does contains an 'type' field in the 'attributes' + field + then: return a log as in the previous use case + ''' + + # setup + event_record = copy.deepcopy(self.event_records[0]) + event_record['attributes'].pop('type') + + + # execute + log = pipeline.pack_event_record_into_log( + query, + '00001111AAAABBBB', + event_record + ) + + # verify + self.assertEqual(log['message'], f'SFEvent {created_date}') + + ''' + given: a query, record id, and an event record + when: the event record contains a 'type' field in the 'attributes' + field but it is not a string + then: return a log as in the previous use case + ''' + + # setup + event_record = copy.deepcopy(self.event_records[0]) + event_record['attributes']['type'] = 12345 + + # execute + log = pipeline.pack_event_record_into_log( + query, + '00001111AAAABBBB', + event_record + ) + + # verify + self.assertEqual(log['message'], f'SFEvent {created_date}') + + ''' + given: a query, record id, and an event record + when: the 'timestamp_attr' query option is specified but the specified + attribute name is not in the event record + then: return a log as in use case 1 but the message does not contain a + created date and contains a 'timestamp' field set to the current + time. + ''' + + # setup + event_record = copy.deepcopy(self.event_records[0]) + + # execute + log = pipeline.pack_event_record_into_log( + QueryStub({ 'timestamp_attr': 'NotPresent' }), + '00001111AAAABBBB', + event_record + ) + + # verify + timestamp = int(datetime.now().timestamp() * 1000) + self.assertEqual(log['message'], f'Account') + self.assertTrue('timestamp' in log) + self.assertTrue(log['timestamp'] <= timestamp) + + ''' + given: a query, record id, and an event record + when: no query options are specified and the event record does not + contain a 'CreatedDate' field + then: return the same as the previous use case + ''' + + # setup + event_record = copy.deepcopy(self.event_records[0]) + event_record.pop('CreatedDate') + + # execute + log = pipeline.pack_event_record_into_log( + query, + '00001111AAAABBBB', + event_record + ) + + # verify + timestamp = int(datetime.now().timestamp() * 1000) + self.assertEqual(log['message'], f'Account') + self.assertTrue('timestamp' in log) + self.assertTrue(log['timestamp'] <= timestamp) + + ''' + given: a query, record id, and an event record + when: the 'rename_timestamp' query options is set + then: return the same as use case 1 but with a field with the name + specified in the 'rename_timestamp' query option set to the + current time and no 'timestamp' field + ''' + + # setup + event_record = copy.deepcopy(self.event_records[0]) + + # execute + log = pipeline.pack_event_record_into_log( + QueryStub({ 'rename_timestamp': 'custom_timestamp' }), + '00001111AAAABBBB', + event_record + ) + + # verify + timestamp = int(datetime.now().timestamp() * 1000) + self.assertTrue('custom_timestamp' in log['attributes']) + self.assertTrue(not 'timestamp' in log) + self.assertTrue(log['attributes']['custom_timestamp'] <= timestamp) + + def test_transform_event_records(self): + ''' + given: an event record, query, and no data cache + when: the record contains an 'Id' field + then: return a log with the 'Id' attribute in the log 'attributes' set + to the value of the 'Id' attribute + when: the record does not contain an 'Id' field and the 'id' query + option is not set + then: return a log with no 'Id' attribute in the log 'attributes' field + ''' + + # execute + logs = pipeline.transform_event_records( + self.event_records, + QueryStub({}), + None, + ) + + l = [] + + for log in logs: + l.append(log) + + # verify + self.assertEqual(len(l), 3) + self.assertTrue('Id' in l[0]['attributes']) + self.assertTrue('Id' in l[1]['attributes']) + self.assertEqual(l[0]['attributes']['Id'], '000012345') + self.assertEqual(l[1]['attributes']['Id'], '000054321') + self.assertTrue(not 'Id' in l[2]['attributes']) + + ''' + given: an event record, query, and no data cache + when: the record does not contain an 'Id' field, the 'id' query option + is set and the record contains the field set in the 'id' query + option + then: return a log with a generated id + ''' + + # execute + logs = pipeline.transform_event_records( + self.event_records[2:], + QueryStub({ 'id': ['Name'] }), + None, + ) + + l = [] + + for log in logs: + l.append(log) + + # verify + customId = util.generate_record_id([ 'Name' ], self.event_records[2]) + self.assertTrue('Id' in l[0]['attributes']) + self.assertEqual(l[0]['attributes']['Id'], customId) + + ''' + given: an event record, query, and a data cache + when: the data cache contains cached record IDs + then: return only logs with record ids that do not match those in the + cache + ''' + + # execute + logs = pipeline.transform_event_records( + self.event_records, + QueryStub({}), + DataCacheStub( + {}, + [ '000012345', '000054321' ] + ), + ) + + l = [] + + for log in logs: + l.append(log) + + # verify + self.assertEqual(len(l), 1) + self.assertEqual( + l[0]['attributes']['Name'], + 'My Last Account', + ) + + def test_load_as_logs(self): + def logs(n = 50): + for i in range(0, n): + yield { + 'message': f'log {i}', + 'attributes': { + 'EVENT_TYPE': f'SFEvent{i}', + 'REQUEST_ID': f'abcdef-{i}', + }, + } + + ''' + given: an generator iterator of logs, newrelic instance, set of labels, + and a max rows value + when: their are less than the maximum number of rows, n + then: a single Logs API post should be made with a 'common' property + that contains all the labels and a 'logs' property that contains + n log entries + ''' + + # setup + labels = { 'foo': 'bar' } + newrelic = NewRelicStub() + + # preconditions + self.assertEqual(len(newrelic.logs), 0) + + # execute + pipeline.load_as_logs( + logs(), + newrelic, + labels, + pipeline.DEFAULT_MAX_ROWS, + ) + + # verify + self.assertEqual(len(newrelic.logs), 1) + l = newrelic.logs[0] + self.assertEqual(len(l), 1) + l = l[0] + self.assertTrue('logs' in l) + self.assertTrue('common' in l) + self.assertTrue(type(l['common']) is dict) + self.assertTrue('foo' in l['common']) + self.assertEqual(l['common']['foo'], 'bar') + self.assertEqual(len(l['logs']), 50) + for i, log in enumerate(l['logs']): + self.assertTrue('message' in log) + self.assertEqual(log['message'], f'log {i}') + self.assertTrue('attributes' in log) + self.assertTrue('EVENT_TYPE' in log['attributes']) + self.assertEqual(log['attributes']['EVENT_TYPE'], f'SFEvent{i}') + self.assertTrue('REQUEST_ID' in log['attributes']) + self.assertEqual(log['attributes']['REQUEST_ID'], f'abcdef-{i}') + + ''' + given: an generator iterator of logs, newrelic instance, set of labels, + and a max rows value + when: their are more than the maximum number of rows, n + then: floor(n / max) Logs API posts should be made each containing max + logs in the 'logs' property and a 'common' property that contains + all the labels. IFF n % max > 0, an additional Logs API post is + made containing n % max logs in the 'logs' property and a 'common' + property that contains all the labels. + ''' + + # setup + newrelic = NewRelicStub() + + # preconditions + self.assertEqual(len(newrelic.logs), 0) + + # execute + pipeline.load_as_logs( + logs(150), + newrelic, + labels, + 50, + ) + + # verify + self.assertEqual(len(newrelic.logs), 3) + for i in range(0, 3): + l = newrelic.logs[i] + self.assertEqual(len(l), 1) + l = l[0] + self.assertTrue('logs' in l) + self.assertTrue('common' in l) + self.assertEqual(len(l['logs']), 50) + + # setup + newrelic = NewRelicStub() + + # preconditions + self.assertEqual(len(newrelic.logs), 0) + + # execute + pipeline.load_as_logs( + logs(53), + newrelic, + labels, + 50, + ) + + # verify + self.assertEqual(len(newrelic.logs), 2) + l = newrelic.logs[0] + self.assertEqual(len(l), 1) + l = l[0] + self.assertTrue('logs' in l) + self.assertTrue('common' in l) + self.assertEqual(len(l['logs']), 50) + l = newrelic.logs[1] + self.assertEqual(len(l), 1) + l = l[0] + self.assertTrue('logs' in l) + self.assertTrue('common' in l) + self.assertEqual(len(l['logs']), 3) + + def test_pack_log_into_event(self): + ''' + given: a single log, set of labels, and a set of numeric field names + when: the set of numeric field names is the empty set and the log + contains an 'EVENT_TYPE' property + then: return a single event with a property for each attribute specified + in the 'attributes' field of the log, a property for each label, + and an 'eventType' property set to the value of the 'EVENT_TYPE' + property of the log entry + ''' + + # setup + log = { + 'message': 'Foo and Bar', + 'attributes': self.log_lines[0] + } + + # execute + event = pipeline.pack_log_into_event( + log, + { 'foo': 'bar' }, + set(), + ) + + # verify + self.assertTrue('eventType' in event) + self.assertEqual(event['eventType'], self.log_lines[0]['EVENT_TYPE']) + self.assertTrue('foo' in event) + self.assertEqual(event['foo'], 'bar') + self.assertEqual(len(event), len(self.log_lines[0]) + 2) + for k in self.log_lines[0]: + self.assertTrue(k in event) + self.assertEqual(event[k], self.log_lines[0][k]) + + ''' + given: a single log, set of labels, and a set of numeric field names + when: the set of numeric field names is not empty + then: return the same as use case 1 except each property in the returned + event matching a property in the numeric field names set is + converted to a number and non-numeric values are left as is. + ''' + + # execute + event = pipeline.pack_log_into_event( + log, + { 'foo': 'bar' }, + set(['RUN_TIME', 'CPU_TIME', 'SUCCESS', 'URI']), + ) + + # verify + self.assertEqual(len(event), len(self.log_lines[0]) + 2) + self.assertTrue(type(event['RUN_TIME']) == int) + self.assertTrue(type(event['CPU_TIME']) == int) + self.assertTrue(type(event['SUCCESS']) == int) + self.assertTrue(type(event['URI']) == str) + + ''' + given: a single log, set of labels, and a set of numeric field names + when: the set of numeric field names is the empty set and the log + does not contain an 'EVENT_TYPE' property + then: return the same as use case 1 except the 'eventType' attribute is + set to the default event name. + ''' + + # setup + log_lines = copy.deepcopy(self.log_lines) + + del log_lines[0]['EVENT_TYPE'] + + log['attributes'] = log_lines[0] + + # execute + event = pipeline.pack_log_into_event( + log, + { 'foo': 'bar' }, + set(), + ) + + # verify + self.assertEqual(len(event), len(log_lines[0]) + 2) + self.assertTrue('eventType' in event) + self.assertEqual(event['eventType'], 'UnknownSFEvent') + + def test_load_as_events(self): + def logs(n = 50): + for i in range(0, n): + yield { + 'message': f'log {i}', + 'attributes': { + 'EVENT_TYPE': f'SFEvent{i}', + 'REQUEST_ID': f'abcdef-{i}', + }, + } + + ''' + given: an generator iterator of logs, newrelic instance, set of labels, + max rows value, and a set of numeric field names + when: their are less than the maximum number of rows, n + then: a single Events API post should be made with n events + ''' + + # setup + labels = { 'foo': 'bar' } + newrelic = NewRelicStub() + + # preconditions + self.assertEqual(len(newrelic.events), 0) + + # execute + pipeline.load_as_events( + logs(), + newrelic, + labels, + pipeline.DEFAULT_MAX_ROWS, + set(), + ) + + # verify + self.assertEqual(len(newrelic.events), 1) + l = newrelic.events[0] + self.assertEqual(len(l), 50) + for i, event in enumerate(l): + self.assertTrue('foo' in event) + self.assertEqual(event['foo'], 'bar') + self.assertTrue('EVENT_TYPE' in event) + self.assertEqual(event['EVENT_TYPE'], f'SFEvent{i}') + self.assertTrue('REQUEST_ID' in event) + self.assertEqual(event['REQUEST_ID'], f'abcdef-{i}') + + ''' + given: an generator iterator of logs, newrelic instance, set of labels, + max rows value, and a set of numeric field names + when: their are more than the maximum number of rows, n + then: floor(n / max) Events API posts should be made each containing max + events. IFF n % max > 0, an additional Events API post is made + containing n % max events. + ''' + + # setup + newrelic = NewRelicStub() + + # preconditions + self.assertEqual(len(newrelic.logs), 0) + + # execute + pipeline.load_as_events( + logs(150), + newrelic, + labels, + 50, + set(), + ) + + # verify + self.assertEqual(len(newrelic.events), 3) + for i in range(0, 3): + l = newrelic.events[i] + self.assertEqual(len(l), 50) + + # setup + newrelic = NewRelicStub() + + # preconditions + self.assertEqual(len(newrelic.logs), 0) + + # execute + pipeline.load_as_events( + logs(53), + newrelic, + labels, + 50, + set(), + ) + + # verify + self.assertEqual(len(newrelic.events), 2) + l = newrelic.events[0] + self.assertEqual(len(l), 50) + l = newrelic.events[1] + self.assertEqual(len(l), 3) + + def test_load_data(self): + def logs(n = 50): + for i in range(0, n): + yield { + 'message': f'log {i}', + 'attributes': { + 'EVENT_TYPE': f'SFEvent{i}', + 'REQUEST_ID': f'abcdef-{i}', + }, + } + + ''' + given: a generator iterator of logs, newrelic instance, data format, + set of labels, max rows value, and a set of numeric field names + when: the data format is set to DataFormat.LOGS + then: log data is sent via the New Relic Logs API. + ''' + + # setup + newrelic = NewRelicStub() + + # preconditions + self.assertEqual(len(newrelic.logs), 0) + self.assertEqual(len(newrelic.events), 0) + + # execute + pipeline.load_data( + logs(), + newrelic, + DataFormat.LOGS, + { 'foo': 'bar' }, + 50, + set() + ) + + # verify + self.assertEqual(len(newrelic.logs), 1) + self.assertEqual(len(newrelic.events), 0) + self.assertEqual(len(newrelic.logs[0]), 1) + self.assertTrue('logs' in newrelic.logs[0][0]) + self.assertEqual(len(newrelic.logs[0][0]['logs']), 50) + + ''' + given: a generator iterator of logs, newrelic instance, data format, + set of labels, max rows value, and a set of numeric field names + when: the data format is set to DataFormat.EVENTS + then: log data is sent via the New Relic Events API. + ''' + + # setup + newrelic = NewRelicStub() + + # preconditions + self.assertEqual(len(newrelic.logs), 0) + self.assertEqual(len(newrelic.events), 0) + + # execute + pipeline.load_data( + logs(), + newrelic, + DataFormat.EVENTS, + { 'foo': 'bar' }, + 50, + set() + ) + + # verify + self.assertEqual(len(newrelic.logs), 0) + self.assertEqual(len(newrelic.events), 1) + self.assertEqual(len(newrelic.events[0]), 50) + + def test_pipeline_process_log_record(self): + ''' + given: an instance configuration, data cache, http session, newrelic + instance, data format, set of labels, event type fields mapping, + set of numeric field names, query, instance url, access token and + log record + when: the pipeline is configured with the configuration, session, + newrelic instance, data format, labels, event type fields mapping, + and numeric field names + and when: the data format is set to DataFormat.LOGS + and when: a log record is being processed + and when: the number of log lines to be processed is less than the + maximum number of rows + and when: no data cache is specified + then: a single Logs API post is made containing all labels in the + 'common' property of the logs post and one log for each exported + and transformed log line with the correct attributes from the + corresponding log line using the record ID, event type, and file + name from the given record + ''' + + # setup + cfg = config.Config({}) + session = SessionStub([]) + session.response = ResponseStub(200, 'OK', '', self.log_rows) + newrelic = NewRelicStub() + query = QueryStub({}) + + p = pipeline.Pipeline( + cfg, + None, + newrelic, + DataFormat.LOGS, + { 'foo': 'bar' }, + {}, + set(), + ) + + record = self.log_records[0] + + # preconditions + self.assertEqual(len(newrelic.logs), 0) + + # execute + p.process_log_record( + session, + query, + 'https://test.local.test', + '12345', + record, + ) + + # verify + self.assertEqual(len(newrelic.logs), 1) + l = newrelic.logs[0] + self.assertEqual(len(l), 1) + l = l[0] + self.assertTrue('logs' in l) + self.assertTrue('common' in l) + self.assertTrue(type(l['common']) is dict) + self.assertTrue('foo' in l['common']) + self.assertEqual(l['common']['foo'], 'bar') + self.assertEqual(len(l['logs']), 2) + + logs = l['logs'] + log0 = logs[0] + log1 = logs[1] + + self.assertTrue('message' in log0) + self.assertEqual(log0['message'], 'LogFile 00001111AAAABBBB row 0') + self.assertTrue('attributes' in log0) + self.assertTrue('EVENT_TYPE' in log0['attributes']) + self.assertEqual(log0['attributes']['EVENT_TYPE'], f'ApexCallout') + self.assertTrue('REQUEST_ID' in log0['attributes']) + self.assertEqual(log0['attributes']['REQUEST_ID'], f'YYZ:abcdef123456') + + self.assertTrue('message' in log1) + self.assertEqual(log1['message'], 'LogFile 00001111AAAABBBB row 1') + self.assertTrue('attributes' in log1) + self.assertTrue('EVENT_TYPE' in log1['attributes']) + self.assertEqual(log1['attributes']['EVENT_TYPE'], f'ApexCallout') + self.assertTrue('REQUEST_ID' in log0['attributes']) + self.assertEqual(log1['attributes']['REQUEST_ID'], f'YYZ:fedcba654321') + + ''' + given: the values from use case 1 + when: the pipeline is configured as in use case 1 + and when: the data format is set to DataFormat.LOGS + and when: the number of log lines to be processed is less than the + maximum number of rows, + and when: a data cache is specified + and when: the record ID matches a record ID in the data cache + then: no log entries are sent + ''' + + # setup + data_cache = DataCacheStub(skip_record_ids=['00001111AAAABBBB']) + newrelic = NewRelicStub() + + p = pipeline.Pipeline( + cfg, + data_cache, + newrelic, + DataFormat.LOGS, + { 'foo': 'bar' }, + {}, + set(), + ) + + record = self.log_records[0] + + # preconditions + self.assertEqual(len(newrelic.logs), 0) + + # execute + p.process_log_record( + session, + query, + 'https://test.local.test', + '12345', + record, + ) + + # verify + self.assertEqual(len(newrelic.logs), 0) + + ''' + given: the values from use case 2 + when: the pipeline is configured as in use case 2 + and when: the data format is set to DataFormat.LOGS + and when: the number of log lines to be processed is less than the + maximum number of rows, + and when: a data cache is specified + and when: the record ID matches a record ID in the data cache + and when: the 'Interval' value of the record is set to 'Daily' + then: A Logs API post is made for the record anyway + ''' + + # setup + newrelic = NewRelicStub() + + p = pipeline.Pipeline( + cfg, + data_cache, + newrelic, + DataFormat.LOGS, + { 'foo': 'bar' }, + {}, + set(), + ) + + new_record = copy.deepcopy(record) + new_record['Interval'] = 'Daily' + + # preconditions + self.assertEqual(len(newrelic.logs), 0) + + # execute + p.process_log_record( + session, + query, + 'https://test.local.test', + '12345', + new_record, + ) + + # verify + self.assertEqual(len(newrelic.logs), 1) + l = newrelic.logs[0] + self.assertEqual(len(l), 1) + l = l[0] + self.assertTrue('logs' in l) + self.assertTrue('common' in l) + self.assertTrue(type(l['common']) is dict) + self.assertTrue('foo' in l['common']) + self.assertEqual(l['common']['foo'], 'bar') + self.assertEqual(len(l['logs']), 2) + + ''' + given: the values from use case 3 + when: the pipeline is configured as in use case 3 + and when: the data format is set to DataFormat.LOGS + and when: the number of log lines to be processed is less than the + maximum number of rows, + and when: a data cache is specified + and when: the cache contains a list of log line IDs for the record ID + then: A Logs API post is made for the record containing log entries only + for log lines that have log line IDs that are not in the list of + cached log lines for the record ID + ''' + + # setup + data_cache = DataCacheStub( + cached_log_lines={ + '00001111AAAABBBB': ['YYZ:abcdef123456', 'YYZ:fedcba654321'] + } + ) + newrelic = NewRelicStub() + + p = pipeline.Pipeline( + cfg, + data_cache, + newrelic, + DataFormat.LOGS, + { 'foo': 'bar' }, + {}, + set(), + ) + + record = self.log_records[0] + + # preconditions + self.assertEqual(len(newrelic.logs), 0) + + # execute + p.process_log_record( + session, + query, + 'https://test.local.test', + '12345', + record, + ) + + # verify + self.assertEqual(len(newrelic.logs), 0) + + def test_pipeline_process_event_records(self): + ''' + given: an instance configuration, data cache, http session, newrelic + instance, data format, set of labels, event type fields mapping, + set of numeric field names, query, and a set of event records + when: the pipeline is configured with the configuration, session, + newrelic instance, data format, labels, event type fields mapping, + and numeric field names + and when: the data format is set to DataFormat.LOGS + and when: event records are being processed + and when: the number of event records to be processed is less than the + maximum number of rows + and when: no data cache is specified + then: a single Events API post is made containing all labels in the + 'common' property of the logs post and one log for each exported + and transformed event record with the correct attributes from the + corresponding event record + ''' + + # setup + cfg = config.Config({}) + newrelic = NewRelicStub() + query = QueryStub({ 'id': ['Name'] }) + + p = pipeline.Pipeline( + cfg, + None, + newrelic, + DataFormat.LOGS, + { 'foo': 'bar' }, + {}, + set(), + ) + + # preconditions + self.assertEqual(len(newrelic.logs), 0) + + # execute + p.process_event_records(query, self.event_records) + + # verify + self.assertEqual(len(newrelic.logs), 1) + l = newrelic.logs[0] + self.assertEqual(len(l), 1) + l = l[0] + self.assertTrue('logs' in l) + self.assertTrue('common' in l) + self.assertTrue(type(l['common']) is dict) + self.assertTrue('foo' in l['common']) + self.assertEqual(l['common']['foo'], 'bar') + self.assertEqual(len(l['logs']), 3) + + logs = l['logs'] + log0 = logs[0] + log1 = logs[1] + log2 = logs[2] + + self.assertTrue('message' in log0) + self.assertEqual(log0['message'], 'Account 2024-03-11T00:00:00.000+0000') + self.assertTrue('attributes' in log0) + self.assertTrue('Id' in log0['attributes']) + self.assertEqual(log0['attributes']['Id'], f'000012345') + self.assertTrue('Name' in log0['attributes']) + self.assertEqual(log0['attributes']['Name'], f'My Account') + + self.assertTrue('message' in log1) + self.assertEqual(log1['message'], 'Account 2024-03-10T00:00:00.000+0000') + self.assertTrue('attributes' in log1) + self.assertTrue('Id' in log1['attributes']) + self.assertEqual(log1['attributes']['Id'], f'000054321') + self.assertTrue('Name' in log1['attributes']) + self.assertEqual(log1['attributes']['Name'], f'My Other Account') + + customId = util.generate_record_id([ 'Name' ], self.event_records[2]) + + self.assertTrue('message' in log2) + self.assertEqual(log2['message'], 'Account 2024-03-09T00:00:00.000+0000') + self.assertTrue('attributes' in log2) + self.assertTrue('Id' in log2['attributes']) + self.assertEqual(log2['attributes']['Id'], customId) + self.assertTrue('Name' in log2['attributes']) + self.assertEqual(log2['attributes']['Name'], f'My Last Account') + + def test_pipeline_execute(self): + ''' + given: an instance configuration, data cache, http session, newrelic + instance, data format, set of labels, event type fields mapping, + set of numeric field names, query, and a set of query result + records + when: the pipeline is configured with the configuration, session, + newrelic instance, data format, labels, event type fields mapping, + and numeric field names + and when: the first record in the result set contains a 'LogFile' + attribute + and when: the number of log lines to be processed is less than the + maximum number of rows + and when: a data cache is specified + then: a single Logs API post is made containing all labels in the + 'common' property of the logs post and one log for each exported + and transformed log line, and the cache is flushed + ''' + + # setup + cfg = config.Config({}) + session = SessionStub([]) + session.response = ResponseStub(200, 'OK', '', self.log_rows) + newrelic = NewRelicStub() + query = QueryStub({}) + data_cache = DataCacheStub() + + p = pipeline.Pipeline( + cfg, + data_cache, + newrelic, + DataFormat.LOGS, + { 'foo': 'bar' }, + {}, + set(), + ) + + # preconditions + self.assertEqual(len(newrelic.logs), 0) + + # execute + p.execute( + session, + query, + 'https://test.local.test', + '12345', + self.log_records, + ) + + # verify + self.assertEqual(len(newrelic.logs), 2) + + for _, l in enumerate(newrelic.logs): + self.assertEqual(len(l), 1) + l = l[0] + self.assertTrue('logs' in l) + self.assertTrue('common' in l) + self.assertTrue(type(l['common']) is dict) + self.assertTrue('foo' in l['common']) + self.assertEqual(l['common']['foo'], 'bar') + self.assertEqual(len(l['logs']), 2) + + self.assertTrue(data_cache.flush_called) + + ''' + given: the values from use case 1 + when: the pipeline is configured as in use case 1 + and when: the first record in the result set does not contain a + 'LogFile' attribute + and when: a data cache is specified + and when: the number of event records to be processed is less than the + maximum number of rows + then: a single Events API post is made containing all labels in the + 'common' property of the logs post and one log for each exported + and transformed event record, and the cache is flushed + ''' + + cfg = config.Config({}) + newrelic = NewRelicStub() + query = QueryStub({ 'id': ['Name'] }) + data_cache = DataCacheStub() + + p = pipeline.Pipeline( + cfg, + data_cache, + newrelic, + DataFormat.LOGS, + { 'foo': 'bar' }, + {}, + set(), + ) + + self.assertEqual(len(newrelic.logs), 0) + + p.execute( + session, + query, + 'https://test.local.test', + '12345', + self.event_records, + ) + + self.assertEqual(len(newrelic.logs), 1) + l = newrelic.logs[0] + self.assertEqual(len(l), 1) + l = l[0] + self.assertTrue('logs' in l) + self.assertTrue('common' in l) + self.assertTrue(type(l['common']) is dict) + self.assertTrue('foo' in l['common']) + self.assertEqual(l['common']['foo'], 'bar') + self.assertEqual(len(l['logs']), 3) + + self.assertTrue(data_cache.flush_called) + + +if __name__ == '__main__': + unittest.main() diff --git a/src/tests/test_util.py b/src/tests/test_util.py new file mode 100644 index 0000000..7fc960f --- /dev/null +++ b/src/tests/test_util.py @@ -0,0 +1,138 @@ +import hashlib +import unittest + +from newrelic_logging import util + +class TestUtilities(unittest.TestCase): + def test_is_logfile_response(self): + ''' + given: a set of query result records + when: the set is the empty set + then: return true + ''' + + # execute/verify + self.assertTrue(util.is_logfile_response([])) + + ''' + given: a set of query result records + when: the first record in the set contains a 'LogFile' property + then: return true + ''' + + # execute/verify + self.assertTrue(util.is_logfile_response([ + { 'LogFile': 'example' } + ])) + + ''' + given: a set of query result records + when: the first record in the set does not contain a 'LogFile' property + then: return false + ''' + + # execute/verify + self.assertFalse(util.is_logfile_response([{}])) + + def test_generate_record_id(self): + ''' + given: a set of id keys and a query result record + when: the set of id keys is the empty set + then: return the empty string + ''' + + # execute + record_id = util.generate_record_id([], { 'Name': 'foo' }) + + # verify + self.assertEqual(record_id, '') + + ''' + given: a set of id keys and a query result record + when: the set of id keys is not empty + and when: there is some key for which the query result record does not + have a property + then: an exception is raised + ''' + + # execute/verify + with self.assertRaises(Exception): + util.generate_record_id([ 'EventType' ], { 'Name': 'foo' }) + + ''' + given: a set of id keys and a query result record + when: the set of id keys is not empty + and when: the query result record has a property for the key but the + value for that key is the empty string + then: return the empty string + ''' + + # execute + record_id = util.generate_record_id([ 'Name' ], { 'Name': '' }) + + # verify + self.assertEqual(record_id, '') + + ''' + given: a set of id keys and a query result record + when: the set of id keys is not empty + and when: the query result record has a property for the key with a + value that is not the emptry string + then: return a value obtained by concatenating the values for all id + keys and creating a sha3 256 message digest over that value + ''' + + # execute + record_id = util.generate_record_id([ 'Name' ], { 'Name': 'foo' }) + + # verify + m = hashlib.sha3_256() + m.update('foo'.encode('utf-8')) + expected = m.hexdigest() + + self.assertEqual(expected, record_id) + + def test_maybe_convert_str_to_num(self): + ''' + given: a string + when: the string contains a valid integer + then: the string value is converted to an integer + ''' + + # execute + val = util.maybe_convert_str_to_num('2') + + # verify + self.assertTrue(type(val) is int) + self.assertEqual(val, 2) + + ''' + given: a string + when: the string contains a valid floating point number + then: the string value is converted to a float + ''' + + # execute + val = util.maybe_convert_str_to_num('3.14') + + # verify + self.assertTrue(type(val) is float) + self.assertEqual(val, 3.14) + + ''' + given: a string + when: the string contains neither a valid integer nor a valid floating + point number + then: the string value is returned + ''' + + # execute + val = util.maybe_convert_str_to_num('not a number') + + # verify + self.assertTrue(type(val) is str) + self.assertEqual(val, 'not a number') + + +if __name__ == '__main__': + unittest.main() From d3d827e955fbe4f981275f1992c60c0175e3f62d Mon Sep 17 00:00:00 2001 From: Scott DeWitt Date: Fri, 15 Mar 2024 17:06:59 -0400 Subject: [PATCH 04/11] feat: optimize memory usage part 4: add additional unit tests --- src/__main__.py | 20 + src/newrelic_logging/auth.py | 34 +- src/newrelic_logging/cache.py | 60 +-- src/newrelic_logging/integration.py | 148 ++++--- src/newrelic_logging/newrelic.py | 54 +-- src/newrelic_logging/pipeline.py | 41 +- src/newrelic_logging/query.py | 53 +-- src/newrelic_logging/salesforce.py | 74 ++-- src/newrelic_logging/telemetry.py | 17 +- src/newrelic_logging/util.py | 19 + src/tests/__init__.py | 275 +++++++++++-- src/tests/test_integration.py | 452 ++++++++++++++++++++++ src/tests/test_pipeline.py | 12 +- src/tests/test_salesforce.py | 580 ++++++++++++++++++++++++++++ src/tests/test_util.py | 82 ++++ 15 files changed, 1700 insertions(+), 221 deletions(-) create mode 100644 src/tests/test_integration.py create mode 100644 src/tests/test_salesforce.py diff --git a/src/__main__.py b/src/__main__.py index 6d9059a..6ad67a6 100644 --- a/src/__main__.py +++ b/src/__main__.py @@ -13,7 +13,14 @@ from apscheduler.schedulers.background import BlockingScheduler from pytz import utc from yaml import Loader, load +from newrelic_logging.auth import AuthenticatorFactory +from newrelic_logging.cache import CacheFactory from newrelic_logging.config import Config, getenv +from newrelic_logging.newrelic import NewRelicFactory +from newrelic_logging.pipeline import PipelineFactory +from newrelic_logging.query import QueryFactory +from newrelic_logging.salesforce import SalesForceFactory + from newrelic_logging.integration import Integration from newrelic_logging.telemetry import print_info, print_warn @@ -127,8 +134,15 @@ def run_once( event_type_fields_mapping: dict, numeric_fields_list: set ): + Integration( config, + AuthenticatorFactory(), + CacheFactory(), + PipelineFactory(), + SalesForceFactory(), + QueryFactory(), + NewRelicFactory(), event_type_fields_mapping, numeric_fields_list, config.get_int(CRON_INTERVAL_MINUTES, 60), @@ -157,6 +171,12 @@ def run_as_service( scheduler.add_job( Integration( config, + AuthenticatorFactory(), + CacheFactory(), + PipelineFactory(), + SalesForceFactory(), + QueryFactory(), + NewRelicFactory(), event_type_fields_mapping, numeric_fields_list, 0 diff --git a/src/newrelic_logging/auth.py b/src/newrelic_logging/auth.py index 8080691..f1731e2 100644 --- a/src/newrelic_logging/auth.py +++ b/src/newrelic_logging/auth.py @@ -47,7 +47,7 @@ def set_auth_data(self, access_token: str, instance_url: str) -> None: self.access_token = access_token self.instance_url = instance_url - def clear_auth(self): + def clear_auth(self) -> None: self.set_auth_data(None, None) if self.data_cache: @@ -84,7 +84,7 @@ def load_auth_from_cache(self) -> bool: return False - def store_auth(self, auth_resp: dict): + def store_auth(self, auth_resp: dict) -> None: self.access_token = auth_resp['access_token'] self.instance_url = auth_resp['instance_url'] @@ -291,21 +291,25 @@ def make_auth_from_env(config: Config) -> dict: raise Exception(f'Wrong or missing grant_type') -def New(config: Config, data_cache: DataCache) -> Authenticator: - token_url = config.get('token_url', env_var_name=SF_TOKEN_URL) +class AuthenticatorFactory: + def __init__(self): + pass - if not token_url: - raise ConfigException('token_url', 'missing token URL') + def new(self, config: Config, data_cache: DataCache) -> Authenticator: + token_url = config.get('token_url', env_var_name=SF_TOKEN_URL) + + if not token_url: + raise ConfigException('token_url', 'missing token URL') + + if 'auth' in config: + return Authenticator( + token_url, + make_auth_from_config(config.sub('auth')), + data_cache, + ) - if 'auth' in config: return Authenticator( token_url, - make_auth_from_config(config.sub('auth')), - data_cache, + make_auth_from_env(config), + data_cache ) - - return Authenticator( - token_url, - make_auth_from_env(config), - data_cache - ) diff --git a/src/newrelic_logging/cache.py b/src/newrelic_logging/cache.py index 4438f77..b9a49b5 100644 --- a/src/newrelic_logging/cache.py +++ b/src/newrelic_logging/cache.py @@ -129,31 +129,35 @@ def flush(self) -> None: gc.collect() -def New(config: Config): - if config.get_bool(CONFIG_CACHE_ENABLED, DEFAULT_CACHE_ENABLED): - host = config.get(CONFIG_REDIS_HOST, DEFAULT_REDIS_HOST) - port = config.get_int(CONFIG_REDIS_PORT, DEFAULT_REDIS_PORT) - db = config.get_int(CONFIG_REDIS_DB_NUMBER, DEFAULT_REDIS_DB_NUMBER) - password = config.get(CONFIG_REDIS_PASSWORD) - ssl = config.get_bool(CONFIG_REDIS_USE_SSL, DEFAULT_REDIS_SSL) - expire_days = config.get_int(CONFIG_REDIS_EXPIRE_DAYS) - password_display = "XXXXXX" if password != None else None - - print_info( - f'Cache enabled, connecting to redis instance {host}:{port}:{db}, ssl={ssl}, password={password_display}' - ) - - return DataCache( - RedisBackend( - redis.Redis( - host=host, - port=port, - db=db, - password=password, - ssl=ssl - ), expire_days) - ) - - print_info('Cache disabled') - - return None +class CacheFactory: + def __init__(self): + pass + + def new(self, config: Config): + if config.get_bool(CONFIG_CACHE_ENABLED, DEFAULT_CACHE_ENABLED): + host = config.get(CONFIG_REDIS_HOST, DEFAULT_REDIS_HOST) + port = config.get_int(CONFIG_REDIS_PORT, DEFAULT_REDIS_PORT) + db = config.get_int(CONFIG_REDIS_DB_NUMBER, DEFAULT_REDIS_DB_NUMBER) + password = config.get(CONFIG_REDIS_PASSWORD) + ssl = config.get_bool(CONFIG_REDIS_USE_SSL, DEFAULT_REDIS_SSL) + expire_days = config.get_int(CONFIG_REDIS_EXPIRE_DAYS) + password_display = "XXXXXX" if password != None else None + + print_info( + f'Cache enabled, connecting to redis instance {host}:{port}:{db}, ssl={ssl}, password={password_display}' + ) + + return DataCache( + RedisBackend( + redis.Redis( + host=host, + port=port, + db=db, + password=password, + ssl=ssl + ), expire_days) + ) + + print_info('Cache disabled') + + return None diff --git a/src/newrelic_logging/integration.py b/src/newrelic_logging/integration.py index 295b444..b4fd53e 100644 --- a/src/newrelic_logging/integration.py +++ b/src/newrelic_logging/integration.py @@ -11,26 +11,84 @@ from . import cache from . import config as mod_config from . import newrelic +from . import query from . import pipeline +from . import salesforce from .http_session import new_retry_session -from .salesforce import SalesForce -from .telemetry import Telemetry, print_info, print_err +from .telemetry import print_err, print_info, print_warn, Telemetry # @TODO: move queries to the instance level, so we can have different queries for # each instance. # @TODO: also keep general queries that apply to all instances. +def build_instance( + config: mod_config.Config, + auth_factory: auth.AuthenticatorFactory, + cache_factory: cache.CacheFactory, + pipeline_factory: pipeline.PipelineFactory, + salesforce_factory: salesforce.SalesForceFactory, + query_factory: query.QueryFactory, + new_relic: newrelic.NewRelic, + data_format: DataFormat, + event_type_fields_mapping: dict, + numeric_fields_list: set, + initial_delay: int, + instance: dict, + index: int, +): + instance_name = instance['name'] + labels = instance['labels'] + labels['nr-labs'] = 'data' + instance_config = config.sub(f'instances.{index}.arguments') + instance_config.set_prefix( + instance_config['auth_env_prefix'] \ + if 'auth_env_prefix' in instance_config else '' + ) + + data_cache = cache_factory.new(instance_config) + authenticator = auth_factory.new(instance_config, data_cache) + + return { + 'client': salesforce_factory.new( + instance_name, + instance_config, + data_cache, + authenticator, + pipeline_factory.new( + instance_config, + data_cache, + new_relic, + data_format, + labels, + event_type_fields_mapping, + numeric_fields_list, + ), + query_factory, + initial_delay, + config['queries'] if 'queries' in config else None, + ), + 'name': instance_name, + } class Integration: def __init__( self, config: mod_config.Config, + auth_factory: auth.AuthenticatorFactory, + cache_factory: cache.CacheFactory, + pipeline_factory: pipeline.PipelineFactory, + salesforce_factory: salesforce.SalesForceFactory, + query_factory: query.QueryFactory, + newrelic_factory: newrelic.NewRelicFactory, event_type_fields_mapping: dict = {}, numeric_fields_list: set = set(), initial_delay: int = 0, ): - Telemetry(config["integration_name"]) + Telemetry( + config['integration_name'] if 'integration_name' in config \ + else 'com.newrelic.labs.sfdc.eventlogfiles' + ) data_format = config.get('newrelic.data_format', 'logs').lower() if data_format == 'logs': @@ -40,56 +98,42 @@ def __init__( else: raise ConfigException(f'invalid data format {data_format}') - # Fill credentials for NR APIs - - new_relic = newrelic.New(config) - + self.new_relic = newrelic_factory.new(config) self.instances = [] - for count, instance in enumerate(config['instances']): - instance_name = instance['name'] - labels = instance['labels'] - labels['nr-labs'] = 'data' - instance_config = config.sub(f'instances.{count}.arguments') - instance_config.set_prefix( - instance_config['auth_env_prefix'] \ - if 'auth_env_prefix' in instance_config else '' - ) - - data_cache = cache.New(instance_config) - authenticator = auth.New(instance_config, data_cache) - - self.instances.append({ - 'client': SalesForce( - instance_name, - instance_config, - data_cache, - authenticator, - pipeline.New( - instance_config, - data_cache, - new_relic, - data_format, - labels, - event_type_fields_mapping, - numeric_fields_list, - ), - initial_delay, - config['queries'] if 'queries' in config else None, - ), - 'name': instance_name, - }) - - def process_telemetry(self): - if not Telemetry().is_empty(): - print_info("Sending telemetry data") - self.process_logs(Telemetry().build_model(), {}, None) - Telemetry().clear() - else: + + if not 'instances' in config or len(config['instances']) == 0: + print_warn('no instances found to run') + return + + for index, instance in enumerate(config['instances']): + self.instances.append(build_instance( + config, + auth_factory, + cache_factory, + pipeline_factory, + salesforce_factory, + query_factory, + self.new_relic, + data_format, + event_type_fields_mapping, + numeric_fields_list, + initial_delay, + instance, + index, + )) + + def process_telemetry(self, session: Session): + if Telemetry().is_empty(): print_info("No telemetry data") + return + + print_info("Sending telemetry data") + self.new_relic.post_logs(session, Telemetry().build_model()) + Telemetry().clear() def auth_and_fetch( self, - client: SalesForce, + client: salesforce.SalesForce, session: Session, retry: bool = True, ) -> None: @@ -123,7 +167,9 @@ def auth_and_fetch( print_err(f'unknown exception occurred: {e}') def run(self): + session = new_retry_session() + for instance in self.instances: - print_info(f"Running instance '{instance['name']}'") - self.auth_and_fetch(instance['client'], new_retry_session()) - self.process_telemetry() + print_info(f'Running instance "{instance["name"]}"') + self.auth_and_fetch(instance['client'], session) + self.process_telemetry(session) diff --git a/src/newrelic_logging/newrelic.py b/src/newrelic_logging/newrelic.py index ed9d7e1..3d3bba1 100644 --- a/src/newrelic_logging/newrelic.py +++ b/src/newrelic_logging/newrelic.py @@ -107,29 +107,31 @@ def post_events(self, session: Session, events: list[dict]) -> None: raise NewRelicApiException('newrelic events api request failed') -def New( - config: Config, -): - license_key = config.get( - 'newrelic.license_key', - env_var_name=NR_LICENSE_KEY, - ) - - region = config.get('newrelic.api_endpoint') - account_id = config.get('newrelic.account_id', env_var_name=NR_ACCOUNT_ID) - - if region == "US": - logs_api_endpoint = US_LOGGING_ENDPOINT - events_api_endpoint = US_EVENTS_ENDPOINT.format(account_id=account_id) - elif region == "EU": - logs_api_endpoint = EU_LOGGING_ENDPOINT - events_api_endpoint = EU_EVENTS_ENDPOINT.format(account_id=account_id) - else: - raise NewRelicApiException(f'Invalid region {region}') - - return NewRelic( - logs_api_endpoint, - license_key, - events_api_endpoint, - license_key, - ) +class NewRelicFactory: + def __init__(self): + pass + + def new(self, config: Config): + license_key = config.get( + 'newrelic.license_key', + env_var_name=NR_LICENSE_KEY, + ) + + region = config.get('newrelic.api_endpoint') + account_id = config.get('newrelic.account_id', env_var_name=NR_ACCOUNT_ID) + + if region == "US": + logs_api_endpoint = US_LOGGING_ENDPOINT + events_api_endpoint = US_EVENTS_ENDPOINT.format(account_id=account_id) + elif region == "EU": + logs_api_endpoint = EU_LOGGING_ENDPOINT + events_api_endpoint = EU_EVENTS_ENDPOINT.format(account_id=account_id) + else: + raise NewRelicApiException(f'Invalid region {region}') + + return NewRelic( + logs_api_endpoint, + license_key, + events_api_endpoint, + license_key, + ) diff --git a/src/newrelic_logging/pipeline.py b/src/newrelic_logging/pipeline.py index c53bfb5..0dd32a3 100644 --- a/src/newrelic_logging/pipeline.py +++ b/src/newrelic_logging/pipeline.py @@ -475,21 +475,26 @@ def execute( if self.data_cache: self.data_cache.flush() -def New( - config: Config, - data_cache: DataCache, - new_relic: NewRelic, - data_format: DataFormat, - labels: dict, - event_type_field_mappings: dict, - numeric_fields_list: set, -): - return Pipeline( - config, - data_cache, - new_relic, - data_format, - labels, - event_type_field_mappings, - numeric_fields_list, - ) +class PipelineFactory: + def __init__(self): + pass + + def new( + self, + config: Config, + data_cache: DataCache, + new_relic: NewRelic, + data_format: DataFormat, + labels: dict, + event_type_field_mappings: dict, + numeric_fields_list: set, + ): + return Pipeline( + config, + data_cache, + new_relic, + data_format, + labels, + event_type_field_mappings, + numeric_fields_list, + ) diff --git a/src/newrelic_logging/query.py b/src/newrelic_logging/query.py index 9299b53..5e886a8 100644 --- a/src/newrelic_logging/query.py +++ b/src/newrelic_logging/query.py @@ -5,7 +5,7 @@ from . import SalesforceApiException from .config import Config from .telemetry import print_info -from .util import substitute +from .util import get_iso_date_with_offset, substitute class Query: def __init__( @@ -55,29 +55,34 @@ def execute( ) from e -def New( - q: dict, - time_lag_minutes: int, - last_to_timestamp: str, - generation_interval: str, - default_api_ver: str, -) -> Query: - to_timestamp = ( - datetime.utcnow() - timedelta(minutes=time_lag_minutes) - ).isoformat(timespec='milliseconds') + "Z" - from_timestamp = last_to_timestamp +class QueryFactory: + def __init__(self): + pass - qp = copy.deepcopy(q) - qq = qp.pop('query', '') + def new( + self, + q: dict, + time_lag_minutes: int, + last_to_timestamp: str, + generation_interval: str, + default_api_ver: str, + ) -> Query: + to_timestamp = get_iso_date_with_offset(time_lag_minutes) + from_timestamp = last_to_timestamp + + qp = copy.deepcopy(q) + qq = qp.pop('query', '') + + args = { + 'to_timestamp': to_timestamp, + 'from_timestamp': from_timestamp, + 'log_interval_type': generation_interval, + } - args = { - 'to_timestamp': to_timestamp, - 'from_timestamp': from_timestamp, - 'log_interval_type': generation_interval, - } + env = qp['env'] if 'env' in qp and type(qp['env']) is dict else {} - return Query( - substitute(args, qq, qp).replace(' ', '+'), - Config(qp), - qp.get('api_ver', default_api_ver) - ) + return Query( + substitute(args, qq, env).replace(' ', '+'), + Config(qp), + qp.get('api_ver', default_api_ver) + ) diff --git a/src/newrelic_logging/salesforce.py b/src/newrelic_logging/salesforce.py index c8d15cb..1afa0db 100644 --- a/src/newrelic_logging/salesforce.py +++ b/src/newrelic_logging/salesforce.py @@ -1,13 +1,13 @@ from datetime import datetime, timedelta from requests import Session -from . import DataFormat from .auth import Authenticator from .cache import DataCache from . import config as mod_config from .pipeline import Pipeline from . import query as mod_query -from .telemetry import print_info +from .telemetry import print_info, print_warn +from .util import get_iso_date_with_offset CSV_SLICE_SIZE = 1000 @@ -27,13 +27,16 @@ def __init__( data_cache: DataCache, authenticator: Authenticator, pipeline: Pipeline, + query_factory: mod_query.QueryFactory, initial_delay: int, - queries=None, + queries: list[dict] = None, ): self.instance_name = instance_name - self.default_api_ver = config.get('api_ver', '52.0') self.data_cache = data_cache self.auth = authenticator + self.pipeline = pipeline + self.query_factory = query_factory + self.default_api_ver = config.get('api_ver', '52.0') self.time_lag_minutes = config.get( mod_config.CONFIG_TIME_LAG_MINUTES, mod_config.DEFAULT_TIME_LAG_MINUTES if not self.data_cache else 0, @@ -47,28 +50,24 @@ def __init__( mod_config.CONFIG_GENERATION_INTERVAL, mod_config.DEFAULT_GENERATION_INTERVAL, ) - self.last_to_timestamp = (datetime.utcnow() - timedelta( - minutes=self.time_lag_minutes + initial_delay - )).isoformat(timespec='milliseconds') + 'Z' - - if queries: - self.queries = queries - else: - self.queries = [{ + self.last_to_timestamp = get_iso_date_with_offset( + self.time_lag_minutes, + initial_delay, + ) + self.queries = queries if queries else \ + [{ 'query': SALESFORCE_LOG_DATE_QUERY \ if self.date_field.lower() == 'logdate' \ else SALESFORCE_CREATED_DATE_QUERY }] - self.pipeline = pipeline - def authenticate(self, sfdc_session: Session): self.auth.authenticate(sfdc_session) def slide_time_range(self): - self.last_to_timestamp = ( - datetime.utcnow() - timedelta(minutes=self.time_lag_minutes)) \ - .isoformat(timespec='milliseconds') + "Z" + self.last_to_timestamp = get_iso_date_with_offset( + self.time_lag_minutes + ) # NOTE: Is it possible that different SF orgs have overlapping IDs? If this is possible, we should use a different # database for each org, or add a prefix to keys to avoid conflicts. @@ -76,21 +75,24 @@ def slide_time_range(self): def fetch_logs(self, session: Session) -> list[dict]: print_info(f"Queries = {self.queries}") - for query in self.queries: - response = mod_query.New( - query, + for q in self.queries: + query = self.query_factory.new( + q, self.time_lag_minutes, self.last_to_timestamp, self.generation_interval, self.default_api_ver, - ).execute( + ) + + response = query.execute( session, self.auth.get_instance_url(), self.auth.get_access_token(), ) - # Show query response - #print("Response = ", response) + if not response or not 'records' in response: + print_warn(f'no records returned for query {query.query}') + continue self.pipeline.execute( session, @@ -135,3 +137,29 @@ def fetch_logs(self, session: Session) -> list[dict]: # csv_rows = self.parse_csv(download_response, record_id, record_event_type, cached_messages) # # print_info(f"CSV rows = {len(csv_rows)}") + +class SalesForceFactory: + def __init__(self): + pass + + def new( + self, + instance_name: str, + config: mod_config.Config, + data_cache: DataCache, + authenticator: Authenticator, + pipeline: Pipeline, + query_factory: mod_query.QueryFactory, + initial_delay: int, + queries: list[dict] = None, + ): + return SalesForce( + instance_name, + config, + data_cache, + authenticator, + pipeline, + query_factory, + initial_delay, + queries, + ) diff --git a/src/newrelic_logging/telemetry.py b/src/newrelic_logging/telemetry.py index 3e4174e..f8984f3 100644 --- a/src/newrelic_logging/telemetry.py +++ b/src/newrelic_logging/telemetry.py @@ -18,19 +18,19 @@ class Telemetry: def __init__(self, integration_name: str) -> None: self.integration_name = integration_name - + def is_empty(self): return len(self.logs) == 0 - + def log_info(self, msg: str): self.record_log(msg, "info") - + def log_err(self, msg: str): self.record_log(msg, "error") - + def log_warn(self, msg: str): self.record_log(msg, "warn") - + def record_log(self, msg: str, level: str): log = { "timestamp": round(time.time() * 1000), @@ -41,15 +41,16 @@ def record_log(self, msg: str, level: str): } } self.logs.append(log) - + def clear(self): self.logs = [] def build_model(self): return [{ - "log_entries": self.logs + "common": {}, + "logs": self.logs, }] - + def print_log(msg: str, level: str): print(json.dumps({ "message": msg, diff --git a/src/newrelic_logging/util.py b/src/newrelic_logging/util.py index 639ffab..d77864b 100644 --- a/src/newrelic_logging/util.py +++ b/src/newrelic_logging/util.py @@ -41,6 +41,25 @@ def maybe_convert_str_to_num(val: str) -> Union[int, str, float]: return val +# Make testing easier +def _utcnow(): + return datetime.utcnow() + +_UTCNOW = _utcnow + +def get_iso_date_with_offset( + time_lag_minutes: int = 0, + initial_delay: int = 0, +) -> str: + return ( + _UTCNOW() - timedelta( + minutes=(time_lag_minutes + initial_delay) + ) + ).isoformat( + timespec='milliseconds' + ) + 'Z' + + # NOTE: this sandbox can be jailbroken using the trick to exec statements inside # an exec block, and run an import (and other tricks): # https://book.hacktricks.xyz/generic-methodologies-and-resources/python/bypass-python-sandboxes#operators-and-short-tricks diff --git a/src/tests/__init__.py b/src/tests/__init__.py index 8b46fd7..50688f1 100644 --- a/src/tests/__init__.py +++ b/src/tests/__init__.py @@ -1,49 +1,85 @@ from requests import Session +from newrelic_logging import DataFormat +from newrelic_logging.auth import Authenticator +from newrelic_logging.cache import DataCache from newrelic_logging.config import Config +from newrelic_logging.newrelic import NewRelic +from newrelic_logging.pipeline import Pipeline +from newrelic_logging.query import Query, QueryFactory -class QueryStub: - def __init__(self, config: dict): - self.config = Config(config) +class AuthenticatorStub: + def __init__( + self, + config: Config = None, + data_cache: DataCache = None, + token_url: str = '', + access_token: str = '', + instance_url: str = '', + grant_type: str = '', + authenticate_called: bool = False, + ): + self.config = config + self.data_cache = data_cache + self.token_url = token_url + self.access_token = access_token + self.instance_url = instance_url + self.grant_type = grant_type + self.authenticate_called = authenticate_called - def get(self, key: str, default = None): - return self.config.get(key, default) + def get_access_token(self) -> str: + return self.access_token - def get_config(self): - return self.config + def get_instance_url(self) -> str: + return self.instance_url + + def get_grant_type(self) -> str: + return self.grant_type - def execute(): + def set_auth_data(self, access_token: str, instance_url: str) -> None: pass + def clear_auth(self) -> None: + pass -class ResponseStub: - def __init__(self, status_code, reason, text, lines): - self.status_code = status_code - self.reason = reason - self.text = text - self.lines = lines + def load_auth_from_cache(self) -> bool: + return False - def iter_lines(self, *args, **kwargs): - yield from self.lines + def store_auth(self, auth_resp: dict) -> None: + pass + def authenticate( + self, + session: Session, + ) -> None: + self.authenticate_called = True -class SessionStub: - def __init__(self, lines): - self.response = None + def authenticate_with_jwt(self, session: Session) -> None: + pass - def get(self, *args, **kwargs): - return self.response + def authenticate_with_password(self, session: Session) -> None: + pass + + +class AuthenticatorFactoryStub: + def __init__(self): + pass + + def new(self, config: Config, data_cache: DataCache) -> Authenticator: + return AuthenticatorStub(config, data_cache) class DataCacheStub: def __init__( self, + config: Config = None, cached_logs = {}, cached_events = [], skip_record_ids = [], cached_log_lines = {}, ): + self.config = config self.cached_logs = cached_logs self.cached_events = cached_events self.skip_record_ids = skip_record_ids @@ -68,8 +104,17 @@ def flush(self) -> None: self.flush_called = True -class NewRelicStub: +class CacheFactoryStub: def __init__(self): + pass + + def new(self, config: Config): + return DataCacheStub(config) + + +class NewRelicStub: + def __init__(self, config: Config = None): + self.config = config self.logs = [] self.events = [] @@ -78,3 +123,189 @@ def post_logs(self, session: Session, data: list[dict]) -> None: def post_events(self, session: Session, events: list[dict]) -> None: self.events.append(events) + + +class NewRelicFactoryStub: + def __init__(self): + pass + + def new(self, config: Config): + return NewRelicStub(config) + + +class QueryStub: + def __init__( + self, + config: Config = Config({}), + api_ver: str = '', + result: dict = { 'records': [] }, + query: str = '', + ): + self.query = query + self.config = config + self.api_ver = api_ver + self.executed = False + self.result = result + + def get(self, key: str, default = None): + return self.config.get(key, default) + + def get_config(self): + return self.config + + def execute( + self, + session: Session = None, + instance_url: str = '', + access_token: str = '', + ): + self.executed = True + return self.result + + +class QueryFactoryStub: + def __init__(self, query: QueryStub = None ): + self.query = query + self.queries = [] if not query else None + pass + + def new( + self, + q: dict, + time_lag_minutes: int = 0, + last_to_timestamp: str = '', + generation_interval: str = '', + default_api_ver: str = '', + ) -> Query: + if self.query: + return self.query + + qq = QueryStub(q, default_api_ver, query=q['query']) + self.queries.append(qq) + return qq + + +class PipelineStub: + def __init__( + self, + config: Config = Config({}), + data_cache: DataCache = None, + new_relic: NewRelic = None, + data_format: DataFormat = DataFormat.LOGS, + labels: dict = {}, + event_type_fields_mapping: dict = {}, + numeric_fields_list: set = set(), + ): + self.config = config + self.data_cache = data_cache + self.new_relic = new_relic + self.data_format = data_format + self.labels = labels + self.event_type_fields_mapping = event_type_fields_mapping + self.numeric_fields_list = numeric_fields_list + self.queries = [] + self.executed = False + + def execute( + self, + session: Session, + query: Query, + instance_url: str, + access_token: str, + records: list[dict], + ): + self.queries.append(query) + self.executed = True + + +class PipelineFactoryStub: + def __init__(self): + pass + + def new( + self, + config: Config, + data_cache: DataCache, + new_relic: NewRelic, + data_format: DataFormat, + labels: dict, + event_type_fields_mapping: dict, + numeric_fields_list: set, + ): + return PipelineStub( + config, + data_cache, + new_relic, + data_format, + labels, + event_type_fields_mapping, + numeric_fields_list, + ) + + +class ResponseStub: + def __init__(self, status_code, reason, text, lines): + self.status_code = status_code + self.reason = reason + self.text = text + self.lines = lines + + def iter_lines(self, *args, **kwargs): + yield from self.lines + + +class SalesForceStub: + def __init__( + self, + instance_name: str, + config: Config, + data_cache: DataCache, + authenticator: Authenticator, + pipeline: Pipeline, + query_factory: QueryFactory, + initial_delay: int, + queries: list[dict] = None, + ): + self.instance_name = instance_name + self.config = config + self.data_cache = data_cache + self.authenticator = authenticator + self.pipeline = pipeline + self.query_factory = query_factory + self.initial_delay = initial_delay + self.queries = queries + + +class SalesForceFactoryStub: + def __init__(self): + pass + + def new( + self, + instance_name: str, + config: Config, + data_cache: DataCache, + authenticator: Authenticator, + pipeline: Pipeline, + query_factory: QueryFactory, + initial_delay: int, + queries: list[dict] = None, + ): + return SalesForceStub( + instance_name, + config, + data_cache, + authenticator, + pipeline, + query_factory, + initial_delay, + queries, + ) + + +class SessionStub: + def __init__(self): + self.response = None + + def get(self, *args, **kwargs): + return self.response diff --git a/src/tests/test_integration.py b/src/tests/test_integration.py new file mode 100644 index 0000000..4d7a60a --- /dev/null +++ b/src/tests/test_integration.py @@ -0,0 +1,452 @@ +import unittest + +from . import AuthenticatorFactoryStub, \ + CacheFactoryStub, \ + NewRelicStub, \ + NewRelicFactoryStub, \ + PipelineFactoryStub, \ + QueryFactoryStub, \ + SalesForceFactoryStub +from newrelic_logging import ConfigException, \ + DataFormat, \ + config as mod_config, \ + integration + +class TestIntegration(unittest.TestCase): + def test_build_instance(self): + ''' + given: a Config instance, a NewRelic instance, a data format, a set of + event type field mappings, a set of numeric fields an initial + delay, an instance config, and a instance index + when: no prefix or queries are provided in the instance config + then: return a Salesforce instance configured with appropriate values + ''' + + # setup + config = mod_config.Config({ + 'instances': [ + { + 'name': 'test-inst-1', + 'labels': { + 'foo': 'bar', + 'beep': 'boop', + }, + 'arguments': { + 'token_url': 'https://my.salesforce.test/token', + 'cache_enabled': False, + }, + } + ] + }) + auth_factory = AuthenticatorFactoryStub() + cache_factory = CacheFactoryStub() + pipeline_factory = PipelineFactoryStub() + salesforce_factory = SalesForceFactoryStub() + query_factory = QueryFactoryStub() + new_relic = NewRelicStub(config) + event_type_fields_mapping = { 'event': ['field1'] } + numeric_fields_list = set(['field1', 'field2']) + + # execute + instance = integration.build_instance( + config, + auth_factory, + cache_factory, + pipeline_factory, + salesforce_factory, + query_factory, + new_relic, + DataFormat.EVENTS, + event_type_fields_mapping, + numeric_fields_list, + 603, + config['instances'][0], + 0, + ) + + # verify + self.assertTrue('client' in instance) + self.assertTrue('name' in instance) + self.assertEqual(instance['name'], 'test-inst-1') + client = instance['client'] + self.assertEqual(client.instance_name, 'test-inst-1') + inst_config = client.config + self.assertIsNotNone(inst_config) + self.assertEqual(inst_config.prefix, '') + self.assertTrue('token_url' in inst_config) + data_cache = client.data_cache + self.assertIsNotNone(data_cache) + self.assertIsNotNone(data_cache.config) + self.assertTrue('cache_enabled' in data_cache.config) + self.assertFalse(data_cache.config['cache_enabled']) + authenticator = client.authenticator + self.assertIsNotNone(authenticator) + self.assertEqual(authenticator.data_cache, data_cache) + self.assertIsNotNone(authenticator.config) + self.assertEqual( + authenticator.config['token_url'], + 'https://my.salesforce.test/token', + ) + p = client.pipeline + self.assertIsNotNone(p) + self.assertEqual(p.data_cache, data_cache) + self.assertEqual(p.new_relic, new_relic) + self.assertEqual(p.data_format, DataFormat.EVENTS) + self.assertIsNotNone(p.labels) + self.assertTrue(type(p.labels) is dict) + labels = p.labels + self.assertTrue('foo' in labels) + self.assertEqual(labels['foo'], 'bar') + self.assertTrue('beep' in labels) + self.assertEqual(labels['beep'], 'boop') + self.assertIsNotNone(p.event_type_fields_mapping) + self.assertTrue(type(p.event_type_fields_mapping) is dict) + event_type_fields_mapping = p.event_type_fields_mapping + self.assertTrue('event' in event_type_fields_mapping) + self.assertTrue(type(event_type_fields_mapping['event']) is list) + self.assertEqual(len(event_type_fields_mapping['event']), 1) + self.assertEqual(event_type_fields_mapping['event'][0], 'field1') + self.assertIsNotNone(p.numeric_fields_list) + self.assertTrue(type(p.numeric_fields_list) is set) + numeric_fields_list = p.numeric_fields_list + self.assertEqual(len(numeric_fields_list), 2) + self.assertTrue('field1' in numeric_fields_list) + self.assertTrue('field2' in numeric_fields_list) + self.assertEqual(client.query_factory, query_factory) + self.assertEqual(client.initial_delay, 603) + self.assertIsNone(client.queries) + + ''' + given: a Config instance, a NewRelic instance, a data format, a set of + event type field mappings, a set of numeric fields an initial + delay, an instance config, and a instance index + when: prefix provided in the instance config + then: the prefix should be set on the instance config object + ''' + + # setup + config = mod_config.Config({ + 'instances': [ + { + 'name': 'test-inst-1', + 'labels': { + 'foo': 'bar', + 'beep': 'boop', + }, + 'arguments': { + 'token_url': 'https://my.salesforce.test/token', + 'cache_enabled': False, + 'auth_env_prefix': 'ABCDEF_' + }, + } + ] + }) + auth_factory = AuthenticatorFactoryStub() + cache_factory = CacheFactoryStub() + pipeline_factory = PipelineFactoryStub() + salesforce_factory = SalesForceFactoryStub() + query_factory = QueryFactoryStub() + + # execute + instance = integration.build_instance( + config, + auth_factory, + cache_factory, + pipeline_factory, + salesforce_factory, + query_factory, + new_relic, + DataFormat.EVENTS, + event_type_fields_mapping, + numeric_fields_list, + 603, + config['instances'][0], + 0, + ) + + # verify + self.assertTrue('client' in instance) + client = instance['client'] + inst_config = client.config + self.assertIsNotNone(inst_config) + self.assertEqual(inst_config.prefix, 'ABCDEF_') + + ''' + given: a Config instance, a NewRelic instance, a data format, a set of + event type field mappings, a set of numeric fields an initial + delay, an instance config, and a instance index + when: queries provide in the config + then: queries should be set in the Salesforce client instance + ''' + + # setup + config = mod_config.Config({ + 'instances': [ + { + 'name': 'test-inst-1', + 'labels': { + 'foo': 'bar', + 'beep': 'boop', + }, + 'arguments': { + 'token_url': 'https://my.salesforce.test/token', + 'cache_enabled': False, + }, + } + ], + 'queries': [ + { + 'query': 'SELECT foo FROM Account' + } + ] + }) + auth_factory = AuthenticatorFactoryStub() + cache_factory = CacheFactoryStub() + pipeline_factory = PipelineFactoryStub() + salesforce_factory = SalesForceFactoryStub() + query_factory = QueryFactoryStub() + + # execute + instance = integration.build_instance( + config, + auth_factory, + cache_factory, + pipeline_factory, + salesforce_factory, + query_factory, + new_relic, + DataFormat.EVENTS, + event_type_fields_mapping, + numeric_fields_list, + 603, + config['instances'][0], + 0, + ) + + # verify + self.assertTrue('client' in instance) + client = instance['client'] + self.assertIsNotNone(client.queries) + self.assertEqual(len(client.queries), 1) + self.assertTrue('query' in client.queries[0]) + self.assertEqual(client.queries[0]['query'], 'SELECT foo FROM Account') + + def test_init(self): + ''' + given: a Config instance, set of factories, a data format, an event + type fields mapping, a set of numeric fields and an initial + delay + when: an integration instance is created + then: instances should be created with pipelines that use the correct + data format and a newrelic instance should be created with the + correct config instance + ''' + + # setup + config = mod_config.Config({ + 'instances': [ + { + 'name': 'test-inst-1', + 'labels': { + 'foo': 'bar', + 'beep': 'boop', + }, + 'arguments': { + 'token_url': 'https://my.salesforce.test/token', + 'cache_enabled': False, + }, + }, + { + 'name': 'test-inst-2', + 'labels': { + 'foo': 'bar', + 'beep': 'boop', + }, + 'arguments': { + 'token_url': 'https://my.salesforce.test/token', + 'cache_enabled': False, + }, + } + ], + 'newrelic': { + 'data_format': 'events', + } + }) + + auth_factory = AuthenticatorFactoryStub() + cache_factory = CacheFactoryStub() + pipeline_factory = PipelineFactoryStub() + salesforce_factory = SalesForceFactoryStub() + query_factory = QueryFactoryStub() + newrelic_factory = NewRelicFactoryStub() + event_type_fields_mapping = { 'event': ['field1'] } + numeric_fields_list = set(['field1', 'field2']) + + # execute + + i = integration.Integration( + config, + auth_factory, + cache_factory, + pipeline_factory, + salesforce_factory, + query_factory, + newrelic_factory, + event_type_fields_mapping, + numeric_fields_list, + 603 + ) + + # verify + + self.assertIsNotNone(i.instances) + self.assertTrue(type(i.instances) is list) + self.assertEqual(len(i.instances), 2) + self.assertTrue('client' in i.instances[0]) + self.assertTrue('name' in i.instances[0]) + self.assertEqual(i.instances[0]['name'], 'test-inst-1') + self.assertTrue('client' in i.instances[1]) + self.assertTrue('name' in i.instances[1]) + self.assertEqual(i.instances[1]['name'], 'test-inst-2') + client = i.instances[0]['client'] + self.assertEqual(client.pipeline.data_format, DataFormat.EVENTS) + client = i.instances[1]['client'] + self.assertEqual(client.pipeline.data_format, DataFormat.EVENTS) + self.assertIsNotNone(i.new_relic) + self.assertIsNotNone(i.new_relic.config) + self.assertEqual(i.new_relic.config, config) + + ''' + given: a Config instance, set of factories, a data format, an event + type fields mapping, a set of numeric fields and an initial + delay + when: an integration instance is created + then: an exception should be raised if an invalid data format is + passed + ''' + + # setup + config = mod_config.Config({ + 'instances': [ + { + 'name': 'test-inst-1', + 'labels': { + 'foo': 'bar', + 'beep': 'boop', + }, + 'arguments': { + 'token_url': 'https://my.salesforce.test/token', + 'cache_enabled': False, + }, + }, + ], + 'newrelic': { + 'data_format': 'invalid', + } + }) + + auth_factory = AuthenticatorFactoryStub() + cache_factory = CacheFactoryStub() + pipeline_factory = PipelineFactoryStub() + salesforce_factory = SalesForceFactoryStub() + query_factory = QueryFactoryStub() + newrelic_factory = NewRelicFactoryStub() + + # execute/verify + + with self.assertRaises(ConfigException): + i = integration.Integration( + config, + auth_factory, + cache_factory, + pipeline_factory, + salesforce_factory, + query_factory, + newrelic_factory, + event_type_fields_mapping, + numeric_fields_list, + 603 + ) + + ''' + given: a Config instance, set of factories, a data format, an event + type fields mapping, a set of numeric fields and an initial + delay + when: no instance configurations are provided + then: integration instances should be the empty set + ''' + + # setup + config = mod_config.Config({ + 'instances': [], + 'newrelic': { + 'data_format': 'logs', + } + }) + + auth_factory = AuthenticatorFactoryStub() + cache_factory = CacheFactoryStub() + pipeline_factory = PipelineFactoryStub() + salesforce_factory = SalesForceFactoryStub() + query_factory = QueryFactoryStub() + newrelic_factory = NewRelicFactoryStub() + + # execute + + i = integration.Integration( + config, + auth_factory, + cache_factory, + pipeline_factory, + salesforce_factory, + query_factory, + newrelic_factory, + event_type_fields_mapping, + numeric_fields_list, + 603 + ) + + # verify + + self.assertEqual(len(i.instances), 0) + + ''' + given: a Config instance, set of factories, a data format, an event + type fields mapping, a set of numeric fields and an initial + delay + when: instances property is missing + then: integration instances should be the empty set + ''' + + # setup + config = mod_config.Config({ + 'newrelic': { + 'data_format': 'logs', + } + }) + + auth_factory = AuthenticatorFactoryStub() + cache_factory = CacheFactoryStub() + pipeline_factory = PipelineFactoryStub() + salesforce_factory = SalesForceFactoryStub() + query_factory = QueryFactoryStub() + newrelic_factory = NewRelicFactoryStub() + + # execute + + i = integration.Integration( + config, + auth_factory, + cache_factory, + pipeline_factory, + salesforce_factory, + query_factory, + newrelic_factory, + event_type_fields_mapping, + numeric_fields_list, + 603 + ) + + # verify + + self.assertEqual(len(i.instances), 0) diff --git a/src/tests/test_pipeline.py b/src/tests/test_pipeline.py index 52d4f64..1f5eb3c 100644 --- a/src/tests/test_pipeline.py +++ b/src/tests/test_pipeline.py @@ -192,7 +192,7 @@ def test_export_log_lines(self): ''' # setup - session = SessionStub([]) + session = SessionStub() session.response = ResponseStub(500, 'Error', '', []) # execute/verify @@ -277,7 +277,7 @@ def test_transform_log_lines(self): QueryStub({}), '00001111AAAABBBB', 'ApexCallout', - DataCacheStub({ + DataCacheStub(cached_logs={ '00001111AAAABBBB': [ 'YYZ:abcdef123456' ] }), {}, @@ -612,8 +612,8 @@ def test_transform_event_records(self): self.event_records, QueryStub({}), DataCacheStub( - {}, - [ '000012345', '000054321' ] + cached_logs={}, + cached_events=[ '000012345', '000054321' ] ), ) @@ -1023,7 +1023,7 @@ def test_pipeline_process_log_record(self): # setup cfg = config.Config({}) - session = SessionStub([]) + session = SessionStub() session.response = ResponseStub(200, 'OK', '', self.log_rows) newrelic = NewRelicStub() query = QueryStub({}) @@ -1330,7 +1330,7 @@ def test_pipeline_execute(self): # setup cfg = config.Config({}) - session = SessionStub([]) + session = SessionStub() session.response = ResponseStub(200, 'OK', '', self.log_rows) newrelic = NewRelicStub() query = QueryStub({}) diff --git a/src/tests/test_salesforce.py b/src/tests/test_salesforce.py new file mode 100644 index 0000000..4f2a02d --- /dev/null +++ b/src/tests/test_salesforce.py @@ -0,0 +1,580 @@ +from datetime import datetime, timedelta +import unittest + +from . import \ + AuthenticatorStub, \ + DataCacheStub, \ + PipelineStub, \ + QueryStub, \ + QueryFactoryStub, \ + SessionStub +from newrelic_logging import config, salesforce, util + +class TestSalesforce(unittest.TestCase): + def test_init(self): + _now = datetime.utcnow() + + def _utcnow(): + nonlocal _now + return _now + + util._UTCNOW = _utcnow + + ''' + given: an instance name and configuration, a data cache, authenticator, + pipeline, query factory, initial delay value, and a set of + queries + when: values are present in the configuration for all relevant config + properties + and when: no data cache is specified + and when: no queries are specified + then: a new Salesforce instance is created with the correct values + ''' + + # setup + time_lag_minutes = 603 + initial_delay = 5 + cfg = config.Config({ + 'api_ver': '55.0', + 'time_lag_minutes': time_lag_minutes, + 'date_field': 'CreateDate', + 'generation_interval': 'Hourly', + }) + auth = AuthenticatorStub() + pipeline = PipelineStub() + data_cache = DataCacheStub() + query_factory = QueryFactoryStub() + last_to_timestamp = util.get_iso_date_with_offset( + time_lag_minutes, + initial_delay, + ) + + # execute + client = salesforce.SalesForce( + 'test_instance', + cfg, + None, + auth, + pipeline, + query_factory, + initial_delay, + ) + + # verify + self.assertEqual(client.instance_name, 'test_instance') + self.assertEqual(client.default_api_ver, '55.0') + self.assertEqual(client.data_cache, None) + self.assertEqual(client.auth, auth) + self.assertEqual(client.time_lag_minutes, time_lag_minutes) + self.assertEqual(client.date_field, 'CreateDate') + self.assertEqual(client.generation_interval, 'Hourly') + self.assertEqual(client.last_to_timestamp, last_to_timestamp) + self.assertTrue(len(client.queries) == 1) + self.assertTrue('query' in client.queries[0]) + self.assertEqual( + client.queries[0]['query'], + salesforce.SALESFORCE_CREATED_DATE_QUERY, + ) + + ''' + given: an instance name and configuration, a data cache, authenticator, + pipeline, query factory, initial delay value, and a set of + queries + and when: no data cache is specified + and when: no queries are specified + and when: no lag time is specified in the config + then: a new Salesforce instance is created with the default lag time + ''' + + # setup + cfg = config.Config({ + 'api_ver': '55.0', + 'date_field': 'CreateDate', + 'generation_interval': 'Hourly', + }) + last_to_timestamp = util.get_iso_date_with_offset( + config.DEFAULT_TIME_LAG_MINUTES, + initial_delay, + ) + + # execute + client = salesforce.SalesForce( + 'test_instance', + cfg, + None, + auth, + pipeline, + query_factory, + initial_delay, + ) + + # verify + self.assertEqual( + client.time_lag_minutes, + config.DEFAULT_TIME_LAG_MINUTES, + ) + self.assertEqual(client.last_to_timestamp, last_to_timestamp) + + ''' + given: an instance name and configuration, a data cache, authenticator, + pipeline, query factory, initial delay value, and a set of + queries + and when: a data cache is specified + and when: no queries are specified + and when: no lag time is specified in the config + then: a new Salesforce instance is created with the lag time set to 0 + ''' + + # setup + last_to_timestamp = util.get_iso_date_with_offset( + 0, + initial_delay, + ) + + # execute + client = salesforce.SalesForce( + 'test_instance', + cfg, + data_cache, + auth, + pipeline, + query_factory, + initial_delay, + ) + + # verify + self.assertEqual( + client.time_lag_minutes, + 0, + ) + self.assertEqual(client.last_to_timestamp, last_to_timestamp) + + ''' + given: an instance name and configuration, a data cache, authenticator, + pipeline, query factory, initial delay value, and a set of + queries + and when: no data cache is specified + and when: no queries are specified + and when: no date field is specified in the config + then: a new Salesforce instance is created with the default date field + ''' + + # setup + cfg = config.Config({ + 'api_ver': '55.0', + 'time_lag_minutes': time_lag_minutes, + 'generation_interval': 'Hourly', + }) + + # execute + client = salesforce.SalesForce( + 'test_instance', + cfg, + None, + auth, + pipeline, + query_factory, + initial_delay, + ) + + # verify + self.assertEqual(client.date_field, config.DATE_FIELD_LOG_DATE) + + ''' + given: an instance name and configuration, a data cache, authenticator, + pipeline, query factory, initial delay value, and a set of + queries + and when: a data cache is specified + and when: no queries are specified + and when: no date field is specified in the config + then: a new Salesforce instance is created with the default date field + ''' + + # execute + client = salesforce.SalesForce( + 'test_instance', + cfg, + data_cache, + auth, + pipeline, + query_factory, + initial_delay, + ) + + # verify + self.assertEqual(client.date_field, config.DATE_FIELD_CREATE_DATE) + + ''' + given: an instance name and configuration, a data cache, authenticator, + pipeline, query factory, initial delay value, and a set of + queries + and when: no data cache is specified + and when: no queries are specified + and when: no generation interval is specified + then: a new Salesforce instance is created with the default generation + interval + ''' + + # setup + cfg = config.Config({ + 'api_ver': '55.0', + 'time_lag_minutes': time_lag_minutes, + 'date_field': 'CreateDate' + }) + + # execute + client = salesforce.SalesForce( + 'test_instance', + cfg, + None, + auth, + pipeline, + query_factory, + initial_delay, + ) + + # verify + self.assertEqual( + client.generation_interval, + config.DEFAULT_GENERATION_INTERVAL, + ) + + ''' + given: an instance name and configuration, a data cache, authenticator, + pipeline, query factory, initial delay value, and a set of + queries + and when: no data cache is specified + and when: no queries are specified + and when: the date field option is set to 'LogDate' + then: a new Salesforce instance is created with the default log date + query + ''' + + # setup + cfg = config.Config({ + 'api_ver': '55.0', + 'time_lag_minutes': time_lag_minutes, + 'date_field': 'LogDate', + 'generation_interval': 'Hourly', + }) + + # execute + client = salesforce.SalesForce( + 'test_instance', + cfg, + None, + auth, + pipeline, + query_factory, + initial_delay, + ) + + # verify + self.assertTrue(len(client.queries) == 1) + self.assertTrue('query' in client.queries[0]) + self.assertEqual( + client.queries[0]['query'], + salesforce.SALESFORCE_LOG_DATE_QUERY, + ) + + ''' + given: an instance name and configuration, a data cache, authenticator, + pipeline, query factory, initial delay value, and a set of + queries + and when: no data cache is specified + and when: queries are specified + then: a new Salesforce instance is created with the specified queries + ''' + + # setup + queries = [ { 'query': 'foo' }, { 'query': 'bar' }] + + # execute + client = salesforce.SalesForce( + 'test_instance', + cfg, + None, + auth, + pipeline, + query_factory, + initial_delay, + queries + ) + + # verify + self.assertTrue(len(client.queries) == 2) + self.assertTrue('query' in client.queries[0]) + self.assertEqual(client.queries[0]['query'], 'foo') + self.assertTrue('query' in client.queries[1]) + self.assertEqual(client.queries[1]['query'], 'bar') + + def test_authenticate(self): + ''' + given: an instance name and configuration, a data cache, authenticator, + pipeline, query factory, initial delay value, set of queries, + and an http session + when: called + then: the underlying authenticator is called + ''' + + # setup + time_lag_minutes = 603 + initial_delay = 5 + cfg = config.Config({ + 'api_ver': '55.0', + 'time_lag_minutes': time_lag_minutes, + 'date_field': 'CreateDate', + 'generation_interval': 'Hourly', + }) + auth = AuthenticatorStub() + pipeline = PipelineStub() + query_factory = QueryFactoryStub() + session = SessionStub() + + client = salesforce.SalesForce( + 'test_instance', + cfg, + None, + auth, + pipeline, + query_factory, + initial_delay, + ) + + # execute + client.authenticate(session) + + # verify + self.assertTrue(auth.authenticate_called) + + def test_slide_time_range(self): + _now = datetime.utcnow() + + def _utcnow(): + nonlocal _now + return _now + + util._UTCNOW = _utcnow + + ''' + given: an instance name and configuration, data cache, authenticator, + pipeline, query factory, initial delay value, and a set of + queries + when: called + then: the 'last_to_timestamp' is updated + ''' + + # setup + time_lag_minutes = 603 + initial_delay = 5 + cfg = config.Config({ + 'api_ver': '55.0', + 'time_lag_minutes': time_lag_minutes, + 'date_field': 'CreateDate', + 'generation_interval': 'Hourly', + }) + auth = AuthenticatorStub() + pipeline = PipelineStub() + query_factory = QueryFactoryStub() + + client = salesforce.SalesForce( + 'test_instance', + cfg, + None, + auth, + pipeline, + query_factory, + initial_delay, + ) + + last_to_before = client.last_to_timestamp + + # pretend it's 10 minutes from now to ensure this is different from + # the timestamp calculated during object creation + _now = datetime.utcnow() + timedelta(minutes=10) + + # execute + client.slide_time_range() + + # verify + last_to_after = client.last_to_timestamp + + self.assertNotEqual(last_to_after, last_to_before) + + def test_fetch_logs(self): + ''' + given: an instance name and configuration, data cache, authenticator, + pipeline, query factory, initial delay value, a set of queries, + and an http session + when: the set of queries is the empty set + then: a default query should be executed + ''' + + # setup + time_lag_minutes = 603 + initial_delay = 5 + cfg = config.Config({ + 'api_ver': '55.0', + 'time_lag_minutes': time_lag_minutes, + 'date_field': 'CreateDate', + 'generation_interval': 'Hourly', + }) + auth = AuthenticatorStub() + pipeline = PipelineStub() + query_factory = QueryFactoryStub() + session = SessionStub() + + client = salesforce.SalesForce( + 'test_instance', + cfg, + None, + auth, + pipeline, + query_factory, + initial_delay, + ) + + # execute + client.fetch_logs(session) + + # verify + + self.assertEqual(len(query_factory.queries), 1) + query = query_factory.queries[0] + self.assertTrue(query.executed) + self.assertTrue(pipeline.executed) + self.assertEqual(len(pipeline.queries), 1) + self.assertEqual(query, pipeline.queries[0]) + + ''' + given: an instance name and configuration, data cache, authenticator, + pipeline, query factory, initial delay value, a set of queries, + and an http session + when: the set of queries is not the empty set + then: each query should be executed + ''' + + auth = AuthenticatorStub() + pipeline = PipelineStub() + query_factory = QueryFactoryStub() + session = SessionStub() + queries = [ + { + 'query': 'foo', + }, + { + 'query': 'bar', + }, + { + 'query': 'beep', + }, + { + 'query': 'boop', + }, + ] + + client = salesforce.SalesForce( + 'test_instance', + cfg, + None, + auth, + pipeline, + query_factory, + initial_delay, + queries, + ) + + # execute + client.fetch_logs(session) + + # verify + + self.assertEqual(len(query_factory.queries), 4) + + query1 = query_factory.queries[0] + self.assertEqual(query1.query, 'foo') + self.assertTrue(query1.executed) + query2 = query_factory.queries[1] + self.assertEqual(query2.query, 'bar') + self.assertTrue(query2.executed) + query3 = query_factory.queries[2] + self.assertEqual(query3.query, 'beep') + self.assertTrue(query3.executed) + query4 = query_factory.queries[3] + self.assertEqual(query4.query, 'boop') + self.assertTrue(query4.executed) + + self.assertTrue(pipeline.executed) + self.assertEqual(len(pipeline.queries), 4) + self.assertEqual(query1, pipeline.queries[0]) + self.assertEqual(query2, pipeline.queries[1]) + self.assertEqual(query3, pipeline.queries[2]) + self.assertEqual(query4, pipeline.queries[3]) + + ''' + given: an instance name and configuration, data cache, authenticator, + pipeline, query factory, initial delay value, a set of queries, + and an http session + when: no response is returned from a query + then: query should be executed and pipeline should not be executed + ''' + + auth = AuthenticatorStub() + pipeline = PipelineStub() + query = QueryStub(result=None) + query_factory = QueryFactoryStub(query) + session = SessionStub() + + client = salesforce.SalesForce( + 'test_instance', + cfg, + None, + auth, + pipeline, + query_factory, + initial_delay, + ) + + # execute + client.fetch_logs(session) + + # verify + + self.assertTrue(query.executed) + self.assertFalse(pipeline.executed) + + ''' + given: an instance name and configuration, data cache, authenticator, + pipeline, query factory, initial delay value, a set of queries, + and an http session + when: there is no 'records' attribute in response returned from query + then: query should be executed and pipeline should not be executed + ''' + + auth = AuthenticatorStub() + pipeline = PipelineStub() + query = QueryStub(result={ 'foo': 'bar' }) + query_factory = QueryFactoryStub(query) + session = SessionStub() + + client = salesforce.SalesForce( + 'test_instance', + cfg, + None, + auth, + pipeline, + query_factory, + initial_delay, + ) + + # execute + client.fetch_logs(session) + + # verify + + self.assertTrue(query.executed) + self.assertFalse(pipeline.executed) + + +if __name__ == '__main__': + unittest.main() diff --git a/src/tests/test_util.py b/src/tests/test_util.py index 7fc960f..4f9b923 100644 --- a/src/tests/test_util.py +++ b/src/tests/test_util.py @@ -1,3 +1,4 @@ +from datetime import datetime, timedelta import hashlib import unittest @@ -133,6 +134,87 @@ def test_maybe_convert_str_to_num(self): self.assertTrue(type(val) is str) self.assertEqual(val, 'not a number') + def test_get_iso_date_with_offset(self): + _now = datetime.utcnow() + + def _utcnow(): + nonlocal _now + return _now + + util._UTCNOW = _utcnow + + ''' + given: a time lag and initial delay + when: neither are specified (both default to 0) + then: return the current time in iso format + ''' + + # setup + val = _now.isoformat(timespec='milliseconds') + 'Z' + + # execute + isonow = util.get_iso_date_with_offset() + + # verify + self.assertEqual(val, isonow) + + ''' + given: a time lag and initial delay + when: time lag is specified + then: return the current time minus the time lag in iso format + ''' + + # setup + time_lag_minutes = 412 + val = (_now - timedelta(minutes=time_lag_minutes)) \ + .isoformat(timespec='milliseconds') + 'Z' + + # execute + isonow = util.get_iso_date_with_offset( + time_lag_minutes=time_lag_minutes + ) + + # verify + self.assertEqual(val, isonow) + + ''' + given: a time lag and initial delay + when: initial delay is specified + then: return the current time minus the initial delay in iso format + ''' + + # setup + initial_delay = 678 + val = (_now - timedelta(minutes=initial_delay)) \ + .isoformat(timespec='milliseconds') + 'Z' + + # execute + isonow = util.get_iso_date_with_offset(initial_delay=initial_delay) + + # verify + self.assertEqual(val, isonow) + + ''' + given: a time lag and initial delay + when: both are specified + then: return the current time minus the sum of the time lag and initial + delay in iso format + ''' + + # setup + initial_delay = 678 + val = (_now - timedelta(minutes=(time_lag_minutes + initial_delay))) \ + .isoformat(timespec='milliseconds') + 'Z' + + # execute + isonow = util.get_iso_date_with_offset( + time_lag_minutes, + initial_delay, + ) + + # verify + self.assertEqual(val, isonow) + if __name__ == '__main__': unittest.main() From bfa449b7820acc3676ac9226ddee1be87fcbc55d Mon Sep 17 00:00:00 2001 From: Scott DeWitt Date: Mon, 18 Mar 2024 09:30:04 -0400 Subject: [PATCH 05/11] fix: argument error calling DataCache constructor --- src/newrelic_logging/cache.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/newrelic_logging/cache.py b/src/newrelic_logging/cache.py index b9a49b5..33ee647 100644 --- a/src/newrelic_logging/cache.py +++ b/src/newrelic_logging/cache.py @@ -150,13 +150,13 @@ def new(self, config: Config): return DataCache( RedisBackend( redis.Redis( - host=host, - port=port, - db=db, - password=password, - ssl=ssl + host=host, + port=port, + db=db, + password=password, + ssl=ssl + ), ), expire_days) - ) print_info('Cache disabled') From 2529cf4d5efc6b90c3d7b67e07d07ebfd2394488 Mon Sep 17 00:00:00 2001 From: Scott DeWitt Date: Fri, 22 Mar 2024 09:17:16 -0400 Subject: [PATCH 06/11] feat: optimize memory usage parts 5 and 6: rework cache and add unit tests for cache and query --- src/__main__.py | 6 +- src/newrelic_logging/auth.py | 20 +- src/newrelic_logging/cache.py | 194 ++++---- src/newrelic_logging/pipeline.py | 7 +- src/newrelic_logging/query.py | 40 +- src/newrelic_logging/util.py | 4 +- src/tests/__init__.py | 110 ++++- src/tests/test_cache.py | 728 +++++++++++++++++++++++++++++++ src/tests/test_pipeline.py | 2 +- src/tests/test_query.py | 384 ++++++++++++++++ 10 files changed, 1359 insertions(+), 136 deletions(-) create mode 100644 src/tests/test_cache.py create mode 100644 src/tests/test_query.py diff --git a/src/__main__.py b/src/__main__.py index 6ad67a6..98c0823 100644 --- a/src/__main__.py +++ b/src/__main__.py @@ -14,7 +14,7 @@ from pytz import utc from yaml import Loader, load from newrelic_logging.auth import AuthenticatorFactory -from newrelic_logging.cache import CacheFactory +from newrelic_logging.cache import CacheFactory, BackendFactory from newrelic_logging.config import Config, getenv from newrelic_logging.newrelic import NewRelicFactory from newrelic_logging.pipeline import PipelineFactory @@ -138,7 +138,7 @@ def run_once( Integration( config, AuthenticatorFactory(), - CacheFactory(), + CacheFactory(BackendFactory()), PipelineFactory(), SalesForceFactory(), QueryFactory(), @@ -172,7 +172,7 @@ def run_as_service( Integration( config, AuthenticatorFactory(), - CacheFactory(), + CacheFactory(BackendFactory()), PipelineFactory(), SalesForceFactory(), QueryFactory(), diff --git a/src/newrelic_logging/auth.py b/src/newrelic_logging/auth.py index f1731e2..4803b8d 100644 --- a/src/newrelic_logging/auth.py +++ b/src/newrelic_logging/auth.py @@ -52,28 +52,26 @@ def clear_auth(self) -> None: if self.data_cache: try: - self.data_cache.redis.delete(AUTH_CACHE_KEY) + # @TODO need to change all the places where redis is explicitly + # referenced since this breaks encapsulation. + self.data_cache.backend.redis.delete(AUTH_CACHE_KEY) except Exception as e: print_warn(f'Failed deleting data from cache: {e}') def load_auth_from_cache(self) -> bool: try: - auth_exists = self.data_cache.redis.exists(AUTH_CACHE_KEY) + auth_exists = self.data_cache.backend.redis.exists(AUTH_CACHE_KEY) if auth_exists: print_info('Retrieving credentials from Redis.') - #NOTE: hmget and hgetall both return byte arrays, not strings. We have to convert. - # We could fix it by adding the argument "decode_responses=True" to Redis constructor, - # but then we would have to change all places where we assume a byte array instead of a string, - # and refactoring in a language without static types is a pain. try: - auth = self.data_cache.redis.hmget( + auth = self.data_cache.backend.redis.hmget( AUTH_CACHE_KEY, ['access_token', 'instance_url'], ) - self.set_auth( - auth[0].decode("utf-8"), - auth[1].decode("utf-8"), + self.set_auth_data( + auth[0], + auth[1], ) return True @@ -97,7 +95,7 @@ def store_auth(self, auth_resp: dict) -> None: } try: - self.data_cache.redis.hmset(AUTH_CACHE_KEY, auth) + self.data_cache.backend.redis.hmset(AUTH_CACHE_KEY, auth) except Exception as e: print_warn(f"Failed storing data in cache: {e}") diff --git a/src/newrelic_logging/cache.py b/src/newrelic_logging/cache.py index 33ee647..6f7a197 100644 --- a/src/newrelic_logging/cache.py +++ b/src/newrelic_logging/cache.py @@ -32,132 +32,144 @@ def exists(self, key): def put(self, key, item): self.redis.set(key, item) - def list_length(self, key): - return self.redis.llen(key) + def get_set(self, key): + return self.redis.smembers(key) - def list_slice(self, key, start, end): - return self.redis.lrange(key, start, end) - - def list_append(self, key, item): - self.redis.rpush(key, item) + def set_add(self, key, *values): + self.redis.sadd(key, *values) def set_expiry(self, key, days): self.redis.expire(key, timedelta(days=days)) +class BufferedAddSetCache: + def __init__(self, s: set): + self.s = s + self.buffer = set() + + def check_or_set(self, item: str) -> bool: + if item in self.s or item in self.buffer: + return True + + self.buffer.add(item) + + return False + + def get_buffer(self) -> set: + return self.buffer + + class DataCache: def __init__(self, backend, expiry): self.backend = backend self.expiry = expiry - self.cached_events = {} - self.cached_logs = {} + self.log_records = {} + self.event_records = None def can_skip_downloading_logfile(self, record_id: str) -> bool: try: - return self.backend.exists(record_id) and \ - self.backend.list_length(record_id) > 1 + return self.backend.exists(record_id) except Exception as e: raise CacheException(f'failed checking record {record_id}: {e}') - def load_cached_log_lines(self, record_id: str) -> None: + def check_or_set_log_line(self, record_id: str, line: dict) -> bool: try: - if self.backend.exists(record_id): - self.cached_logs[record_id] = \ - self.backend.list_slice(record_id, 0, -1) - return + if not record_id in self.log_records: + self.log_records[record_id] = BufferedAddSetCache( + self.backend.get_set(record_id), + ) - self.cached_logs[record_id] = ['init'] + return self.log_records[record_id].check_or_set(line['REQUEST_ID']) except Exception as e: - raise CacheException(f'failed checking log record {record_id}: {e}') + raise CacheException(f'failed checking record {record_id}: {e}') - # Cache log - # @TODO this function assumes you have called load_cached_log_lines - # which isn't obvious. - def check_and_set_log_line(self, record_id: str, row: dict) -> bool: - row_id = row["REQUEST_ID"] + def check_or_set_event_id(self, record_id: str) -> bool: + try: + if not self.event_records: + self.event_records = BufferedAddSetCache( + self.backend.get_set('event_ids'), + ) - if row_id.encode('utf-8') in self.cached_logs[record_id]: - return True + return self.event_records.check_or_set(record_id) + except Exception as e: + raise CacheException(f'failed checking record {record_id}: {e}') - self.cached_logs[record_id].append(row_id) + def flush(self) -> None: + try: + for record_id in self.log_records: + buf = self.log_records[record_id].get_buffer() + if len(buf) > 0: + self.backend.set_add(record_id, *buf) - return False + self.backend.set_expiry(record_id, self.expiry) - # Cache event - def check_and_set_event_id(self, record_id: str) -> bool: - try: - if self.backend.exists(record_id): - return True + if self.event_records: + buf = self.event_records.get_buffer() + if len(buf) > 0: + for id in buf: + self.backend.put(id, 1) + self.backend.set_expiry(id, self.expiry) + + self.backend.set_add('event_ids', *buf) + self.backend.set_expiry('event_ids', self.expiry) + + # attempt to reclaim memory + for record_id in self.log_records: + self.log_records[record_id] = None - self.cached_events[record_id] = '' + self.log_records = {} + self.event_records = None - return False + gc.collect() except Exception as e: - raise CacheException(f'failed checking record {record_id}: {e}') + raise CacheException(f'failed flushing cache: {e}') - def flush(self) -> None: - # Flush cached log line ids for each log record - for record_id in self.cached_logs: - for row_id in self.cached_logs[record_id]: - try: - self.backend.list_append(record_id, row_id) - - # Set expire date for the whole list only once, when we find - # the first entry ('init') - if row_id == 'init': - self.backend.set_expiry(record_id, self.expiry) - except Exception as e: - raise CacheException( - f'failed pushing row {row_id} for record {record_id}: {e}' - ) - - # Attempt to release memory - del self.cached_logs[record_id] - - # Flush any cached event record ids - for record_id in self.cached_events: - try: - self.backend.put(record_id, '') - self.backend.set_expiry(record_id, self.expiry) - # Attempt to release memory - del self.cached_events[record_id] - except Exception as e: - raise CacheException(f"failed setting record {record_id}: {e}") +class BackendFactory: + def __init__(self): + pass - # Run a gc in an attempt to reclaim memory - gc.collect() + def new(self, config: Config): + host = config.get(CONFIG_REDIS_HOST, DEFAULT_REDIS_HOST) + port = config.get_int(CONFIG_REDIS_PORT, DEFAULT_REDIS_PORT) + db = config.get_int(CONFIG_REDIS_DB_NUMBER, DEFAULT_REDIS_DB_NUMBER) + password = config.get(CONFIG_REDIS_PASSWORD) + ssl = config.get_bool(CONFIG_REDIS_USE_SSL, DEFAULT_REDIS_SSL) + password_display = "XXXXXX" if password != None else None + + print_info( + f'connecting to redis instance {host}:{port}:{db}, ssl={ssl}, password={password_display}' + ) + + return RedisBackend( + redis.Redis( + host=host, + port=port, + db=db, + password=password, + ssl=ssl, + decode_responses=True, + ), + + ) class CacheFactory: - def __init__(self): + def __init__(self, backend_factory): + self.backend_factory = backend_factory pass def new(self, config: Config): - if config.get_bool(CONFIG_CACHE_ENABLED, DEFAULT_CACHE_ENABLED): - host = config.get(CONFIG_REDIS_HOST, DEFAULT_REDIS_HOST) - port = config.get_int(CONFIG_REDIS_PORT, DEFAULT_REDIS_PORT) - db = config.get_int(CONFIG_REDIS_DB_NUMBER, DEFAULT_REDIS_DB_NUMBER) - password = config.get(CONFIG_REDIS_PASSWORD) - ssl = config.get_bool(CONFIG_REDIS_USE_SSL, DEFAULT_REDIS_SSL) - expire_days = config.get_int(CONFIG_REDIS_EXPIRE_DAYS) - password_display = "XXXXXX" if password != None else None - - print_info( - f'Cache enabled, connecting to redis instance {host}:{port}:{db}, ssl={ssl}, password={password_display}' - ) + if not config.get_bool(CONFIG_CACHE_ENABLED, DEFAULT_CACHE_ENABLED): + print_info('Cache disabled') + return None + print_info('Cache enabled') + + try: return DataCache( - RedisBackend( - redis.Redis( - host=host, - port=port, - db=db, - password=password, - ssl=ssl - ), - ), expire_days) - - print_info('Cache disabled') - - return None + self.backend_factory.new(config), + config.get_int(CONFIG_REDIS_EXPIRE_DAYS) + ) + except Exception as e: + raise CacheException(f'failed creating backend: {e}') diff --git a/src/newrelic_logging/pipeline.py b/src/newrelic_logging/pipeline.py index 0dd32a3..27622b7 100644 --- a/src/newrelic_logging/pipeline.py +++ b/src/newrelic_logging/pipeline.py @@ -132,7 +132,7 @@ def transform_log_lines( for row in reader: # If we've already seen this log line, skip it - if data_cache and data_cache.check_and_set_log_line(record_id, row): + if data_cache and data_cache.check_or_set_log_line(record_id, row): continue # Otherwise, pack it up for shipping and yield it for consumption @@ -205,7 +205,7 @@ def transform_event_records(iter, query: Query, data_cache: DataCache): ) # If we've already seen this event record, skip it. - if data_cache and data_cache.check_and_set_event_id(record_id): + if data_cache and data_cache.check_or_set_event_id(record_id): continue # Build a New Relic log record from the SF event record @@ -403,9 +403,6 @@ def process_log_record( ) return None - if self.data_cache: - self.data_cache.load_cached_log_lines(record_id) - load_data( transform_log_lines( export_log_lines( diff --git a/src/newrelic_logging/query.py b/src/newrelic_logging/query.py index 5e886a8..ce75a2a 100644 --- a/src/newrelic_logging/query.py +++ b/src/newrelic_logging/query.py @@ -1,5 +1,4 @@ import copy -from datetime import datetime, timedelta from requests import RequestException, Session from . import SalesforceApiException @@ -59,6 +58,24 @@ class QueryFactory: def __init__(self): pass + def build_args( + self, + time_lag_minutes: int, + last_to_timestamp: str, + generation_interval: str, + ): + return { + 'to_timestamp': get_iso_date_with_offset(time_lag_minutes), + 'from_timestamp': last_to_timestamp, + 'log_interval_type': generation_interval, + } + + def get_env(self, q: dict) -> dict: + if 'env' in q and type(q['env']) is dict: + return q['env'] + + return {} + def new( self, q: dict, @@ -67,22 +84,19 @@ def new( generation_interval: str, default_api_ver: str, ) -> Query: - to_timestamp = get_iso_date_with_offset(time_lag_minutes) - from_timestamp = last_to_timestamp - qp = copy.deepcopy(q) qq = qp.pop('query', '') - args = { - 'to_timestamp': to_timestamp, - 'from_timestamp': from_timestamp, - 'log_interval_type': generation_interval, - } - - env = qp['env'] if 'env' in qp and type(qp['env']) is dict else {} - return Query( - substitute(args, qq, env).replace(' ', '+'), + substitute( + self.build_args( + time_lag_minutes, + last_to_timestamp, + generation_interval, + ), + qq, + self.get_env(qp), + ).replace(' ', '+'), Config(qp), qp.get('api_ver', default_api_ver) ) diff --git a/src/newrelic_logging/util.py b/src/newrelic_logging/util.py index d77864b..bc6932a 100644 --- a/src/newrelic_logging/util.py +++ b/src/newrelic_logging/util.py @@ -83,9 +83,9 @@ def sf_time(t: datetime): def now(delta: timedelta = None): if delta: - return sf_time(datetime.utcnow() + delta) + return sf_time(_UTCNOW() + delta) else: - return sf_time(datetime.utcnow()) + return sf_time(_UTCNOW()) try: return eval(code) diff --git a/src/tests/__init__.py b/src/tests/__init__.py index 50688f1..297f4d3 100644 --- a/src/tests/__init__.py +++ b/src/tests/__init__.py @@ -1,4 +1,7 @@ -from requests import Session +from datetime import timedelta +import json +from redis import RedisError +from requests import Session, RequestException from newrelic_logging import DataFormat from newrelic_logging.auth import Authenticator @@ -77,27 +80,21 @@ def __init__( cached_logs = {}, cached_events = [], skip_record_ids = [], - cached_log_lines = {}, ): self.config = config self.cached_logs = cached_logs self.cached_events = cached_events self.skip_record_ids = skip_record_ids - self.cached_log_lines = cached_log_lines self.flush_called = False def can_skip_downloading_logfile(self, record_id: str) -> bool: return record_id in self.skip_record_ids - def load_cached_log_lines(self, record_id: str) -> None: - if record_id in self.cached_log_lines: - self.cached_logs[record_id] = self.cached_log_lines[record_id] - - def check_and_set_log_line(self, record_id: str, row: dict) -> bool: + def check_or_set_log_line(self, record_id: str, row: dict) -> bool: return record_id in self.cached_logs and \ row['REQUEST_ID'] in self.cached_logs[record_id] - def check_and_set_event_id(self, record_id: str) -> bool: + def check_or_set_event_id(self, record_id: str) -> bool: return record_id in self.cached_events def flush(self) -> None: @@ -243,6 +240,87 @@ def new( ) +class RedisStub: + def __init__(self, test_cache, raise_error = False): + self.expiry = {} + self.test_cache = test_cache + self.raise_error = raise_error + + def exists(self, key): + if self.raise_error: + raise RedisError('raise_error set') + + return key in self.test_cache + + def set(self, key, item): + if self.raise_error: + raise RedisError('raise_error set') + + self.test_cache[key] = item + + def smembers(self, key): + if self.raise_error: + raise RedisError('raise_error set') + + if not key in self.test_cache: + return set() + + if not type(self.test_cache[key]) is set: + raise RedisError(f'{key} is not a set') + + return self.test_cache[key] + + def sadd(self, key, *values): + if self.raise_error: + raise RedisError('raise_error set') + + if key in self.test_cache and not type(self.test_cache[key]) is set: + raise RedisError(f'{key} is not a set') + + if not key in self.test_cache: + self.test_cache[key] = set() + + for v in values: + self.test_cache[key].add(v) + + def expire(self, key, time): + if self.raise_error: + raise RedisError('raise_error set') + + self.expiry[key] = time + + +class BackendStub: + def __init__(self, test_cache, raise_error = False): + self.redis = RedisStub(test_cache, raise_error) + + def exists(self, key): + return self.redis.exists(key) + + def put(self, key, item): + self.redis.set(key, item) + + def get_set(self, key): + return self.redis.smembers(key) + + def set_add(self, key, *values): + self.redis.sadd(key, *values) + + def set_expiry(self, key, days): + self.redis.expire(key, timedelta(days=days)) + + +class BackendFactoryStub: + def __init__(self, raise_error = False): + self.raise_error = raise_error + + def new(self, _: Config): + if self.raise_error: + raise RedisError('raise_error set') + + return BackendStub({}) + + class ResponseStub: def __init__(self, status_code, reason, text, lines): self.status_code = status_code @@ -253,6 +331,9 @@ def __init__(self, status_code, reason, text, lines): def iter_lines(self, *args, **kwargs): yield from self.lines + def json(self, *args, **kwargs): + return json.loads(self.text) + class SalesForceStub: def __init__( @@ -304,8 +385,17 @@ def new( class SessionStub: - def __init__(self): + def __init__(self, raise_error=False): + self.raise_error = raise_error self.response = None + self.headers = None + self.url = None def get(self, *args, **kwargs): + if self.raise_error: + raise RequestException('raise_error set') + + self.url = args[0] + self.headers = kwargs['headers'] + return self.response diff --git a/src/tests/test_cache.py b/src/tests/test_cache.py new file mode 100644 index 0000000..9d75950 --- /dev/null +++ b/src/tests/test_cache.py @@ -0,0 +1,728 @@ +from datetime import timedelta +from redis import RedisError +import unittest + +from newrelic_logging import cache, CacheException, config as mod_config +from . import RedisStub, BackendStub, BackendFactoryStub + +class TestRedisBackend(unittest.TestCase): + def test_exists(self): + ''' + backend exists returns redis exists + given: a redis instance + when: exists is called + then: the redis instance exists command result is returned + ''' + + # setup + redis = RedisStub({ 'foo': 'bar', 'beep': 'boop' }) + + # execute + backend = cache.RedisBackend(redis) + foo_exists = backend.exists('foo') + baz_exists = backend.exists('baz') + + # verify + self.assertTrue(foo_exists) + self.assertFalse(baz_exists) + + def test_exists_raises_when_redis_does(self): + ''' + backend exists raises error if redis exists does + given: a redis instance + when: exists is called + and when: the redis instance raises an exception + then: the redis instance exception is raised + ''' + + # setup + redis = RedisStub({ 'foo': 'bar', 'beep': 'boop' }, raise_error=True) + + # execute + backend = cache.RedisBackend(redis) + + # verify + with self.assertRaises(RedisError): + backend.exists('foo') + + def test_put(self): + ''' + backend put calls redis set + given: a redis instance + when: put is called + then: the redis instance set command is called + ''' + + # setup + redis = RedisStub({ 'foo': 'bar', 'beep': 'boop' }) + + # preconditions + self.assertFalse(redis.exists('r2')) + + # execute + backend = cache.RedisBackend(redis) + backend.put('r2', 'd2') + + # verify + self.assertTrue(redis.exists('r2')) + self.assertEqual(redis.test_cache['r2'], 'd2') + + def test_put_raises_if_redis_does(self): + ''' + backend put raises error if redis set does + given: a redis instance + when: put is called + and when: the redis instance raises an exception + then: the redis instance exception is raised + ''' + + # setup + redis = RedisStub({ 'foo': 'bar', 'beep': 'boop' }, raise_error=True) + + # execute + backend = cache.RedisBackend(redis) + + # verify + with self.assertRaises(RedisError) as _: + backend.put('r2', 'd2') + + + def test_get_set(self): + ''' + backend calls redis smembers + given: a redis instance + when: get_set is called + and when: key does not exist + or when: key exists and is set + then: the redis instance smembers command result is returned + ''' + + # setup + redis = RedisStub({ 'foo': set(['bar']) }) + + # execute + backend = cache.RedisBackend(redis) + foo_set = backend.get_set('foo') + beep_set = backend.get_set('beep') + + # verify + self.assertEqual(len(foo_set), 1) + self.assertEqual(foo_set, set(['bar'])) + self.assertEqual(len(beep_set), 0) + self.assertEqual(beep_set, set()) + + def test_get_set_raises_if_redis_does(self): + ''' + backend raises exception if redis smembers does + given: a redis instance + when: get_set is called + and when: redis instance raises an exception + then: the redis instance exception is raised + ''' + + # setup + redis = RedisStub({ 'foo': 'bar' }) + + # execute + backend = cache.RedisBackend(redis) + + # verify + with self.assertRaises(RedisError) as _: + backend.get_set('foo') + + def test_set_add(self): + ''' + backend set_add calls sadd + given: a redis instance + when: set_add is called + and when: key does not exist + or when: key exists and is set + then: the redis instance sadd command is called + ''' + + # setup + redis = RedisStub({ 'foo': set(['bar', 'beep', 'boop']) }) + + # execute + backend = cache.RedisBackend(redis) + backend.set_add('foo', 'baz') + backend.set_add('bop', 'biz') + + args = [1, 2, 3] + backend.set_add('bip', *args) + + # verify + self.assertEqual( + redis.test_cache['foo'], set(['bar', 'beep', 'boop', 'baz']), + ) + self.assertEqual(redis.test_cache['bop'], set(['biz'])) + self.assertEqual(redis.test_cache['bip'], set([1, 2, 3])) + + def test_set_add_raises_if_redis_does(self): + ''' + backend raises exception if redis sadd does + given: a redis instance + when: set_add is called + and when: redis instance raises an exception + then: the redis instance exception is raised + ''' + + # setup + redis = RedisStub({ 'foo': 'bar' }) + + # execute + backend = cache.RedisBackend(redis) + + # verify + with self.assertRaises(RedisError) as _: + backend.set_add('foo', 'beep') + + def test_set_expiry(self): + ''' + backend set_expiry calls expire + given: a redis instance + when: set_expiry is called + then: the redis instance expire command is called + ''' + + # setup + redis = RedisStub({ 'foo': set('bar') }) + + # execute + backend = cache.RedisBackend(redis) + backend.set_expiry('foo', 5) + + # verify + self.assertTrue('foo' in redis.expiry) + self.assertEqual(redis.expiry['foo'], timedelta(5)) + + def test_set_expiry_raises_if_redis_does(self): + ''' + backend set_expiry raises exception if redis expire does + given: a redis instance + when: set_expiry is called + and when: redis instance raises an exception + then: the redis instance exception is raised + ''' + + # setup + redis = RedisStub({ 'foo': set('bar') }, raise_error=True) + + # execute + backend = cache.RedisBackend(redis) + + # verify + with self.assertRaises(RedisError) as _: + backend.set_expiry('foo', 5) + + +class TestBufferedAddSetCache(unittest.TestCase): + def test_check_or_set_true_when_item_exists(self): + ''' + check_or_set returns true when item exists + given: a set + and when: the set contains the item to be checked + then: returns true + ''' + + # execute + s = cache.BufferedAddSetCache(set(['foo'])) + contains_foo = s.check_or_set('foo') + + # verify + self.assertTrue(contains_foo) + + def test_check_or_set_false_and_adds_item_when_item_missing(self): + ''' + check_or_set returns false and adds item when item is not in set + given: a set + when: the set does not contain the item to be checked + then: returns false + and then: the item is in the set + ''' + + # execute + s = cache.BufferedAddSetCache(set()) + contains_foo = s.check_or_set('foo') + + # verify + self.assertFalse(contains_foo) + self.assertTrue('foo' in s.get_buffer()) + + def test_check_or_set_checks_both_sets(self): + ''' + check_or_set checks both the given set and buffer for each check_or_set + given: a set + when: the set does not contain the item to be checked + then: returns false + and then: the item is in the set + and when: the item is added again + then: returns true + ''' + + # execute + s = cache.BufferedAddSetCache(set()) + contains_foo = s.check_or_set('foo') + + # verify + self.assertFalse(contains_foo) + self.assertTrue('foo' in s.get_buffer()) + + # execute + contains_foo = s.check_or_set('foo') + + # verify + # this verifies that check_or_set also checks the buffer + self.assertTrue(contains_foo) + + +class TestDataCache(unittest.TestCase): + def test_can_skip_download_logfile_true_when_key_exists(self): + ''' + dl logfile returns true if key exists in cache + given: a backend instance + when: can_skip_download_logfile is called + and when: the key exists in the cache + then: return true + ''' + + # setup + backend = BackendStub({ 'foo': ['bar'] }) + + # execute + data_cache = cache.DataCache(backend, 5) + can = data_cache.can_skip_downloading_logfile('foo') + + # verify + self.assertTrue(can) + + def test_can_skip_download_logfile_false_when_key_missing(self): + ''' + dl logfile returns false if key does not exist + given: a backend instance + when: can_skip_download_logfile is called + and when: the key does not exist in the backend + then: return false + ''' + + # setup + backend = BackendStub({}) + + # execute + data_cache = cache.DataCache(backend, 5) + can = data_cache.can_skip_downloading_logfile('foo') + + # verify + self.assertFalse(can) + + def test_can_skip_download_logfile_raises_if_backend_does(self): + ''' + dl logfile raises CacheException if backend raises any Exception + given: a backend instance + when: can_skip_download_logfile is called + and when: backend raises any exception + then: a CacheException is raised + ''' + + # setup + backend = BackendStub({}, raise_error=True) + + # execute + data_cache = cache.DataCache(backend, 5) + + # verify + with self.assertRaises(CacheException) as _: + data_cache.can_skip_downloading_logfile('foo') + + def test_check_or_set_log_line_true_when_exists(self): + ''' + check_or_set_log_line returns true when line ID is in the cached set + given: a backend instance + when: check_or_set_log_line is called + and when: row['REQUEST_ID'] is in the set for key + then: returns true + ''' + + # setup + backend = BackendStub({ 'foo': set(['bar']) }) + line = { 'REQUEST_ID': 'bar' } + + # execute + data_cache = cache.DataCache(backend, 5) + val = data_cache.check_or_set_log_line('foo', line) + + # verify + self.assertTrue(val) + + def test_check_or_set_log_line_false_and_adds_when_missing(self): + ''' + check_or_set_log_line returns false and adds line ID when line ID is not in the cached set + given: a backend instance + when: check_or_set_log_line is called + and when: row['REQUEST_ID'] is not in set for key + then: the line ID is added to the set and false is returned + ''' + + # setup + backend = BackendStub({ 'foo': set() }) + row = { 'REQUEST_ID': 'bar' } + + # preconditions + self.assertFalse('bar' in backend.redis.test_cache['foo']) + + # execute + data_cache = cache.DataCache(backend, 5) + val = data_cache.check_or_set_log_line('foo', row) + + # verify + self.assertFalse(val) + + # Need to flush before we can check the cache as how it is stored + # in memory is an implementation detail + data_cache.flush() + self.assertTrue('bar' in backend.redis.test_cache['foo']) + + def test_check_or_set_log_line_raises_if_backend_does(self): + ''' + check_or_set_log_line raises CacheException if backend raises any Exception + given: a backend instance + when: check_or_set_log_line is called + and when: backend raises any exception + then: a CacheException is raised + ''' + + # setup + backend = BackendStub({}, raise_error=True) + line = { 'REQUEST_ID': 'bar' } + + # execute / verify + data_cache = cache.DataCache(backend, 5) + + with self.assertRaises(CacheException) as _: + data_cache.check_or_set_log_line('foo', line) + + def test_check_or_set_event_id_true_when_exists(self): + ''' + check_or_set_event_id returns true when event ID is in the cached set + given: a backend instance + when: check_or_set_event_id is called + and when: event ID is in the set 'event_ids' + then: returns true + ''' + + # setup + backend = BackendStub({ 'event_ids': set(['foo']) }) + + # execute + data_cache = cache.DataCache(backend, 5) + val = data_cache.check_or_set_event_id('foo') + + # verify + self.assertTrue(val) + + def test_check_or_set_event_id_false_and_adds_when_missing(self): + ''' + check_or_set_event_id returns false and adds event ID when event ID is not in the cached set + given: a backend instance + when: check_or_set_event_id is called + and when: event ID is not in set 'event_ids' + then: the event ID is added to the set and False is returned + ''' + + # setup + backend = BackendStub({ 'event_ids': set() }) + + # preconditions + self.assertFalse('foo' in backend.redis.test_cache['event_ids']) + + # execute + data_cache = cache.DataCache(backend, 5) + val = data_cache.check_or_set_event_id('foo') + + # verify + self.assertFalse(val) + + # Need to flush before we can check the cache as how it is stored + # in memory is an implementation detail + data_cache.flush() + self.assertTrue('foo' in backend.redis.test_cache['event_ids']) + + def test_check_or_set_event_id_raises_if_backend_does(self): + ''' + check_or_set_event_id raises CacheException if backend raises any Exception + given: a backend instance + when: check_or_set_event_id is called + and when: backend raises any exception + then: a CacheException is raised + ''' + + # setup + backend = BackendStub({}, raise_error=True) + + # execute / verify + data_cache = cache.DataCache(backend, 5) + + with self.assertRaises(CacheException) as _: + data_cache.check_or_set_event_id('foo') + + def test_flush_does_not_affect_cache_when_add_buffers_empty(self): + ''' + backend cache is empty if flush is called when BufferedAddSet buffers are empty + given: a backend instance + when: flush is called + and when: the log lines and event ID add buffers are empty + then: the backend cache remains empty + ''' + + # setup + backend = BackendStub({}) + + # preconditions + self.assertEqual(len(backend.redis.test_cache), 0) + + # execute + data_cache = cache.DataCache(backend, 5) + data_cache.flush() + + # verify + self.assertEqual(len(backend.redis.test_cache), 0) + + def test_flush_writes_log_lines_when_add_buffer_not_empty(self): + ''' + flush writes any buffered log lines when log lines add buffer is not empty + given: a backend instance + when: flush is called + and when: the log lines add buffer is not empty + then: the backend cache is updated with items from the add buffer + ''' + + # setup + backend = BackendStub({}) + line1 = { 'REQUEST_ID': 'bar1' } + line2 = { 'REQUEST_ID': 'bar2' } + line3 = { 'REQUEST_ID': 'boop' } + + # preconditions + self.assertEqual(len(backend.redis.test_cache), 0) + + # execute + data_cache = cache.DataCache(backend, 5) + data_cache.check_or_set_log_line('foo', line1) + data_cache.check_or_set_log_line('foo', line2) + data_cache.check_or_set_log_line('beep', line3) + data_cache.flush() + + # verify + self.assertEqual(len(backend.redis.test_cache), 2) + self.assertTrue('foo' in backend.redis.test_cache) + self.assertEqual(backend.redis.test_cache['foo'], set(['bar1', 'bar2'])) + self.assertTrue('beep' in backend.redis.test_cache) + self.assertEqual(backend.redis.test_cache['beep'], set(['boop'])) + + def test_flush_writes_event_ids_when_add_buffer_not_empty(self): + ''' + flush writes any buffered event IDs when event IDs add buffer is not empty + given: a backend instance + when: flush is called + and when: the event IDs add buffer is not empty + then: the backend cache is updated with items from the add buffer + ''' + + # setup + backend = BackendStub({}) + + # preconditions + self.assertEqual(len(backend.redis.test_cache), 0) + + # execute + data_cache = cache.DataCache(backend, 5) + data_cache.check_or_set_event_id('foo') + data_cache.check_or_set_event_id('bar') + data_cache.check_or_set_event_id('beep') + data_cache.check_or_set_event_id('boop') + data_cache.flush() + + # verify + self.assertEqual(len(backend.redis.test_cache), 5) + self.assertTrue('event_ids' in backend.redis.test_cache) + self.assertEqual( + backend.redis.test_cache['event_ids'], + set(['foo', 'bar', 'beep', 'boop']) + ) + self.assertTrue('foo' in backend.redis.test_cache) + self.assertEqual(backend.redis.test_cache['foo'], 1) + self.assertTrue('bar' in backend.redis.test_cache) + self.assertEqual(backend.redis.test_cache['bar'], 1) + self.assertTrue('beep' in backend.redis.test_cache) + self.assertEqual(backend.redis.test_cache['beep'], 1) + self.assertTrue('boop' in backend.redis.test_cache) + self.assertEqual(backend.redis.test_cache['boop'], 1) + + def test_flush_sets_expiry_on_write(self): + ''' + flush sets the expiry time of any keys it writes + given: a backend instance and expiry + when: flush is called + and when: add buffers are not empty + then: the backend cache is updated with new cached sets with specified expiration time + ''' + + # setup + backend = BackendStub({}) + line1 = { 'REQUEST_ID': 'bar' } + + # preconditions + self.assertEqual(len(backend.redis.test_cache), 0) + self.assertEqual(len(backend.redis.expiry), 0) + + # execute + data_cache = cache.DataCache(backend, 5) + data_cache.check_or_set_log_line('foo', line1) + data_cache.check_or_set_event_id('bar') + data_cache.flush() + + # verify + self.assertEqual(len(backend.redis.test_cache), 3) + self.assertTrue('foo' in backend.redis.test_cache) + self.assertEqual(backend.redis.test_cache['foo'], set(['bar'])) + self.assertTrue('event_ids' in backend.redis.test_cache) + self.assertEqual(backend.redis.test_cache['event_ids'], set(['bar'])) + self.assertTrue('bar' in backend.redis.test_cache) + self.assertEqual(backend.redis.test_cache['bar'], 1) + self.assertEqual(len(backend.redis.expiry), 3) + self.assertTrue('foo' in backend.redis.expiry) + self.assertEqual(backend.redis.expiry['foo'], timedelta(days=5)) + self.assertTrue('bar' in backend.redis.expiry) + self.assertEqual(backend.redis.expiry['bar'], timedelta(days=5)) + self.assertTrue('event_ids' in backend.redis.expiry) + self.assertEqual(backend.redis.expiry['event_ids'], timedelta(days=5)) + + def test_flush_does_not_write_dups(self): + ''' + flush only writes items from add buffers + given: a backend instance + when: flush is called + and when: add buffers are not empty + then: the backend cache is updated with ONLY the items from the add buffers + ''' + + # setup + backend = BackendStub({ + 'foo': set(['bar', 'baz']), + 'beep': 1, + 'boop': 1, + 'event_ids': set(['beep', 'boop']) + }) + line1 = { 'REQUEST_ID': 'bip' } + line2 = { 'REQUEST_ID': 'bop' } + + # execute + data_cache = cache.DataCache(backend, 5) + data_cache.check_or_set_log_line('foo', line1) + data_cache.check_or_set_log_line('foo', line2) + data_cache.check_or_set_event_id('bim') + data_cache.check_or_set_event_id('bam') + data_cache.flush() + + # verify + self.assertEqual(len(backend.redis.test_cache), 6) + self.assertTrue('foo' in backend.redis.test_cache) + self.assertEqual( + backend.redis.test_cache['foo'], + set(['bar', 'baz', 'bip', 'bop']), + ) + self.assertTrue('event_ids' in backend.redis.test_cache) + self.assertEqual( + backend.redis.test_cache['event_ids'], + set(['beep', 'boop', 'bim', 'bam']), + ) + self.assertTrue('bim' in backend.redis.expiry) + self.assertEqual(backend.redis.expiry['bim'], timedelta(days=5)) + self.assertTrue('bam' in backend.redis.expiry) + self.assertEqual(backend.redis.expiry['bam'], timedelta(days=5)) + + def test_flush_raises_if_backend_does(self): + ''' + flush raises CacheException if backend raises any Exception + given: a backend instance + when: flush is called + and when: backend raises any exception + then: a CacheException is raised + ''' + + # setup + backend = BackendStub({}) + + # execute / verify + data_cache = cache.DataCache(backend, 5) + + # have to add data to be cached before setting raise_error + data_cache.check_or_set_event_id('foo') + + # now set error and execute/verify + backend.redis.raise_error = True + + with self.assertRaises(CacheException) as _: + data_cache.flush() + + +class TestCacheFactory(unittest.TestCase): + def test_new_returns_none_if_disabled(self): + ''' + new returns None if cache disabled + given: a backend factory and a configuration + when: new is called + and when: cache_enabled is False in config + then: None is returned + ''' + + # setup + config = mod_config.Config({ 'cache_enabled': 'false' }) + backend_factory = BackendFactoryStub() + + # execute + cache_factory = cache.CacheFactory(backend_factory) + data_cache = cache_factory.new(config) + + # verify + self.assertIsNone(data_cache) + + def test_new_returns_data_cache_with_backend_if_enabled(self): + ''' + new returns DataCache instance with backend from specified backend factory if cache enabled + given: a backend factory and a configuration + when: new is called + and when: cache_enabled is True in config + then: a DataCache is returned with the a backend from the specified backend factory + ''' + + # setup + config = mod_config.Config({ 'cache_enabled': 'true' }) + backend_factory = BackendFactoryStub() + + # execute + cache_factory = cache.CacheFactory(backend_factory) + data_cache = cache_factory.new(config) + + # verify + self.assertIsNotNone(data_cache) + self.assertTrue(type(data_cache.backend) is BackendStub) + + def test_new_raises_if_backend_factory_does(self): + ''' + new raises CacheException if backend factory does + given: a backend factory and a configuration + when: new is called + and when: backend factory new() raises any exception + then: a CacheException is raised + ''' + + # setup + config = mod_config.Config({ 'cache_enabled': 'true' }) + backend_factory = BackendFactoryStub(raise_error=True) + + # execute / verify + cache_factory = cache.CacheFactory(backend_factory) + + with self.assertRaises(CacheException): + _ = cache_factory.new(config) diff --git a/src/tests/test_pipeline.py b/src/tests/test_pipeline.py index 1f5eb3c..f0642d1 100644 --- a/src/tests/test_pipeline.py +++ b/src/tests/test_pipeline.py @@ -1193,7 +1193,7 @@ def test_pipeline_process_log_record(self): # setup data_cache = DataCacheStub( - cached_log_lines={ + cached_logs={ '00001111AAAABBBB': ['YYZ:abcdef123456', 'YYZ:fedcba654321'] } ) diff --git a/src/tests/test_query.py b/src/tests/test_query.py new file mode 100644 index 0000000..74fc5c7 --- /dev/null +++ b/src/tests/test_query.py @@ -0,0 +1,384 @@ +from datetime import datetime +import unittest + +from newrelic_logging import SalesforceApiException +from newrelic_logging import config as mod_config, query, util +from . import \ + ResponseStub, \ + SessionStub + +class TestQuery(unittest.TestCase): + def test_get_returns_backing_config_value_when_key_exists(self): + ''' + get() returns the value of the key in the backing config when the key exists + given: a query string, a configuration, and an api version + and given: a key + when: get is called with the key + then: returns value of the key from backing config + ''' + + # setup + config = mod_config.Config({ 'foo': 'bar' }) + + # execute + q = query.Query( + 'SELECT+LogFile+FROM+EventLogFile', + config, + '55.0', + ) + val = q.get('foo') + + # verify + self.assertEqual(val, 'bar') + + def test_get_returns_backing_config_default_when_key_missing(self): + ''' + get() returns the default value passed to the backing config.get when key does not exist in the backing config + given: a query string, a configuration, and an api version + when: get is called with a key and a default value + and when: the key does not exist in the backing config + then: returns default value passed to the backing config.get + ''' + + # setup + config = mod_config.Config({}) + + # execute + q = query.Query( + 'SELECT+LogFile+FROM+EventLogFile', + config, + '55.0', + ) + val = q.get('foo', 'beep') + + # verify + self.assertEqual(val, 'beep') + + def test_execute_raises_exception_on_non_200_response(self): + ''' + execute() raises exception on non-200 status code from Salesforce API + given: a query string, a configuration, and an api version + when: execute() is called with an http session, an instance url, and an access token + and when: the response produces a non-200 status code + then: raise a SalesforceApiException + ''' + + # setup + config = mod_config.Config({}) + session = SessionStub() + session.response = ResponseStub(500, 'Error', '', []) + + # execute/verify + q = query.Query( + 'SELECT+LogFile+FROM+EventLogFile', + config, + '55.0', + ) + + with self.assertRaises(SalesforceApiException) as _: + q.execute(session, 'https://my.salesforce.test', '123456') + + def test_execute_raises_exception_if_session_get_does(self): + ''' + execute() raises exception if session.get() raises a RequestException + given: a query string, a configuration, and an api version + when: execute() is called with an http session, an instance url, and an access token + and when: session.get() raises a RequestException + then: raise a SalesforceApiException + ''' + + # setup + config = mod_config.Config({}) + session = SessionStub(raise_error=True) + session.response = ResponseStub(200, 'OK', '[]', [] ) + + # execute/verify + q = query.Query( + 'SELECT+LogFile+FROM+EventLogFile', + config, + '55.0', + ) + + with self.assertRaises(SalesforceApiException) as _: + q.execute(session, 'https://my.salesforce.test', '123456') + + def test_execute_calls_query_api_url_with_token_and_returns_json_response(self): + ''' + execute() calls the correct query API url with the access token and returns a json response + given: a query string, a configuration, and an api version + when: execute() is called with an http session, an instance url, and an access token + then: a get request is made to the correct API url with the given access token and returns a json response + ''' + + # setup + config = mod_config.Config({}) + session = SessionStub() + session.response = ResponseStub(200, 'OK', '{"foo": "bar"}', [] ) + + # execute + q = query.Query( + 'SELECT+LogFile+FROM+EventLogFile', + config, + '55.0', + ) + + resp = q.execute(session, 'https://my.salesforce.test', '123456') + + # verify + self.assertEqual( + session.url, + f'https://my.salesforce.test/services/data/v55.0/query?q=SELECT+LogFile+FROM+EventLogFile', + ) + self.assertTrue('Authorization' in session.headers) + self.assertEqual(session.headers['Authorization'], 'Bearer 123456') + self.assertIsNotNone(resp) + self.assertTrue(type(resp) is dict) + self.assertTrue('foo' in resp) + self.assertEqual(resp['foo'], 'bar') + + +class TestQueryFactory(unittest.TestCase): + def test_build_args_creates_expected_dict(self): + ''' + build_args() returns dictionary with expected properties + given: a query factory + when: build_args() is called with a time lag, timestamp, and generation interval + then: returns dict with expected properties + ''' + + # setup + _now = datetime.utcnow() + + def _utcnow(): + nonlocal _now + return _now + + util._UTCNOW = _utcnow + + time_lag_minutes = 500 + last_to_timestamp = util.get_iso_date_with_offset( + time_lag_minutes=time_lag_minutes * 2, + ) + to_timestamp = util.get_iso_date_with_offset( + time_lag_minutes=time_lag_minutes, + ) + + # execute + f = query.QueryFactory() + args = f.build_args(time_lag_minutes, last_to_timestamp, 'Daily') + + # verify + self.assertIsNotNone(args) + self.assertTrue(type(args) is dict) + self.assertTrue('to_timestamp' in args) + self.assertEqual(args['to_timestamp'], to_timestamp) + self.assertTrue('from_timestamp' in args) + self.assertEqual(args['from_timestamp'], last_to_timestamp) + self.assertTrue('log_interval_type' in args) + self.assertEqual(args['log_interval_type'], 'Daily') + + def test_get_env_returns_empty_dict_if_no_env(self): + ''' + get_env() returns an empty dict if env is not in the passed query dict + given: a query factory + when: get_env() is called with a query dict + and when: there is no env property in the query dict + then: returns an empty dict + ''' + + # setup + q = { 'query': 'SELECT LogFile FROM EventLogFile' } + + # execute + f = query.QueryFactory() + env = f.get_env(q) + + # verify + self.assertIsNotNone(env) + self.assertTrue(type(env) is dict) + self.assertEqual(len(env), 0) + + def test_get_env_returns_empty_dict_if_env_not_dict(self): + ''' + get_env() returns an empty dict if the passed query dict has an env property but it is not a dict + given: a query factory + when: get_env() is called with a query dict + and when: there is an env property in the query dict + and when: the env property is not a dict + then: returns an empty dict + ''' + + # setup + q = { 'query': 'SELECT LogFile FROM EventLogFile', 'env': 'foo' } + + # execute + f = query.QueryFactory() + env = f.get_env(q) + + # verify + self.assertIsNotNone(env) + self.assertTrue(type(env) is dict) + self.assertEqual(len(env), 0) + + def test_get_env_returns_env_dict_from_query_dict(self): + ''' + get_env() returns the env dict from the query dict when one is present + given: a query factory + when: get_env() is called with a query dict + and when: there is an env property in the query dict + and when: the env property is a dict + then: returns the env dict + ''' + + # setup + q = { + 'query': 'SELECT LogFile FROM EventLogFile', + 'env': { 'foo': 'bar' }, + } + + # execute + f = query.QueryFactory() + env = f.get_env(q) + + # verify + self.assertIsNotNone(env) + self.assertTrue(type(env) is dict) + self.assertEqual(len(env), 1) + self.assertTrue('foo' in env) + self.assertEqual(env['foo'], 'bar') + + def test_new_returns_query_obj_with_encoded_query_with_args_replaced(self): + ''' + new() returns a query instance with the given query with arguments replaced and URL encoded + given: a query factory + when: new() is called with a query dict, lag time, timestamp, generation interval, and default api version + then: returns a query instance with the input query with arguments replaced and URL encoded + ''' + + # setup + _now = datetime.utcnow() + + def _utcnow(): + nonlocal _now + return _now + + util._UTCNOW = _utcnow + + to_timestamp = util.get_iso_date_with_offset(time_lag_minutes=500) + last_to_timestamp = util.get_iso_date_with_offset(time_lag_minutes=1000) + now = _now.isoformat(timespec='milliseconds') + "Z" + env = { 'foo': 'now()' } + + # execute + f = query.QueryFactory() + q = f.new( + { + 'query': 'SELECT LogFile FROM EventLogFile WHERE CreatedDate>={from_timestamp} AND CreatedDate<{to_timestamp} AND LogIntervalType={log_interval_type} AND Foo={foo}', + 'env': env, + }, + 500, + last_to_timestamp, + 'Daily', + '55.0', + ) + + # verify + self.assertEqual( + q.query, + f'SELECT+LogFile+FROM+EventLogFile+WHERE+CreatedDate>={last_to_timestamp}+AND+CreatedDate<{to_timestamp}+AND+LogIntervalType=Daily+AND+Foo={now}' + ) + + def test_new_returns_query_obj_with_expected_config(self): + ''' + new() returns a query instance with the input query dict minus the query property + given: a query factory + when: new() is called with a query dict, lag time, timestamp, generation interval, and default api version + then: returns a query instance with a config equal to the input query dict minus the query property + ''' + + # setup + last_to_timestamp = util.get_iso_date_with_offset(time_lag_minutes=1000) + + # execute + f = query.QueryFactory() + q = f.new( + { + 'query': 'SELECT LogFile FROM EventLogFile', + 'foo': 'bar', + 'beep': 'boop', + 'bip': 0, + 'bop': 5, + }, + 500, + last_to_timestamp, + 'Daily', + '55.0', + ) + + # verify + config = q.get_config() + + self.assertIsNotNone(config) + self.assertTrue(type(config) is mod_config.Config) + self.assertFalse('query' in config) + self.assertTrue('foo' in config) + self.assertEqual(config['foo'], 'bar') + self.assertTrue('beep' in config) + self.assertEqual(config['beep'], 'boop') + self.assertTrue('bip' in config) + self.assertEqual(config['bip'], 0) + self.assertTrue('bop' in config) + self.assertEqual(config['bop'], 5) + + def test_new_returns_query_obj_with_given_api_ver(self): + ''' + new() returns a query instance with the api version specified in the query dict + given: a query factory + when: new() is called with a query dict, lag time, timestamp, generation interval, and default api version + then: returns a query instance with the api version specified in the query dict + ''' + + # setup + last_to_timestamp = util.get_iso_date_with_offset(time_lag_minutes=1000) + + # execute + f = query.QueryFactory() + q = f.new( + { + 'query': 'SELECT LogFile FROM EventLogFile', + 'api_ver': '58.0' + }, + 500, + last_to_timestamp, + 'Daily', + '53.0', + ) + + # verify + self.assertEqual(q.api_ver, '58.0') + + def test_new_returns_query_obj_with_default_api_ver(self): + ''' + new() returns a query instance with the default api version specified on the new() call + given: a query factory + when: new() is called with a query dict, lag time, timestamp, generation interval, and default api version + then: returns a query instance with the default api version specified on the the new() call + ''' + + # setup + last_to_timestamp = util.get_iso_date_with_offset(time_lag_minutes=1000) + + # execute + f = query.QueryFactory() + q = f.new( + { + 'query': 'SELECT LogFile FROM EventLogFile', + }, + 500, + last_to_timestamp, + 'Daily', + '53.0', + ) + + # verify + self.assertEqual(q.api_ver, '53.0') From c23588c815c109862edb9be69111d37e87305b69 Mon Sep 17 00:00:00 2001 From: Scott DeWitt Date: Thu, 28 Mar 2024 09:05:12 -0400 Subject: [PATCH 07/11] feat: support nested fields in event records --- src/newrelic_logging/pipeline.py | 21 +- src/newrelic_logging/util.py | 69 ++++++- src/tests/sample_event_records.json | 51 ++++- src/tests/test_pipeline.py | 238 +++++++++++++++++------ src/tests/test_util.py | 289 ++++++++++++++++++++++++++++ 5 files changed, 593 insertions(+), 75 deletions(-) diff --git a/src/newrelic_logging/pipeline.py b/src/newrelic_logging/pipeline.py index 27622b7..51caecc 100644 --- a/src/newrelic_logging/pipeline.py +++ b/src/newrelic_logging/pipeline.py @@ -14,7 +14,9 @@ from .telemetry import print_info from .util import generate_record_id, \ is_logfile_response, \ - maybe_convert_str_to_num + maybe_convert_str_to_num, \ + process_query_result, \ + get_timestamp DEFAULT_CHUNK_SIZE = 4096 @@ -153,15 +155,13 @@ def pack_event_record_into_log( record_id: str, record: dict, ) -> dict: - # Make a copy of it so we aren't modifying the row passed by the caller, and - # set attributes appropriately - attrs = deepcopy(record) + attrs = process_query_result(record) if record_id: attrs['Id'] = record_id message = query.get('event_type', 'SFEvent') - if 'attributes' in attrs and type(attrs['attributes']) == dict: - attributes = attrs.pop('attributes') + if 'attributes' in record and type(record['attributes']) == dict: + attributes = record['attributes'] if 'type' in attributes and type(attributes['type']) == str: attrs['EVENT_TYPE'] = message = \ query.get('event_type', attributes['type']) @@ -170,15 +170,12 @@ def pack_event_record_into_log( if timestamp_attr in attrs: created_date = attrs[timestamp_attr] message += f' {created_date}' - timestamp = int(datetime.strptime( - created_date, - '%Y-%m-%dT%H:%M:%S.%f%z').timestamp() * 1000, - ) + timestamp = get_timestamp(created_date) else: - timestamp = int(datetime.now().timestamp() * 1000) + timestamp = get_timestamp() timestamp_field_name = query.get('rename_timestamp', 'timestamp') - attrs[timestamp_field_name] = int(timestamp) + attrs[timestamp_field_name] = timestamp log_entry = { 'message': message, diff --git a/src/newrelic_logging/util.py b/src/newrelic_logging/util.py index bc6932a..d12d74d 100644 --- a/src/newrelic_logging/util.py +++ b/src/newrelic_logging/util.py @@ -1,9 +1,12 @@ +from copy import deepcopy from datetime import datetime, timedelta import hashlib -from typing import Union +from typing import Any, Union from .telemetry import print_warn +PRIMITIVE_TYPES = (str, int, float, bool, type(None)) + def is_logfile_response(records): if len(records) > 0: @@ -41,6 +44,51 @@ def maybe_convert_str_to_num(val: str) -> Union[int, str, float]: return val +def is_primitive(val: Any) -> bool: + vt = type(val) + + for t in PRIMITIVE_TYPES: + if vt == t: + return True + + return False + + +def process_query_result_helper( + item: tuple[str, Any], + name: list[str] = [], + vals: list[tuple[str, Any]] = [], +) -> list[tuple[str, Any]]: + (k, v) = item + + if k == 'attributes': + return vals + + if is_primitive(v): + return vals + [('.'.join(name + [k]), v)] + + if not type(v) is dict: + print_warn(f'ignoring structured element {k} in query result') + return vals + + new_vals = vals + + for item0 in v.items(): + new_vals = process_query_result_helper(item0, name + [k], new_vals) + + return new_vals + + +def process_query_result(query_result: dict) -> dict: + out = {} + + for item in query_result.items(): + for (k, v) in process_query_result_helper(item): + out[k] = v + + return out + + # Make testing easier def _utcnow(): return datetime.utcnow() @@ -60,6 +108,25 @@ def get_iso_date_with_offset( ) + 'Z' +# Make testing easier +def _now(): + return datetime.now() + +_NOW = _now + + +def get_timestamp(date_string: str = None): + if not date_string: + return int(_NOW().timestamp() * 1000) + + return int( + datetime.strptime( + date_string, + '%Y-%m-%dT%H:%M:%S.%f%z' + ).timestamp() * 1000 + ) + + # NOTE: this sandbox can be jailbroken using the trick to exec statements inside # an exec block, and run an import (and other tricks): # https://book.hacktricks.xyz/generic-methodologies-and-resources/python/bypass-python-sandboxes#operators-and-short-tricks diff --git a/src/tests/sample_event_records.json b/src/tests/sample_event_records.json index 7f2b206..5f7b1a1 100644 --- a/src/tests/sample_event_records.json +++ b/src/tests/sample_event_records.json @@ -7,7 +7,22 @@ "Id": "000012345", "Name": "My Account", "BillingCity": null, - "CreatedDate": "2024-03-11T00:00:00.000+0000" + "CreatedDate": "2024-03-11T00:00:00.000+0000", + "CreatedBy": { + "attributes": { + "type": "User", + "url": "/services/data/v55.0/sobjects/User/12345" + }, + "Name": "Foo Bar", + "Profile": { + "attributes": { + "type": "Profile", + "url": "/services/data/v55.0/sobjects/Profile/12345" + }, + "Name": "Beep Boop" + }, + "UserType": "Bip Bop" + } }, { "attributes": { @@ -17,7 +32,22 @@ "Id": "000054321", "Name": "My Other Account", "BillingCity": null, - "CreatedDate": "2024-03-10T00:00:00.000+0000" + "CreatedDate": "2024-03-10T00:00:00.000+0000", + "CreatedBy": { + "attributes": { + "type": "User", + "url": "/services/data/v55.0/sobjects/User/12345" + }, + "Name": "Foo Bar", + "Profile": { + "attributes": { + "type": "Profile", + "url": "/services/data/v55.0/sobjects/Profile/12345" + }, + "Name": "Beep Boop" + }, + "UserType": "Bip Bop" + } }, { "attributes": { @@ -26,6 +56,21 @@ }, "Name": "My Last Account", "BillingCity": null, - "CreatedDate": "2024-03-09T00:00:00.000+0000" + "CreatedDate": "2024-03-09T00:00:00.000+0000", + "CreatedBy": { + "attributes": { + "type": "User", + "url": "/services/data/v55.0/sobjects/User/12345" + }, + "Name": "Foo Bar", + "Profile": { + "attributes": { + "type": "Profile", + "url": "/services/data/v55.0/sobjects/Profile/12345" + }, + "Name": "Beep Boop" + }, + "UserType": "Bip Bop" + } } ] diff --git a/src/tests/test_pipeline.py b/src/tests/test_pipeline.py index f0642d1..c81d459 100644 --- a/src/tests/test_pipeline.py +++ b/src/tests/test_pipeline.py @@ -297,16 +297,35 @@ def test_transform_log_lines(self): def test_pack_event_record_into_log(self): + created_date = self.event_records[0]['CreatedDate'] + timestamp = util.get_timestamp(created_date) + + base_expected_attrs = { + 'Id': '00001111AAAABBBB', + 'Name': 'My Account', + 'BillingCity': None, + 'CreatedDate': created_date, + 'CreatedBy.Name': 'Foo Bar', + 'CreatedBy.Profile.Name': 'Beep Boop', + 'CreatedBy.UserType': 'Bip Bop', + 'EVENT_TYPE': 'Account', + 'timestamp': timestamp, + } + ''' given: a query, record id and event record - when: there are no query options and the event record contains a 'type' - field in 'attributes' + when: there are no query options, the record id is not None, and the + event record contains a 'type' field in the 'attributes' field then: return a log with the 'message' attribute set to the event type - specified in the 'type' field + the created date, all attributes - from the original event record minus the 'attributes' field set in - the log 'attributes' field as well as the passed record id, and a - 'timestamp' field with the epoch value representing the - 'CreatedDate' field + specified in the record's 'attributes.type' field + the created + date; where the 'attributes' attribute contains all attributes + according to process_query_result, an 'Id' attribute set to the + passed record ID, an 'EVENT_TYPE' attribute set to the event type + specified in the record's 'attributes.type' field, and a + 'timestamp' attribute set to the epoch value representing the + record's 'CreatedDate' field; and with the 'timestamp' attribute + also set to the epoch value representing the record's + 'CreatedDate' field. ''' # setup @@ -320,39 +339,26 @@ def test_pack_event_record_into_log(self): ) # verify - created_date = self.event_records[0]['CreatedDate'] - timestamp = int(datetime.strptime( - created_date, - '%Y-%m-%dT%H:%M:%S.%f%z').timestamp() * 1000, - ) self.assertTrue('message' in log) self.assertTrue('attributes' in log) self.assertTrue('timestamp' in log) self.assertEqual(log['message'], f'Account {created_date}') - self.assertEqual(timestamp, log['timestamp']) - - attrs = log['attributes'] - self.assertTrue(not 'attributes' in attrs) - self.assertTrue('Id' in attrs) - self.assertTrue('Name' in attrs) - self.assertTrue('BillingCity' in attrs) - self.assertTrue('CreatedDate' in attrs) - self.assertEqual('00001111AAAABBBB', attrs['Id']) - self.assertEqual('My Account', attrs['Name']) - self.assertEqual(None, attrs['BillingCity']) - self.assertEqual('2024-03-11T00:00:00.000+0000', attrs['CreatedDate']) + self.assertEqual(log['attributes'], base_expected_attrs) + self.assertEqual(log['timestamp'], timestamp) ''' given: a query, record id, and an event record - when: the record id is empty and there are no query options and the - event record contains a 'type' field in 'attributes' - then: return a log as in use case 1 but with no 'Id' value in the - log 'attributes' field + when: there are no query options, the record id is None, and the + event record contains a 'type' field in the 'attributes' field + then: return a log as in use case 1 but with the 'Id' value in the + log 'attributes' attribute set to the 'Id' value from the event + record ''' # setup event_record = copy.deepcopy(self.event_records[0]) - event_record.pop('Id') + expected_attrs = copy.deepcopy(base_expected_attrs) + expected_attrs['Id'] = '000012345' # execute log = pipeline.pack_event_record_into_log( @@ -362,18 +368,29 @@ def test_pack_event_record_into_log(self): ) # verify - self.assertTrue(not 'Id' in log['attributes']) + self.assertTrue('message' in log) + self.assertTrue('attributes' in log) + self.assertTrue('timestamp' in log) + self.assertEqual(log['message'], f'Account {created_date}') + self.assertEqual(log['attributes'], expected_attrs) + self.assertEqual(log['timestamp'], timestamp) ''' given: a query, record id, and an event record - when: the 'event_type' query option is specified + when: the 'event_type' query option is specified, the record id is not + None, and the event record contains a 'type' field in the + 'attributes' field then: return a log as in use case 1 but with the event type in the log - message set to the custom event type specified in the 'event_type' - query option plus the created date. + 'message' attribute set to the custom event type specified in the + 'event_type' query option plus the created date, and with the + 'EVENT_TYPE' attribute in the log 'attributes' attribute set to + the custom event type ''' # setup event_record = copy.deepcopy(self.event_records[0]) + expected_attrs = copy.deepcopy(base_expected_attrs) + expected_attrs['EVENT_TYPE'] = 'CustomEvent' # execute log = pipeline.pack_event_record_into_log( @@ -383,19 +400,28 @@ def test_pack_event_record_into_log(self): ) # verify + self.assertTrue('message' in log) + self.assertTrue('attributes' in log) + self.assertTrue('timestamp' in log) self.assertEqual(log['message'], f'CustomEvent {created_date}') + self.assertEqual(log['attributes'], expected_attrs) + self.assertEqual(log['timestamp'], timestamp) ''' given: a query, record id, and an event record - when: the event record does not contain an 'attributes' field + when: there are no query options, the record id is not None, and the + event record does not contain an 'attributes' field then: return a log as in use case 1 but with the event type in the log - message set to the default event type specified in the - 'event_type' query option plus the created date. + 'message' attribute set to the default event type plus the created + date, and with no 'EVENT_TYPE' attribute in the log 'attributes' + attribute ''' # setup event_record = copy.deepcopy(self.event_records[0]) event_record.pop('attributes') + expected_attrs = copy.deepcopy(base_expected_attrs) + expected_attrs.pop('EVENT_TYPE') # execute log = pipeline.pack_event_record_into_log( @@ -405,18 +431,26 @@ def test_pack_event_record_into_log(self): ) # verify + self.assertTrue('message' in log) + self.assertTrue('attributes' in log) + self.assertTrue('timestamp' in log) self.assertEqual(log['message'], f'SFEvent {created_date}') + self.assertEqual(log['attributes'], expected_attrs) + self.assertEqual(log['timestamp'], timestamp) ''' given: a query, record id, and an event record - when: the event record does contains an 'attributes' field but it is not - a dictionary + when: there are no query options, the record id is not None, and the + event record does contain an 'attributes' field but it is not a + dictionary then: return a log as in the previous use case ''' # setup event_record = copy.deepcopy(self.event_records[0]) event_record['attributes'] = 'test' + expected_attrs = copy.deepcopy(base_expected_attrs) + expected_attrs.pop('EVENT_TYPE') # execute log = pipeline.pack_event_record_into_log( @@ -426,11 +460,17 @@ def test_pack_event_record_into_log(self): ) # verify + self.assertTrue('message' in log) + self.assertTrue('attributes' in log) + self.assertTrue('timestamp' in log) self.assertEqual(log['message'], f'SFEvent {created_date}') + self.assertEqual(log['attributes'], expected_attrs) + self.assertEqual(log['timestamp'], timestamp) ''' given: a query, record id, and an event record - when: the event record does contains an 'type' field in the 'attributes' + when: there are no query options, the record id is not None, and the + event record does not contain a 'type' field in the 'attributes' field then: return a log as in the previous use case ''' @@ -438,7 +478,8 @@ def test_pack_event_record_into_log(self): # setup event_record = copy.deepcopy(self.event_records[0]) event_record['attributes'].pop('type') - + expected_attrs = copy.deepcopy(base_expected_attrs) + expected_attrs.pop('EVENT_TYPE') # execute log = pipeline.pack_event_record_into_log( @@ -448,11 +489,17 @@ def test_pack_event_record_into_log(self): ) # verify + self.assertTrue('message' in log) + self.assertTrue('attributes' in log) + self.assertTrue('timestamp' in log) self.assertEqual(log['message'], f'SFEvent {created_date}') + self.assertEqual(log['attributes'], expected_attrs) + self.assertEqual(log['timestamp'], timestamp) ''' given: a query, record id, and an event record - when: the event record contains a 'type' field in the 'attributes' + when: there are no query options, the record id is not None, and the + event record does contain a 'type' field in the 'attributes' field but it is not a string then: return a log as in the previous use case ''' @@ -460,6 +507,8 @@ def test_pack_event_record_into_log(self): # setup event_record = copy.deepcopy(self.event_records[0]) event_record['attributes']['type'] = 12345 + expected_attrs = copy.deepcopy(base_expected_attrs) + expected_attrs.pop('EVENT_TYPE') # execute log = pipeline.pack_event_record_into_log( @@ -469,19 +518,73 @@ def test_pack_event_record_into_log(self): ) # verify + self.assertTrue('message' in log) + self.assertTrue('attributes' in log) + self.assertTrue('timestamp' in log) self.assertEqual(log['message'], f'SFEvent {created_date}') + self.assertEqual(log['attributes'], expected_attrs) + self.assertEqual(log['timestamp'], timestamp) + + ''' + given: a query, record id, and an event record + when: the 'timestamp_attr' query option is specified, the field + specified in the 'timestamp_attr' query option exists in the event + record, the record id is not None, and the event record contains a + 'type' field in the 'attributes' field + then: return a log as in use case 1 but using the timestamp from the + field specified in the 'timestamp_attr' query option. + ''' + + # setup + __now = datetime.now() + + def _now(): + nonlocal __now + return __now + + util._NOW = _now + + created_date_2 = self.event_records[1]['CreatedDate'] + timestamp = util.get_timestamp(created_date_2) + + event_record = copy.deepcopy(self.event_records[0]) + event_record['CustomDate'] = created_date_2 + expected_attrs = copy.deepcopy(base_expected_attrs) + expected_attrs['CustomDate'] = created_date_2 + expected_attrs['timestamp'] = timestamp + + # execute + log = pipeline.pack_event_record_into_log( + QueryStub({ 'timestamp_attr': 'CustomDate' }), + '00001111AAAABBBB', + event_record + ) + + # verify + self.assertTrue('message' in log) + self.assertTrue('attributes' in log) + self.assertTrue('timestamp' in log) + self.assertEqual(log['message'], f'Account {created_date_2}') + self.assertEqual(log['attributes'], expected_attrs) + self.assertEqual(log['timestamp'], timestamp) ''' given: a query, record id, and an event record when: the 'timestamp_attr' query option is specified but the specified - attribute name is not in the event record - then: return a log as in use case 1 but the message does not contain a - created date and contains a 'timestamp' field set to the current - time. + attribute name is not in the event record, the record id is not + None, and the event record contains a 'type' field in the + 'attributes' field + then: return a log as in use case 1 but the log 'message' attribute does + not contain a date, the 'timestamp' attribute of the log + 'attributes' attribute is set to the current time, and with the + log 'timestamp' attribute set to the current time. ''' # setup event_record = copy.deepcopy(self.event_records[0]) + expected_attrs = copy.deepcopy(base_expected_attrs) + timestamp = util.get_timestamp() + expected_attrs['timestamp'] = timestamp # execute log = pipeline.pack_event_record_into_log( @@ -491,21 +594,27 @@ def test_pack_event_record_into_log(self): ) # verify - timestamp = int(datetime.now().timestamp() * 1000) - self.assertEqual(log['message'], f'Account') + self.assertTrue('message' in log) + self.assertTrue('attributes' in log) self.assertTrue('timestamp' in log) - self.assertTrue(log['timestamp'] <= timestamp) + self.assertEqual(log['message'], f'Account') + self.assertEqual(log['attributes'], expected_attrs) + self.assertEqual(log['timestamp'], timestamp) ''' given: a query, record id, and an event record - when: no query options are specified and the event record does not - contain a 'CreatedDate' field + when: no query options are specified, the record id is not None, and the + event record does not contain a 'CreatedDate' field then: return the same as the previous use case ''' # setup event_record = copy.deepcopy(self.event_records[0]) event_record.pop('CreatedDate') + expected_attrs = copy.deepcopy(base_expected_attrs) + expected_attrs.pop('CreatedDate') + timestamp = util.get_timestamp() + expected_attrs['timestamp'] = timestamp # execute log = pipeline.pack_event_record_into_log( @@ -515,21 +624,31 @@ def test_pack_event_record_into_log(self): ) # verify - timestamp = int(datetime.now().timestamp() * 1000) - self.assertEqual(log['message'], f'Account') + self.assertTrue('message' in log) + self.assertTrue('attributes' in log) self.assertTrue('timestamp' in log) - self.assertTrue(log['timestamp'] <= timestamp) + self.assertEqual(log['message'], f'Account') + self.assertEqual(log['attributes'], expected_attrs) + self.assertEqual(log['timestamp'], timestamp) ''' given: a query, record id, and an event record - when: the 'rename_timestamp' query options is set - then: return the same as use case 1 but with a field with the name + when: the 'rename_timestamp' query option is specified, the record id is + not None, and the event record contains a 'type' field in the + 'attributes' field + then: return the same as use case 1 but with an attribute with the name specified in the 'rename_timestamp' query option set to the - current time and no 'timestamp' field + created date in the log 'attributes' attribute, no 'timestamp' + attribute in the log 'attributes' attribute, and with no log + 'timestamp' attribute ''' # setup event_record = copy.deepcopy(self.event_records[0]) + expected_attrs = copy.deepcopy(base_expected_attrs) + timestamp = util.get_timestamp(created_date) + expected_attrs['custom_timestamp'] = timestamp + expected_attrs.pop('timestamp') # execute log = pipeline.pack_event_record_into_log( @@ -539,10 +658,11 @@ def test_pack_event_record_into_log(self): ) # verify - timestamp = int(datetime.now().timestamp() * 1000) - self.assertTrue('custom_timestamp' in log['attributes']) + self.assertTrue('message' in log) + self.assertTrue('attributes' in log) self.assertTrue(not 'timestamp' in log) - self.assertTrue(log['attributes']['custom_timestamp'] <= timestamp) + self.assertEqual(log['message'], f'Account {created_date}') + self.assertEqual(log['attributes'], expected_attrs) def test_transform_event_records(self): ''' diff --git a/src/tests/test_util.py b/src/tests/test_util.py index 4f9b923..4ad247f 100644 --- a/src/tests/test_util.py +++ b/src/tests/test_util.py @@ -1,5 +1,6 @@ from datetime import datetime, timedelta import hashlib +import json import unittest from newrelic_logging import util @@ -215,6 +216,294 @@ def _utcnow(): # verify self.assertEqual(val, isonow) + def test_is_primitive_true_for_primitive_types(self): + ''' + is_primitive() returns true for types considered to be "primitive" (str, int, float, bool, None) + given: a set of "primitive" values + when: is_primitive() is called + then: returns True + ''' + + # setup + vals = ('string', 100, 62.1, False, None) + + # execute / verify + for v in vals: + b = util.is_primitive(v) + self.assertTrue(b) + + def test_is_primitive_false_for_non_primitive_types(self): + ''' + is_primitive() returns false for types not considered to be "primitive" + given: a set of "non-primitive" values + when: is_primitive() is called + then: returns False + ''' + + # setup + vals = (['list'], (1, 2), { 'foo': 'bar' }, Exception()) + + # execute / verify + for v in vals: + b = util.is_primitive(v) + self.assertFalse(b) + + def test_process_query_result_copies_primitive_fields(self): + ''' + process_query_result() copies all primitive fields to the new dict + given: JSON result from an SOQL query + when: the SOQL query result contains only primitive (non-nested) fields + then: returns dict containing all primitive fields + ''' + + # setup + query_result = json.loads('''{ + "Action": "PermSetFlsChanged", + "CreatedByContext": null, + "CreatedById": "0058W00000A7LvTQAV", + "CreatedByIssuer": 1, + "CreatedDate": "2023-11-30T17:33:08.000+0000", + "DelegateUser": 2.0, + "Display": "Changed permission set 00e1U000000XFwxQAG: field-level security for Task: Related To was changed from Read/Write to No Access", + "Id": "0Ym7c00001RGP6MCAX", + "ResponsibleNamespacePrefix": false, + "Section": "Manage Users" + }''') + + expected_result = { + 'Action': 'PermSetFlsChanged', + 'CreatedByContext': None, + 'CreatedById': '0058W00000A7LvTQAV', + 'CreatedByIssuer': 1, + 'CreatedDate': '2023-11-30T17:33:08.000+0000', + 'DelegateUser': 2.0, + 'Display': 'Changed permission set 00e1U000000XFwxQAG: field-level security for Task: Related To was changed from Read/Write to No Access', + 'Id': '0Ym7c00001RGP6MCAX', + 'ResponsibleNamespacePrefix': False, + 'Section': 'Manage Users' + } + + # execute + result = util.process_query_result(query_result) + + # verify + self.assertEqual(expected_result, result) + + def test_process_query_result_flattens_and_copies_nested_fields(self): + ''' + process_query_result() flattens all nested fields and copies them to the new dict with keys that use the syntax "field1.nestedfield1.nestedfield1" and so on + given: JSON result from an SOQL query + when: the SOQL query result contains primitive and nested fields + then: returns dict containing all primitive fields + and: contains primitives from all nested fields + and: the keys for all nested fields use the syntax 'field1.nestedfield1', 'field2.nestedfield2.nestedfield1' and so on + ''' + + # setup + query_result = json.loads('''{ + "Action": "PermSetFlsChanged", + "CreatedByContext": null, + "CreatedById": "0058W00000A7LvTQAV", + "CreatedBy": { + "Name": "Chetan Gupta", + "Profile": { + "Name": "System Administrator" + }, + "UserType": "Standard" + }, + "CreatedByIssuer": null, + "CreatedDate": "2023-11-30T17:33:08.000+0000", + "DelegateUser": null, + "Display": "Changed permission set 00e1U000000XFwxQAG: field-level security for Task: Related To was changed from Read/Write to No Access", + "Id": "0Ym7c00001RGP6MCAX", + "ResponsibleNamespacePrefix": null, + "Section": "Manage Users" + }''') + + expected_result = { + 'Action': 'PermSetFlsChanged', + 'CreatedBy.Name': 'Chetan Gupta', + 'CreatedBy.Profile.Name': 'System Administrator', + 'CreatedBy.UserType': 'Standard', + 'CreatedByContext': None, + 'CreatedById': '0058W00000A7LvTQAV', + 'CreatedByIssuer': None, + 'CreatedDate': '2023-11-30T17:33:08.000+0000', + 'DelegateUser': None, + 'Display': 'Changed permission set 00e1U000000XFwxQAG: field-level security for Task: Related To was changed from Read/Write to No Access', + 'Id': '0Ym7c00001RGP6MCAX', + 'ResponsibleNamespacePrefix': None, + 'Section': 'Manage Users' + } + + # execute + result = util.process_query_result(query_result) + + # verify + self.assertEqual(expected_result, result) + + def test_process_query_result_ignores_attributes_fields(self): + ''' + process_query_result() ignores all fields and nested fields named 'attributes' + given: JSON result from an SOQL query + when: the SOQL query result contains primitive and nested fields + and when: the SOQL query result contains fields and nested fields named 'attributes' + then: returns dict containing all primitive fields + and: contains primitives from all nested fields + and: the keys for all nested fields use the syntax 'field1.nestedfield1', 'field2.nestedfield2.nestedfield1' and so on + and: fields and nested fields named 'attributes' are ignored + ''' + + # setup + query_result = json.loads('''{ + "attributes": { + "type": "SetupAuditTrail", + "url": "/services/data/v55.0/sobjects/SetupAuditTrail/0Ym7c00001RGP6MCAX" + }, + "Action": "PermSetFlsChanged", + "CreatedByContext": null, + "CreatedById": "0058W00000A7LvTQAV", + "CreatedBy": { + "attributes": { + "type": "User", + "url": "/services/data/v55.0/sobjects/User/0058W00000A7LvTQAV" + }, + "Name": "Chetan Gupta", + "Profile": { + "attributes": { + "type": "Profile", + "url": "/services/data/v55.0/sobjects/Profile/00e1U000001wRS1QAM" + }, + "Name": "System Administrator" + }, + "UserType": "Standard" + }, + "CreatedByIssuer": null, + "CreatedDate": "2023-11-30T17:33:08.000+0000", + "DelegateUser": null, + "Display": "Changed permission set 00e1U000000XFwxQAG: field-level security for Task: Related To was changed from Read/Write to No Access", + "Id": "0Ym7c00001RGP6MCAX", + "ResponsibleNamespacePrefix": null, + "Section": "Manage Users" + }''') + + expected_result = { + 'Action': 'PermSetFlsChanged', + 'CreatedBy.Name': 'Chetan Gupta', + 'CreatedBy.Profile.Name': 'System Administrator', + 'CreatedBy.UserType': 'Standard', + 'CreatedByContext': None, + 'CreatedById': '0058W00000A7LvTQAV', + 'CreatedByIssuer': None, + 'CreatedDate': '2023-11-30T17:33:08.000+0000', + 'DelegateUser': None, + 'Display': 'Changed permission set 00e1U000000XFwxQAG: field-level security for Task: Related To was changed from Read/Write to No Access', + 'Id': '0Ym7c00001RGP6MCAX', + 'ResponsibleNamespacePrefix': None, + 'Section': 'Manage Users' + } + + # execute + result = util.process_query_result(query_result) + + # verify + self.assertEqual(expected_result, result) + + def test_process_query_result_ignores_non_dict_structured_fields(self): + ''' + process_query_result() ignores all fields and nested fields that are neither "primitive" nor type dict + given: JSON result from an SOQL query + when: the SOQL query result contains primitive and nested fields + and when: the SOQL query result contains fields and nested fields that are neither "primitive" nor type dict + then: returns dict containing all primitive fields + and: contains primitives from all nested fields + and: the keys for all nested fields use the syntax 'field1.nestedfield1', 'field2.nestedfield2.nestedfield1' and so on + and: fields and nested fields that are neither "primitive" nor type dict are ignored + ''' + + # setup + query_result = json.loads('''{ + "Action": "PermSetFlsChanged", + "CreatedByContext": null, + "CreatedById": "0058W00000A7LvTQAV", + "CreatedByIssuer": null, + "CreatedDate": "2023-11-30T17:33:08.000+0000", + "DelegateUser": null, + "Display": "Changed permission set 00e1U000000XFwxQAG: field-level security for Task: Related To was changed from Read/Write to No Access", + "Id": "0Ym7c00001RGP6MCAX", + "ResponsibleNamespacePrefix": null, + "Section": "Manage Users", + "RandomNested": { + "RandomNestedArray": ["beep", "boop"] + }, + "RandomArray": ["foo", "bar"] + }''') + + expected_result = { + 'Action': 'PermSetFlsChanged', + 'CreatedByContext': None, + 'CreatedById': '0058W00000A7LvTQAV', + 'CreatedByIssuer': None, + 'CreatedDate': '2023-11-30T17:33:08.000+0000', + 'DelegateUser': None, + 'Display': 'Changed permission set 00e1U000000XFwxQAG: field-level security for Task: Related To was changed from Read/Write to No Access', + 'Id': '0Ym7c00001RGP6MCAX', + 'ResponsibleNamespacePrefix': None, + 'Section': 'Manage Users' + } + + # execute + result = util.process_query_result(query_result) + + # verify + self.assertEqual(expected_result, result) + + def test_get_timestamp_returns_current_posix_ms_as_int(self): + ''' + get_timestamp() returns current posix time in ms as an integer + given: a date string + when: the date string is None + then: returns the current posix time in ms as an integer + ''' + + # setup + + __now = datetime.now() + + def _now(): + nonlocal __now + return __now + + util._NOW = _now + + expected = int(__now.timestamp() * 1000) + + # execute + timestamp = util.get_timestamp() + + # verify + self.assertEqual(expected, timestamp) + + def test_get_timestamp_returns_posix_ms_as_int_for_date_string(self): + ''' + get_timestamp() returns the posix time in ms as an integer for the time specified in the date string + given: a date string + when: the date string is not None + and when: the date string is of the form %Y-%m-%dT%H:%M:%S.%f%z + then: returns the posix time in ms as an integer for the date string + ''' + + # setup + date_string = '2024-03-11T00:00:00.000+0000' + time = datetime.strptime(date_string, '%Y-%m-%dT%H:%M:%S.%f%z') + expected = int(time.timestamp() * 1000) + + # execute + timestamp = util.get_timestamp(date_string) + + # verify + self.assertEqual(expected, timestamp) + if __name__ == '__main__': unittest.main() From f0477fff64fc34676f8748268b1579d7dce9b70a Mon Sep 17 00:00:00 2001 From: Scott DeWitt Date: Wed, 3 Apr 2024 10:26:24 -0400 Subject: [PATCH 08/11] fix: AttributeError on clear_auth, refactor to better support reauthentication on api calls --- src/newrelic_logging/api.py | 110 ++++ src/newrelic_logging/auth.py | 32 +- src/newrelic_logging/integration.py | 18 +- src/newrelic_logging/pipeline.py | 43 +- src/newrelic_logging/query.py | 37 +- src/newrelic_logging/salesforce.py | 54 +- src/tests/__init__.py | 173 +++++- src/tests/test_api.py | 838 ++++++++++++++++++++++++++++ src/tests/test_pipeline.py | 246 ++++++-- src/tests/test_query.py | 133 +++-- src/tests/test_salesforce.py | 147 ++++- 11 files changed, 1582 insertions(+), 249 deletions(-) create mode 100644 src/newrelic_logging/api.py create mode 100644 src/tests/test_api.py diff --git a/src/newrelic_logging/api.py b/src/newrelic_logging/api.py new file mode 100644 index 0000000..7b94172 --- /dev/null +++ b/src/newrelic_logging/api.py @@ -0,0 +1,110 @@ +from requests import RequestException, Session, Response +from typing import Any + +from . import SalesforceApiException +from .auth import Authenticator +from .telemetry import print_warn + +def get( + auth: Authenticator, + session: Session, + serviceUrl: str, + cb, + stream: bool = False, +) -> Any: + url = f'{auth.get_instance_url()}{serviceUrl}' + + try: + headers = { + 'Authorization': f'Bearer {auth.get_access_token()}' + } + + response = session.get(url, headers=headers, stream=stream) + + status_code = response.status_code + + if status_code == 200: + return cb(response) + + if status_code == 401: + print_warn( + f'invalid token while executing api operation: {url}', + ) + auth.reauthenticate(session) + + response = session.get(url, headers=headers, stream=stream) + if response.status_code == 200: + return cb(response) + + raise SalesforceApiException( + status_code, + f'error executing api operation: {url}, ' \ + f'status-code: {status_code}, ' \ + f'reason: {response.reason}, ' + ) + except ConnectionError as e: + raise SalesforceApiException( + -1, + f'connection error executing api operation: {url}', + ) from e + except RequestException as e: + raise SalesforceApiException( + -1, + f'request error executing api operation: {url}', + ) from e + + +def stream_lines(response: Response, chunk_size: int): + if response.encoding is None: + response.encoding = 'utf-8' + + # Stream the response as a set of lines. This function will return an + # iterator that yields one line at a time holding only the minimum + # amount of data chunks in memory to make up a single line + + return response.iter_lines( + decode_unicode=True, + chunk_size=chunk_size, + ) + + +class Api: + def __init__(self, authenticator: Authenticator, api_ver: str): + self.authenticator = authenticator + self.api_ver = api_ver + + def authenticate(self, session: Session) -> None: + self.authenticator.authenticate(session) + + def query(self, session: Session, soql: str, api_ver: str = None) -> dict: + ver = self.api_ver + if not api_ver is None: + ver = api_ver + + return get( + self.authenticator, + session, + f'/services/data/v{ver}/query?q={soql}', + lambda response : response.json() + ) + + def get_log_file( + self, + session: Session, + log_file_path: str, + chunk_size: int, + ): + return get( + self.authenticator, + session, + log_file_path, + lambda response : stream_lines(response, chunk_size), + stream=True, + ) + +class ApiFactory: + def __init__(self): + pass + + def new(self, authenticator: Authenticator, api_ver: str) -> Api: + return Api(authenticator, api_ver) diff --git a/src/newrelic_logging/auth.py b/src/newrelic_logging/auth.py index 4803b8d..ea72a1c 100644 --- a/src/newrelic_logging/auth.py +++ b/src/newrelic_logging/auth.py @@ -99,22 +99,6 @@ def store_auth(self, auth_resp: dict) -> None: except Exception as e: print_warn(f"Failed storing data in cache: {e}") - def authenticate( - self, - session: Session, - ) -> None: - if self.data_cache and self.load_auth_from_cache(): - return - - oauth_type = self.get_grant_type() - if oauth_type == 'password': - self.authenticate_with_password(session) - print_info('Correctly authenticated with user/pass flow') - return - - self.authenticate_with_jwt(session) - print_info('Correctly authenticated with JWT flow') - def authenticate_with_jwt(self, session: Session) -> None: private_key_file = self.auth_data['private_key'] client_id = self.auth_data['client_id'] @@ -198,6 +182,22 @@ def authenticate_with_password(self, session: Session) -> None: except RequestException as e: raise LoginException(f'authentication failed for sfdc instance {self.instance_name}') from e + def authenticate(self, session: Session) -> None: + if self.data_cache and self.load_auth_from_cache(): + return + + oauth_type = self.get_grant_type() + if oauth_type == 'password': + self.authenticate_with_password(session) + print_info('Correctly authenticated with user/pass flow') + return + + self.authenticate_with_jwt(session) + print_info('Correctly authenticated with JWT flow') + + def reauthenticate(self, session: Session) -> None: + self.clear_auth() + self.authenticate(session) def validate_oauth_config(auth: dict) -> dict: if not auth['client_id']: diff --git a/src/newrelic_logging/integration.py b/src/newrelic_logging/integration.py index b4fd53e..5b4e59b 100644 --- a/src/newrelic_logging/integration.py +++ b/src/newrelic_logging/integration.py @@ -135,29 +135,13 @@ def auth_and_fetch( self, client: salesforce.SalesForce, session: Session, - retry: bool = True, ) -> None: - try: client.authenticate(session) - return client.fetch_logs(session) + client.fetch_logs(session) except LoginException as e: print_err(f'authentication failed: {e}') except SalesforceApiException as e: - if e.err_code == 401: - if retry: - print_err('authentication failed, retrying...') - client.clear_auth() - self.auth_and_fetch( - client, - session, - False, - ) - return - - print_err(f'exception while fetching data from SF: {e}') - return - print_err(f'exception while fetching data from SF: {e}') except CacheException as e: print_err(f'exception while accessing Redis cache: {e}') diff --git a/src/newrelic_logging/pipeline.py b/src/newrelic_logging/pipeline.py index 51caecc..cecd341 100644 --- a/src/newrelic_logging/pipeline.py +++ b/src/newrelic_logging/pipeline.py @@ -6,6 +6,7 @@ from requests import Session from . import DataFormat, SalesforceApiException +from .api import Api from .cache import DataCache from .config import Config from .http_session import new_retry_session @@ -88,32 +89,13 @@ def pack_log_line_into_log( def export_log_lines( + api: Api, session: Session, - url: str, - access_token: str, + log_file_path: str, chunk_size: int, ): - print_info(f'Downloading log lines for log file: {url}') - - # Request the log lines for the log file record url - response = session.get( - url, - headers={ - 'Authorization': f'Bearer {access_token}' - }, - stream=True, - ) - if response.status_code != 200: - error_message = f'salesforce event log file download failed. ' \ - f'status-code: {response.status_code}, ' \ - f'reason: {response.reason} ' \ - f'response: {response.text}' - raise SalesforceApiException(response.status_code, error_message) - - # Stream the response as a set of lines. This function will return an - # iterator that yields one line at a time holding only the minimum - # amount of data chunks in memory to make up a single line - return response.iter_lines(chunk_size=chunk_size, decode_unicode=True) + print_info(f'Downloading log lines for log file: {log_file_path}') + return api.get_log_file(session, log_file_path, chunk_size) def transform_log_lines( @@ -380,15 +362,14 @@ def __init__( def process_log_record( self, + api: Api, session: Session, query: Query, - instance_url: str, - access_token: str, record: dict, ): record_id = str(record['Id']) record_event_type = query.get("event_type", record['EventType']) - record_file_name = record['LogFile'] + log_file_path = record['LogFile'] interval = record['Interval'] # NOTE: only Hourly logs can be skipped, because Daily logs can change @@ -403,9 +384,9 @@ def process_log_record( load_data( transform_log_lines( export_log_lines( + api, session, - f'{instance_url}{record_file_name}', - access_token, + log_file_path, self.config.get('chunk_size', DEFAULT_CHUNK_SIZE) ), query, @@ -441,20 +422,18 @@ def process_event_records( def execute( self, + api: Api, session: Session, query: Query, - instance_url: str, - access_token: str, records: list[dict], ): if is_logfile_response(records): for record in records: if 'LogFile' in record: self.process_log_record( + api, session, query, - instance_url, - access_token, record, ) diff --git a/src/newrelic_logging/query.py b/src/newrelic_logging/query.py index ce75a2a..e67223c 100644 --- a/src/newrelic_logging/query.py +++ b/src/newrelic_logging/query.py @@ -2,6 +2,7 @@ from requests import RequestException, Session from . import SalesforceApiException +from .api import Api from .config import Config from .telemetry import print_info from .util import get_iso_date_with_offset, substitute @@ -9,10 +10,12 @@ class Query: def __init__( self, + api: Api, query: str, config: Config, - api_ver: str, + api_ver: str = None, ): + self.api = api self.query = query self.config = config self.api_ver = api_ver @@ -26,32 +29,9 @@ def get_config(self): def execute( self, session: Session, - instance_url: str, - access_token: str, ): - url = f'{instance_url}/services/data/v{self.api_ver}/query?q={self.query}' - - try: - print_info(f'Running query {self.query} using url {url}') - - query_response = session.get(url, headers={ - 'Authorization': f'Bearer {access_token}' - }) - if query_response.status_code != 200: - raise SalesforceApiException( - query_response.status_code, - f'error when trying to run SOQL query. ' \ - f'status-code:{query_response.status_code}, ' \ - f'reason: {query_response.reason} ' \ - f'response: {query_response.text} ' - ) - - return query_response.json() - except RequestException as e: - raise SalesforceApiException( - -1, - f'error when trying to run SOQL query. cause: {e}', - ) from e + print_info(f'Running query {self.query}...') + return self.api.query(session, self.query, self.api_ver) class QueryFactory: @@ -78,16 +58,17 @@ def get_env(self, q: dict) -> dict: def new( self, + api: Api, q: dict, time_lag_minutes: int, last_to_timestamp: str, generation_interval: str, - default_api_ver: str, ) -> Query: qp = copy.deepcopy(q) qq = qp.pop('query', '') return Query( + api, substitute( self.build_args( time_lag_minutes, @@ -98,5 +79,5 @@ def new( self.get_env(qp), ).replace(' ', '+'), Config(qp), - qp.get('api_ver', default_api_ver) + qp.get('api_ver', None) ) diff --git a/src/newrelic_logging/salesforce.py b/src/newrelic_logging/salesforce.py index 1afa0db..cb3556f 100644 --- a/src/newrelic_logging/salesforce.py +++ b/src/newrelic_logging/salesforce.py @@ -1,6 +1,7 @@ from datetime import datetime, timedelta from requests import Session +from .api import ApiFactory from .auth import Authenticator from .cache import DataCache from . import config as mod_config @@ -27,16 +28,15 @@ def __init__( data_cache: DataCache, authenticator: Authenticator, pipeline: Pipeline, + api_factory: ApiFactory, query_factory: mod_query.QueryFactory, initial_delay: int, queries: list[dict] = None, ): self.instance_name = instance_name self.data_cache = data_cache - self.auth = authenticator self.pipeline = pipeline self.query_factory = query_factory - self.default_api_ver = config.get('api_ver', '52.0') self.time_lag_minutes = config.get( mod_config.CONFIG_TIME_LAG_MINUTES, mod_config.DEFAULT_TIME_LAG_MINUTES if not self.data_cache else 0, @@ -54,6 +54,10 @@ def __init__( self.time_lag_minutes, initial_delay, ) + self.api = api_factory.new( + authenticator, + config.get('api_ver', '52.0'), + ) self.queries = queries if queries else \ [{ 'query': SALESFORCE_LOG_DATE_QUERY \ @@ -62,7 +66,7 @@ def __init__( }] def authenticate(self, sfdc_session: Session): - self.auth.authenticate(sfdc_session) + self.api.authenticate(sfdc_session) def slide_time_range(self): self.last_to_timestamp = get_iso_date_with_offset( @@ -77,66 +81,28 @@ def fetch_logs(self, session: Session) -> list[dict]: for q in self.queries: query = self.query_factory.new( + self.api, q, self.time_lag_minutes, self.last_to_timestamp, self.generation_interval, - self.default_api_ver, ) - response = query.execute( - session, - self.auth.get_instance_url(), - self.auth.get_access_token(), - ) + response = query.execute(session) if not response or not 'records' in response: print_warn(f'no records returned for query {query.query}') continue self.pipeline.execute( + self.api, session, query, - self.auth.get_instance_url(), - self.auth.get_access_token(), response['records'], ) self.slide_time_range() -# @TODO need to handle this logic but only when exporting logfiles and at this -# level we don't make a distinction but in the pipeline we don't have the right -# info from this level to reauth -# -# try: -# download_response = download_file(session, f'{url}{record_file_name}') -# if download_response is None: -# return -# except SalesforceApiException as e: -# pass -# if e.err_code == 401: -# if retry: -# print_err("invalid token while downloading CSV file, retry auth and download...") -# self.clear_auth() -# if self.authenticate(self.oauth_type, session): -# return self.build_log_from_logfile(False, session, record, query) -# else: -# return None -# else: -# print_err(f'salesforce event log file "{record_file_name}" download failed: {e}') -# return None -# else: -# print_err(f'salesforce event log file "{record_file_name}" download failed: {e}') -# return None -# except RequestException as e: -# print_err( -# f'salesforce event log file "{record_file_name}" download failed: {e}' -# ) -# return -# -# csv_rows = self.parse_csv(download_response, record_id, record_event_type, cached_messages) -# -# print_info(f"CSV rows = {len(csv_rows)}") class SalesForceFactory: def __init__(self): diff --git a/src/tests/__init__.py b/src/tests/__init__.py index 297f4d3..06b3689 100644 --- a/src/tests/__init__.py +++ b/src/tests/__init__.py @@ -3,7 +3,8 @@ from redis import RedisError from requests import Session, RequestException -from newrelic_logging import DataFormat +from newrelic_logging import DataFormat, LoginException, SalesforceApiException +from newrelic_logging.api import Api from newrelic_logging.auth import Authenticator from newrelic_logging.cache import DataCache from newrelic_logging.config import Config @@ -12,6 +13,70 @@ from newrelic_logging.query import Query, QueryFactory +class ApiStub: + def __init__( + self, + authenticator: Authenticator = None, + api_ver: str = None, + query_result: dict = None, + lines: list[str] = None, + raise_error = False, + raise_login_error = False, + ): + self.authenticator = authenticator + self.api_ver = api_ver + self.query_result = query_result + self.lines = lines + self.soql = None + self.query_api_ver = None + self.log_file_path = None + self.chunk_size = None + self.raise_error = raise_error + self.raise_login_error = raise_login_error + + def authenticate(self, session: Session): + if self.raise_login_error: + raise LoginException() + + self.authenticator.authenticate(session) + + def query(self, session: Session, soql: str, api_ver: str = None) -> dict: + self.soql = soql + self.query_api_ver = api_ver + + if self.raise_error: + raise SalesforceApiException() + + if self.raise_login_error: + raise LoginException() + + return self.query_result + + def get_log_file( + self, + session: Session, + log_file_path: str, + chunk_size: int, + ): + self.log_file_path = log_file_path + self.chunk_size = chunk_size + + if self.raise_error: + raise SalesforceApiException() + + if self.raise_login_error: + raise LoginException() + + yield from self.lines + +class ApiFactoryStub: + def __init__(self): + pass + + def new(self, authenticator: Authenticator, api_ver: str): + return ApiStub(authenticator, api_ver) + + class AuthenticatorStub: def __init__( self, @@ -22,6 +87,8 @@ def __init__( instance_url: str = '', grant_type: str = '', authenticate_called: bool = False, + reauthenticate_called: bool = False, + raise_login_error = False, ): self.config = config self.data_cache = data_cache @@ -30,6 +97,8 @@ def __init__( self.instance_url = instance_url self.grant_type = grant_type self.authenticate_called = authenticate_called + self.reauthenticate_called = reauthenticate_called + self.raise_login_error = raise_login_error def get_access_token(self) -> str: return self.access_token @@ -52,17 +121,27 @@ def load_auth_from_cache(self) -> bool: def store_auth(self, auth_resp: dict) -> None: pass + def authenticate_with_jwt(self, session: Session) -> None: + pass + + def authenticate_with_password(self, session: Session) -> None: + pass + def authenticate( self, session: Session, ) -> None: self.authenticate_called = True + if self.raise_login_error: + raise LoginException('Unauthorized') - def authenticate_with_jwt(self, session: Session) -> None: - pass - - def authenticate_with_password(self, session: Session) -> None: - pass + def reauthenticate( + self, + session: Session, + ) -> None: + self.reauthenticate_called = True + if self.raise_login_error: + raise LoginException('Unauthorized') class AuthenticatorFactoryStub: @@ -133,11 +212,13 @@ def new(self, config: Config): class QueryStub: def __init__( self, + api: Api = None, + query: str = '', config: Config = Config({}), - api_ver: str = '', + api_ver: str = None, result: dict = { 'records': [] }, - query: str = '', ): + self.api = api self.query = query self.config = config self.api_ver = api_ver @@ -150,12 +231,7 @@ def get(self, key: str, default = None): def get_config(self): return self.config - def execute( - self, - session: Session = None, - instance_url: str = '', - access_token: str = '', - ): + def execute(self, session: Session = None): self.executed = True return self.result @@ -168,16 +244,16 @@ def __init__(self, query: QueryStub = None ): def new( self, + api: Api, q: dict, time_lag_minutes: int = 0, last_to_timestamp: str = '', generation_interval: str = '', - default_api_ver: str = '', ) -> Query: if self.query: return self.query - qq = QueryStub(q, default_api_ver, query=q['query']) + qq = QueryStub(api, q['query'], q) self.queries.append(qq) return qq @@ -192,6 +268,8 @@ def __init__( labels: dict = {}, event_type_fields_mapping: dict = {}, numeric_fields_list: set = set(), + raise_error: bool = False, + raise_login_error: bool = False, ): self.config = config self.data_cache = data_cache @@ -202,15 +280,22 @@ def __init__( self.numeric_fields_list = numeric_fields_list self.queries = [] self.executed = False + self.raise_error = raise_error + self.raise_login_error = raise_login_error def execute( self, + api: Api, session: Session, query: Query, - instance_url: str, - access_token: str, records: list[dict], ): + if self.raise_error: + raise SalesforceApiException() + + if self.raise_login_error: + raise LoginException() + self.queries.append(query) self.executed = True @@ -321,14 +406,49 @@ def new(self, _: Config): return BackendStub({}) +class MultiRequestSessionStub: + def __init__(self, responses=[], raise_error=False): + self.raise_error = raise_error + self.requests = [] + self.responses = responses + self.count = 0 + + def get(self, *args, **kwargs): + self.requests.append({ + 'url': args[0], + 'headers': kwargs['headers'], + 'stream': kwargs['stream'] if 'stream' in kwargs else None, + }) + + if self.raise_error: + raise RequestException('raise_error set') + + if self.count < len(self.responses): + self.count += 1 + + return self.responses[self.count - 1] + + class ResponseStub: - def __init__(self, status_code, reason, text, lines): + def __init__(self, status_code, reason, text, lines, encoding=None): self.status_code = status_code self.reason = reason self.text = text self.lines = lines + self.chunk_size = None + self.decode_unicode = None + self.encoding = encoding + self.iter_lines_called = False def iter_lines(self, *args, **kwargs): + self.iter_lines_called = True + + if 'chunk_size' in kwargs: + self.chunk_size = kwargs['chunk_size'] + + if 'decode_unicode' in kwargs: + self.decode_unicode = kwargs['decode_unicode'] + yield from self.lines def json(self, *args, **kwargs): @@ -385,17 +505,24 @@ def new( class SessionStub: - def __init__(self, raise_error=False): + def __init__(self, raise_error=False, raise_connection_error=False): self.raise_error = raise_error + self.raise_connection_error = raise_connection_error self.response = None self.headers = None self.url = None + self.stream = None def get(self, *args, **kwargs): - if self.raise_error: - raise RequestException('raise_error set') - self.url = args[0] self.headers = kwargs['headers'] + if 'stream' in kwargs: + self.stream = kwargs['stream'] + + if self.raise_connection_error: + raise ConnectionError('raise_connection_error set') + + if self.raise_error: + raise RequestException('raise_error set') return self.response diff --git a/src/tests/test_api.py b/src/tests/test_api.py new file mode 100644 index 0000000..6459a7d --- /dev/null +++ b/src/tests/test_api.py @@ -0,0 +1,838 @@ +from requests import Session +import unittest + +from newrelic_logging import api, SalesforceApiException, LoginException +from . import \ + AuthenticatorStub, \ + ResponseStub, \ + SessionStub, \ + MultiRequestSessionStub + + +class TestApi(unittest.TestCase): + def test_get_calls_session_get_with_url_access_token_and_default_stream_flag_and_invokes_cb_and_returns_result_on_200(self): + ''' + get() makes a request with the given URL, access token, and default stream flag and invokes the callback with response and returns the result on a 200 status code + given: an authenticator + and given: a session + and given: a url + and given: a service URL + and given: a callback + when: get() is called + then: session.get() is called with full URL, access token, and default stream flag + and when: response status code is 200 + then: invokes callback with response and returns result + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + ) + response = ResponseStub(200, 'OK', 'OK', []) + session = SessionStub() + session.response = response + + def cb(response): + return response.text + + # execute + val = api.get(auth, session, '/foo', cb) + + # verify + self.assertEqual(session.url, 'https://my.salesforce.test/foo') + self.assertTrue('Authorization' in session.headers) + self.assertEqual(session.headers['Authorization'], 'Bearer 123456') + self.assertFalse(session.stream) + self.assertIsNotNone(val) + self.assertEqual(val, 'OK') + + def test_get_calls_session_get_with_url_access_token_and_given_stream_flag_and_invokes_cb_and_returns_result_on_200(self): + ''' + get() makes a request with the given URL, access token, and stream flag and invokes the callback with response and returns the result on a 200 status code + given: an authenticator + and given: a session + and given: a url + and given: a callback + when: get() is called + then: session.get() is called with full URL, access token, and stream flag + and when: response status code is 200 + then: invokes callback with response and returns result + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + ) + response = ResponseStub(200, 'OK', 'OK', []) + session = SessionStub() + session.response = response + + def cb(response): + return response.text + + # execute + val = api.get(auth, session, '/foo', cb, stream=True) + + # verify + self.assertEqual(session.url, 'https://my.salesforce.test/foo') + self.assertTrue('Authorization' in session.headers) + self.assertEqual(session.headers['Authorization'], 'Bearer 123456') + self.assertTrue(session.stream) + self.assertIsNotNone(val) + self.assertEqual(val, 'OK') + + def test_get_raises_on_response_not_200_or_401(self): + ''' + get() raises a SalesforceApiException when the response status code is not 200 or 401 + given: an authenticator + and given: a session + and given: a service url + and given: a callback + when: get() is called + then: session.get() is called with full URL, access token, and default stream flag + and when: response status code is not 200 + and when: response status code is not 401 + then: raises a SalesforceApiException + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + ) + response = ResponseStub(500, 'Server Error', 'Server Error', []) + session = SessionStub() + session.response = response + + def cb(response): + return response + + # execute / verify + with self.assertRaises(SalesforceApiException) as _: + _ = api.get( + auth, + session, + '/foo', + cb, + ) + + self.assertEqual(session.url, 'https://my.salesforce.test/foo') + self.assertTrue('Authorization' in session.headers) + self.assertEqual(session.headers['Authorization'], 'Bearer 123456') + self.assertFalse(session.stream) + + def test_get_raises_on_connection_error(self): + ''' + get() raises a SalesforceApiException when session.get() raises a ConnectionError + given: an authenticator + and given: a session + and given: a service url + and given: a callback + when: get() is called + then: session.get() is called with full URL, access token, and default stream flag + and when: session.get() raises a ConnectionError + then: raises a SalesforceApiException + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + ) + response = ResponseStub(500, 'Server Error', 'Server Error', []) + session = SessionStub(raise_connection_error=True) + session.response = response + + def cb(response): + return response + + # execute / verify + with self.assertRaises(SalesforceApiException) as _: + _ = api.get( + auth, + session, + '/foo', + cb, + ) + + self.assertEqual(session.url, 'https://my.salesforce.test/foo') + self.assertTrue('Authorization' in session.headers) + self.assertEqual(session.headers['Authorization'], 'Bearer 123456') + self.assertFalse(session.stream) + + def test_get_raises_on_request_exception(self): + ''' + get() raises a SalesforceApiException when session.get() raises a RequestException + given: an authenticator + and given: a session + and given: a service url + and given: a callback + when: get() is called + then: session.get() is called with full URL, access token, and default stream flag + and when: session.get() raises a RequestException + then: raises a SalesforceApiException + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + ) + response = ResponseStub(500, 'Server Error', 'Server Error', []) + session = SessionStub(raise_error=True) + session.response = response + + def cb(response): + return response + + # execute / verify + with self.assertRaises(SalesforceApiException) as _: + _ = api.get( + auth, + session, + '/foo', + cb, + ) + + self.assertEqual(session.url, 'https://my.salesforce.test/foo') + self.assertTrue('Authorization' in session.headers) + self.assertEqual(session.headers['Authorization'], 'Bearer 123456') + self.assertFalse(session.stream) + + def test_get_raises_login_exception_if_reauthenticate_does(self): + ''' + get() raises a LoginException when the status code is 401 and reauthenticate() raises a LoginException + given: an authenticator + and given: a session + and given: a service url + and given: a callback + when: get() is called + then: session.get() is called with full URL, access token, and default stream flag + and when: response status code is 401 + then: reauthenticate() is called + and when: reauthenticate() raises a LoginException + then: raises a LoginException + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + raise_login_error=True + ) + response = ResponseStub(401, 'Unauthorized', 'Unauthorized', []) + session = SessionStub() + session.response = response + + def cb(response): + return response + + # execute / verify + with self.assertRaises(LoginException) as _: + _ = api.get( + auth, + session, + '/foo', + cb, + ) + + self.assertEqual(session.url, 'https://my.salesforce.test/foo') + self.assertTrue('Authorization' in session.headers) + self.assertEqual(session.headers['Authorization'], 'Bearer 123456') + self.assertFalse(session.stream) + self.assertTrue(auth.reauthenticate_called) + + def test_get_calls_reauthenticate_on_401_and_invokes_cb_with_response_on_200(self): + ''' + get() calls reauthenticate() on a 401 and then invokes the callback with response and returns the result on a 200 status code + given: an authenticator + and given: a session + and given: a service url + and given: a callback + when: get() is called + then: session.get() is called with given URL, access token, and default stream flag + and when: response status code is 401 + then: reauthenticate() is called + and when: reauthenticate() does not throw a LoginException + then: request is executed again with the same parameters + and when: it returns a 200 + then: calls callback with response and returns result + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + ) + response1 = ResponseStub(401, 'Unauthorized', 'Unauthorized', []) + response2 = ResponseStub(200, 'OK', 'OK', []) + session = MultiRequestSessionStub(responses=[response1, response2]) + + def cb(response): + return response.text + + # execute + val = api.get(auth, session, '/foo', cb) + + # verify + self.assertEqual(len(session.requests), 2) + self.assertEqual( + session.requests[0]['url'], + 'https://my.salesforce.test/foo', + ) + self.assertTrue('Authorization' in session.requests[0]['headers']) + self.assertEqual( + session.requests[0]['headers']['Authorization'], + 'Bearer 123456', + ) + self.assertEqual( + session.requests[0]['stream'], + False, + ) + self.assertTrue(auth.reauthenticate_called) + self.assertEqual( + session.requests[1]['url'], + 'https://my.salesforce.test/foo', + ) + self.assertTrue('Authorization' in session.requests[1]['headers']) + self.assertEqual( + session.requests[1]['headers']['Authorization'], + 'Bearer 123456', + ) + self.assertEqual( + session.requests[1]['stream'], + False, + ) + self.assertIsNotNone(val) + self.assertEqual(val, 'OK') + + def test_get_passes_same_params_to_get_on_reauthenticate(self): + ''' + get() receives the same set of parameters on the second call after reauthenticate() succeeds + given: an authenticator + and given: a session + and given: a service url + and given: a callback + when: get() is called + then: session.get() is called with full URL, access token, and default stream flag + and when: response status code is 401 + then: reauthenticate() is called + and when: reauthenticate() does not throw a LoginException + then: request is executed again with the same set of parameters as the first call to session.get() + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + ) + response1 = ResponseStub(401, 'Unauthorized', 'Unauthorized', []) + response2 = ResponseStub(200, 'OK', 'OK', []) + session = MultiRequestSessionStub(responses=[response1, response2]) + + def cb(response): + return response.text + + # execute + val = api.get(auth, session, '/foo', cb, stream=True) + + # verify + self.assertEqual(len(session.requests), 2) + self.assertEqual(session.requests[0], session.requests[1]) + self.assertTrue(auth.reauthenticate_called) + self.assertIsNotNone(val) + self.assertEqual(val, 'OK') + + def test_get_calls_reauthenticate_on_401_and_raises_on_non_200(self): + ''' + get() function calls reauthenticate() on a 401 and then throws a SalesforceApiException on a non-200 status code + given: an authenticator + and given: a session + and given: a service url + and given: a callback + when: get() is called + then: session.get() is called with full URL, access token, and default stream flag + and when: response status code is 401 + then: reauthenticate() is called + and when: reauthenticate() does not throw a LoginException + then: request is executed again with the same parameters + and when: it returns a non-200 status code + then: throws a SalesforceApiException + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + ) + response1 = ResponseStub(401, 'Unauthorized', 'Unauthorized', []) + response2 = ResponseStub(401, 'Unauthorized', 'Unauthorized 2', []) + session = MultiRequestSessionStub(responses=[response1, response2]) + + def cb(response): + return response + + # execute / verify + with self.assertRaises(SalesforceApiException) as _: + _ = api.get( + auth, + session, + '/foo', + cb, + ) + + self.assertEqual(len(session.requests), 2) + self.assertEqual(session.requests[0], session.requests[1]) + self.assertTrue(auth.reauthenticate_called) + + def test_stream_lines_sets_fallback_encoding_and_calls_iter_lines_with_chunk_size_and_decode_unicode(self): + ''' + stream_lines() sets a default encoding on the response and calls iter_lines with the given chunk size and the decode_unicode flag = True + given: a response + and given: a chunk size + when: encoding on response is None + then: fallback utf-8 is used and iter_lines is called with given chunk size and decode_unicode flag = True + ''' + + # setup + response = ResponseStub(200, 'OK', 'OK', ['foo lines', 'bar line']) + + # execute + lines = api.stream_lines(response, 1024) + + # verify + next(lines) + next(lines) + # NOTE: this has to be done _after_ the generator iterator is called at + # least once since the generator function is not run until the first + # call to next() + self.assertEqual(response.encoding, 'utf-8') + self.assertEqual(response.chunk_size, 1024) + self.assertTrue(response.decode_unicode) + self.assertTrue(response.iter_lines_called) + + def test_stream_lines_uses_default_encoding_and_calls_iter_lines_with_chunk_size_and_decode_unicode(self): + ''' + stream_lines() sets a default encoding on the response and calls iter_lines with the given chunk size and the decode_unicode flag = True + given: a response + and given: a chunk size + when: encoding on response is set + then: response encoding is used and iter_lines is called with given chunk size and decode_unicode flag = True + ''' + + # setup + response = ResponseStub( + 200, + 'OK', + 'OK', + ['foo lines', 'bar line'], + encoding='iso-8859-1', + ) + + # execute + lines = api.stream_lines(response, 1024) + + # verify + next(lines) + next(lines) + # NOTE: this has to be done _after_ the generator iterator is called at + # least once since the generator function is not run until the first + # call to next() + self.assertEqual(response.encoding, 'iso-8859-1') + self.assertEqual(response.chunk_size, 1024) + self.assertTrue(response.decode_unicode) + self.assertTrue(response.iter_lines_called) + + def test_authenticate_calls_authenticator_authenticate(self): + ''' + authenticate() calls authenticate() on the backing authenticator + given: an authenticator + and given: an api version + and given: a session + when: authenticate() is called + then: authenticator.authenticate() is called + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + ) + session = SessionStub() + + # execute + sf_api = api.Api(auth, '55.0') + sf_api.authenticate(session) + + # verify + self.assertTrue(auth.authenticate_called) + + def test_authenticate_raises_login_exception_if_authenticator_authenticate_does(self): + ''' + authenticate() calls authenticate() on the backing authenticator and raises a LoginException if authenticate() does + given: an authenticator + and given: an api version + and given: a session + when: authenticate() is called + then: authenticator.authenticate() is called + and when: authenticator.authenticate() raises a LoginException + then: authenticate() raises a LoginException + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + raise_login_error=True, + ) + session = SessionStub() + + # execute / verify + with self.assertRaises(LoginException) as _: + sf_api = api.Api(auth, '55.0') + sf_api.authenticate(session) + + self.assertTrue(auth.authenticate_called) + + def test_query_requests_correct_url_with_access_token_and_returns_json_response_on_success(self): + ''' + query() calls the correct query API url with the access token and returns a JSON response when no errors occur + given: an authenticator + and given: an api version + and given: a session + and given: a query + when: query() is called + then: session.get() is called with correct URL and access token + and: stream is set to False + and when: session.get() response status code is 200 + then: calls callback with response and returns a JSON response + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + ) + session = SessionStub() + session.response = ResponseStub(200, 'OK', '{"foo": "bar"}', [] ) + + # execute + sf_api = api.Api(auth, '55.0') + resp = sf_api.query( + session, + 'SELECT+LogFile+FROM+EventLogFile', + ) + + # verify + + self.assertEqual( + session.url, + f'https://my.salesforce.test/services/data/v55.0/query?q=SELECT+LogFile+FROM+EventLogFile', + ) + self.assertTrue('Authorization' in session.headers) + self.assertEqual(session.headers['Authorization'], 'Bearer 123456') + self.assertFalse(session.stream) + self.assertIsNotNone(resp) + self.assertTrue(type(resp) is dict) + self.assertTrue('foo' in resp) + self.assertEqual(resp['foo'], 'bar') + + def test_query_requests_correct_url_with_access_token_given_api_version_and_returns_json_response_on_success(self): + ''' + query() calls the correct query API url with the access token when a specific api version is given and returns a JSON response + given: an authenticator + and given: an api version + and given: a session + and given: a query + when: query() is called + and when: the api version parameter is specified + then: session.get() is called with correct URL and access token + and: stream is set to False + and when: response status code is 200 + then: calls callback with response and returns a JSON response + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + ) + session = SessionStub() + session.response = ResponseStub(200, 'OK', '{"foo": "bar"}', [] ) + + # execute + sf_api = api.Api(auth, '55.0') + resp = sf_api.query( + session, + 'SELECT+LogFile+FROM+EventLogFile', + '52.0', + ) + + # verify + + self.assertEqual( + session.url, + f'https://my.salesforce.test/services/data/v52.0/query?q=SELECT+LogFile+FROM+EventLogFile', + ) + self.assertTrue('Authorization' in session.headers) + self.assertEqual(session.headers['Authorization'], 'Bearer 123456') + self.assertFalse(session.stream) + self.assertIsNotNone(resp) + self.assertTrue(type(resp) is dict) + self.assertTrue('foo' in resp) + self.assertEqual(resp['foo'], 'bar') + + def test_query_raises_login_exception_if_get_does(self): + ''' + query() calls the correct query API url with the access token and raises LoginException if get does + given: an authenticator + and given: an api version + and given: a session + and given: a query + when: query() is called + then: session.get() is called with correct URL and access token + and: stream is set to False + and when: session.get() raises a LoginException + then: query() raises a LoginException + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + raise_login_error=True + ) + session = SessionStub() + session.response = ResponseStub(401, 'Unauthorized', '{"foo": "bar"}', [] ) + + # execute / verify + with self.assertRaises(LoginException) as _: + sf_api = api.Api(auth, '55.0') + _ = sf_api.query( + session, + 'SELECT+LogFile+FROM+EventLogFile', + ) + + self.assertEqual( + session.url, + f'https://my.salesforce.test/services/data/v55.0/query?q=SELECT+LogFile+FROM+EventLogFile', + ) + self.assertTrue('Authorization' in session.headers) + self.assertEqual(session.headers['Authorization'], 'Bearer 123456') + self.assertFalse(session.stream) + + def test_query_raises_salesforce_exception_if_get_does(self): + ''' + query() calls the correct query API url with the access token and raises SalesforceApiException if get does + given: an authenticator + and given: an api version + and given: a session + and given: a query + when: query() is called + then: session.get() is called with correct URL and access token + and: stream is set to False + and when: session.get() raises a SalesforceApiException + then: query() raises a SalesforceApiException + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + ) + session = SessionStub() + session.response = ResponseStub(500, 'ServerError', '{"foo": "bar"}', [] ) + + # execute / verify + with self.assertRaises(SalesforceApiException) as _: + sf_api = api.Api(auth, '55.0') + _ = sf_api.query( + session, + 'SELECT+LogFile+FROM+EventLogFile', + ) + + self.assertEqual( + session.url, + f'https://my.salesforce.test/services/data/v55.0/query?q=SELECT+LogFile+FROM+EventLogFile', + ) + self.assertTrue('Authorization' in session.headers) + self.assertEqual(session.headers['Authorization'], 'Bearer 123456') + self.assertFalse(session.stream) + + def test_get_log_file_requests_correct_url_with_access_token_and_returns_generator_on_success(self): + ''' + get_log_file() calls the correct url with the access token and returns a generator iterator + given: an authenticator + and given: an api version + and given: a session + and given: a log file path + and given: a chunk size + when: get_log_file() is called + then: session.get() is called with correct URL and access token + and: stream is set to True + and when: response status code is 200 + and when: get() returns a response + then: calls callback with response + and: iter_lines is called with the correct chunk size + and: returns a generator iterator + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + ) + session = SessionStub() + session.response = ResponseStub( + 200, + 'OK', + '', + [ 'COL1,COL2,COL3', 'foo,bar,baz' ], + ) + + # execute + sf_api = api.Api(auth, '55.0') + resp = sf_api.get_log_file( + session, + '/services/data/v52.0/sobjects/EventLogFile/00001111AAAABBBB/LogFile', + chunk_size=8192, + ) + + # verify + self.assertEqual( + session.url, + f'https://my.salesforce.test/services/data/v52.0/sobjects/EventLogFile/00001111AAAABBBB/LogFile', + ) + self.assertTrue('Authorization' in session.headers) + self.assertEqual(session.headers['Authorization'], 'Bearer 123456') + self.assertTrue(session.stream) + self.assertIsNotNone(resp) + line = next(resp) + self.assertEqual('COL1,COL2,COL3', line) + line = next(resp) + self.assertEqual('foo,bar,baz', line) + line = next(resp, None) + self.assertIsNone(line) + # NOTE: this has to be done _after_ the generator iterator is called at + # least once since the generator function is not run until the first + # call to next() + self.assertTrue(session.response.iter_lines_called) + self.assertEqual(session.response.chunk_size, 8192) + + def test_get_log_file_raises_login_exception_if_get_does(self): + ''' + get_log_file() calls the correct query API url with the access token and raises a LoginException if get does + given: an authenticator + and given: an api version + and given: a session + and given: a log file path + and given: a chunk size + when: get_log_file() is called + then: session.get() is called with correct URL and access token + and: stream is set to True + and when: get() raises a LoginException + then: get_log_file() raises a LoginException + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + raise_login_error=True, + ) + session = SessionStub() + session.response = ResponseStub( + 401, + 'Unauthorized', + 'Unauthorized', + [ 'COL1,COL2,COL3', 'foo,bar,baz' ], + ) + + # execute / verify + with self.assertRaises(LoginException) as _: + sf_api = api.Api(auth, '55.0') + _ = sf_api.get_log_file( + session, + '/services/data/v52.0/sobjects/EventLogFile/00001111AAAABBBB/LogFile', + chunk_size=8192, + ) + + self.assertEqual( + session.url, + f'https://my.salesforce.test/services/data/v52.0/sobjects/EventLogFile/00001111AAAABBBB/LogFile', + ) + self.assertTrue('Authorization' in session.headers) + self.assertEqual(session.headers['Authorization'], 'Bearer 123456') + self.assertTrue(session.stream) + + def test_get_log_file_raises_salesforce_exception_if_get_does(self): + ''' + get_log_file() calls the correct query API url with the access token and raises a SalesforceApiException if get does + given: an authenticator + and given: an api version + and given: a session + and given: a log file path + and given: a chunk size + when: get_log_file() is called + then: session.get() is called with correct URL and access token + and: stream is set to True + and when: get() raises a SalesforceApiException + then: get_log_file() raises a SalesforceApiException + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + ) + session = SessionStub(raise_error=True) + session.response = ResponseStub( + 200, + 'OK', + '', + [ 'COL1,COL2,COL3', 'foo,bar,baz' ], + ) + + # execute / verify + with self.assertRaises(SalesforceApiException) as _: + sf_api = api.Api(auth, '55.0') + _ = sf_api.get_log_file( + session, + '/services/data/v52.0/sobjects/EventLogFile/00001111AAAABBBB/LogFile', + chunk_size=8192, + ) + + self.assertEqual( + session.url, + f'https://my.salesforce.test/services/data/v52.0/sobjects/EventLogFile/00001111AAAABBBB/LogFile', + ) + self.assertTrue('Authorization' in session.headers) + self.assertEqual(session.headers['Authorization'], 'Bearer 123456') + self.assertTrue(session.stream) + + +class TestApiFactory(unittest.TestCase): + def test_new_returns_api_with_correct_authenticator_and_version(self): + ''' + new() returns a new Api instance with the given authenticator and version + given: an authenticator + and given: an api version + when: new() is called + then: returns a new Api instance with the given authenticator and version + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + ) + + # execute + api_factory = api.ApiFactory() + sf_api = api_factory.new(auth, '55.0') + + # verify + self.assertEqual(sf_api.authenticator, auth) + self.assertEqual(sf_api.api_ver, '55.0') diff --git a/src/tests/test_pipeline.py b/src/tests/test_pipeline.py index c81d459..d79c6f7 100644 --- a/src/tests/test_pipeline.py +++ b/src/tests/test_pipeline.py @@ -4,19 +4,22 @@ import pytz import unittest + from newrelic_logging import \ config, \ DataFormat, \ + LoginException, \ pipeline, \ util, \ SalesforceApiException from . import \ + ApiStub, \ DataCacheStub, \ NewRelicStub, \ QueryStub, \ - ResponseStub, \ SessionStub + class TestPipeline(unittest.TestCase): def setUp(self): with open('./tests/sample_log_lines.csv') as stream: @@ -162,7 +165,7 @@ def test_pack_log_line_into_log(self): ''' # setup - query = QueryStub({ + query = QueryStub(config={ 'event_type': 'CustomSFEvent', 'rename_timestamp': 'custom_timestamp', }) @@ -186,30 +189,55 @@ def test_pack_log_line_into_log(self): def test_export_log_lines(self): ''' - given: an http session, url, access token, and chunk size - when: the response produces a non-200 status code + given: an Api instance, an http session, log file path, and chunk size + when: api.get_log_file() is called + and when: api.get_log_file() raises a LoginException + then: raise a LoginException + ''' + + # setup + api = ApiStub(raise_login_error=True) + session = SessionStub() + + # execute/verify + with self.assertRaises(LoginException): + lines = pipeline.export_log_lines(api, session, '', 100) + # Have to use next to cause the generator to execute the function + # else get_log_file() won't get executed and our stub won't have + # a chance to throw the fake exception. + next(lines) + + ''' + given: an Api instance, an http session, log file path, and chunk size + when: api.get_log_file() is called + and when: api.get_log_file() raises a SalesforceApiException then: raise a SalesforceApiException ''' # setup + api = ApiStub(raise_error=True) session = SessionStub() - session.response = ResponseStub(500, 'Error', '', []) # execute/verify with self.assertRaises(SalesforceApiException): - pipeline.export_log_lines(session, '', '', 100) + lines = pipeline.export_log_lines(api, session, '', 100) + # Have to use next to cause the generator to execute the function + # else get_log_file() won't get executed and our stub won't have + # a chance to throw the fake exception. + next(lines) ''' - given: an http session, url, access token, and chunk size + given: an Api instance, an http session, log file path, and chunk size when: the response produces a 200 status code then: return a generator iterator that yields one line of data at a time ''' # setup - session.response = ResponseStub(200, 'OK', '', self.log_rows) + api = ApiStub(lines=self.log_rows) + session = SessionStub() - #execute - response = pipeline.export_log_lines(session, '', '', 100) + # execute + response = pipeline.export_log_lines(api, session, '', 100) lines = [] for line in response: @@ -394,7 +422,7 @@ def test_pack_event_record_into_log(self): # execute log = pipeline.pack_event_record_into_log( - QueryStub({ 'event_type': 'CustomEvent' }), + QueryStub(config={ 'event_type': 'CustomEvent' }), '00001111AAAABBBB', event_record ) @@ -555,7 +583,7 @@ def _now(): # execute log = pipeline.pack_event_record_into_log( - QueryStub({ 'timestamp_attr': 'CustomDate' }), + QueryStub(config={ 'timestamp_attr': 'CustomDate' }), '00001111AAAABBBB', event_record ) @@ -588,7 +616,7 @@ def _now(): # execute log = pipeline.pack_event_record_into_log( - QueryStub({ 'timestamp_attr': 'NotPresent' }), + QueryStub(config={ 'timestamp_attr': 'NotPresent' }), '00001111AAAABBBB', event_record ) @@ -652,7 +680,7 @@ def _now(): # execute log = pipeline.pack_event_record_into_log( - QueryStub({ 'rename_timestamp': 'custom_timestamp' }), + QueryStub(config={ 'rename_timestamp': 'custom_timestamp' }), '00001111AAAABBBB', event_record ) @@ -706,7 +734,7 @@ def test_transform_event_records(self): # execute logs = pipeline.transform_event_records( self.event_records[2:], - QueryStub({ 'id': ['Name'] }), + QueryStub(config={ 'id': ['Name'] }), None, ) @@ -1142,9 +1170,9 @@ def test_pipeline_process_log_record(self): ''' # setup + api = ApiStub(lines=self.log_rows) cfg = config.Config({}) session = SessionStub() - session.response = ResponseStub(200, 'OK', '', self.log_rows) newrelic = NewRelicStub() query = QueryStub({}) @@ -1165,10 +1193,9 @@ def test_pipeline_process_log_record(self): # execute p.process_log_record( + api, session, query, - 'https://test.local.test', - '12345', record, ) @@ -1207,15 +1234,77 @@ def test_pipeline_process_log_record(self): ''' given: the values from use case 1 when: the pipeline is configured as in use case 1 - and when: the data format is set to DataFormat.LOGS - and when: the number of log lines to be processed is less than the - maximum number of rows, + and when: export_log_lines() raises a LoginException + then: raise a LoginException + ''' + + # setup + api = ApiStub(raise_login_error=True) + newrelic = NewRelicStub() + + p = pipeline.Pipeline( + cfg, + None, + newrelic, + DataFormat.LOGS, + { 'foo': 'bar' }, + {}, + set(), + ) + + record = self.log_records[0] + + # execute / verify + with self.assertRaises(LoginException) as _: + p.process_log_record( + api, + session, + query, + record, + ) + + ''' + given: the values from use case 1 + when: the pipeline is configured as in use case 1 + and when: export_log_lines() raises a SalesforceApiException + then: raise a SalesforceApiException + ''' + + # setup + api = ApiStub(raise_error=True) + newrelic = NewRelicStub() + + p = pipeline.Pipeline( + cfg, + None, + newrelic, + DataFormat.LOGS, + { 'foo': 'bar' }, + {}, + set(), + ) + + record = self.log_records[0] + + # execute / verify + with self.assertRaises(SalesforceApiException) as _: + p.process_log_record( + api, + session, + query, + record, + ) + + ''' + given: the values from use case 1 + when: the pipeline is configured as in use case 1 and when: a data cache is specified and when: the record ID matches a record ID in the data cache then: no log entries are sent ''' # setup + api = ApiStub() data_cache = DataCacheStub(skip_record_ids=['00001111AAAABBBB']) newrelic = NewRelicStub() @@ -1236,10 +1325,9 @@ def test_pipeline_process_log_record(self): # execute p.process_log_record( + api, session, query, - 'https://test.local.test', - '12345', record, ) @@ -1259,6 +1347,8 @@ def test_pipeline_process_log_record(self): ''' # setup + api = ApiStub(lines=self.log_rows) + data_cache = DataCacheStub(skip_record_ids=['00001111AAAABBBB']) newrelic = NewRelicStub() p = pipeline.Pipeline( @@ -1279,10 +1369,9 @@ def test_pipeline_process_log_record(self): # execute p.process_log_record( + api, session, query, - 'https://test.local.test', - '12345', new_record, ) @@ -1312,9 +1401,10 @@ def test_pipeline_process_log_record(self): ''' # setup + api = ApiStub(lines=self.log_rows) data_cache = DataCacheStub( cached_logs={ - '00001111AAAABBBB': ['YYZ:abcdef123456', 'YYZ:fedcba654321'] + '00001111AAAABBBB': ['YYZ:abcdef123456'] } ) newrelic = NewRelicStub() @@ -1336,15 +1426,23 @@ def test_pipeline_process_log_record(self): # execute p.process_log_record( + api, session, query, - 'https://test.local.test', - '12345', record, ) # verify - self.assertEqual(len(newrelic.logs), 0) + self.assertEqual(len(newrelic.logs), 1) + l = newrelic.logs[0] + self.assertEqual(len(l), 1) + l = l[0] + self.assertTrue('logs' in l) + self.assertTrue('common' in l) + self.assertTrue(type(l['common']) is dict) + self.assertTrue('foo' in l['common']) + self.assertEqual(l['common']['foo'], 'bar') + self.assertEqual(len(l['logs']), 1) def test_pipeline_process_event_records(self): ''' @@ -1368,7 +1466,7 @@ def test_pipeline_process_event_records(self): # setup cfg = config.Config({}) newrelic = NewRelicStub() - query = QueryStub({ 'id': ['Name'] }) + query = QueryStub(config={ 'id': ['Name'] }) p = pipeline.Pipeline( cfg, @@ -1449,9 +1547,9 @@ def test_pipeline_execute(self): ''' # setup + api = ApiStub(lines=self.log_rows) cfg = config.Config({}) session = SessionStub() - session.response = ResponseStub(200, 'OK', '', self.log_rows) newrelic = NewRelicStub() query = QueryStub({}) data_cache = DataCacheStub() @@ -1471,10 +1569,9 @@ def test_pipeline_execute(self): # execute p.execute( + api, session, query, - 'https://test.local.test', - '12345', self.log_records, ) @@ -1493,6 +1590,82 @@ def test_pipeline_execute(self): self.assertTrue(data_cache.flush_called) + ''' + given: the values from use case 1 + when: the pipeline is configured with the configuration, session, + newrelic instance, data format, labels, event type fields mapping, + and numeric field names + and when: the first record in the result set contains a 'LogFile' + attribute + and when: process_log_record() raises a LoginException + then: raise a LoginException + ''' + + # setup + api = ApiStub(raise_login_error=True) + cfg = config.Config({}) + session = SessionStub() + newrelic = NewRelicStub() + query = QueryStub({}) + data_cache = DataCacheStub() + + p = pipeline.Pipeline( + cfg, + data_cache, + newrelic, + DataFormat.LOGS, + { 'foo': 'bar' }, + {}, + set(), + ) + + # execute / verify + with self.assertRaises(LoginException) as _: + p.execute( + api, + session, + query, + self.log_records, + ) + + ''' + given: the values from use case 1 + when: the pipeline is configured with the configuration, session, + newrelic instance, data format, labels, event type fields mapping, + and numeric field names + and when: the first record in the result set contains a 'LogFile' + attribute + and when: process_log_record() raises a SalesforceApiException + then: raise a SalesforceApiException + ''' + + # setup + api = ApiStub(raise_error=True) + cfg = config.Config({}) + session = SessionStub() + newrelic = NewRelicStub() + query = QueryStub({}) + data_cache = DataCacheStub() + + p = pipeline.Pipeline( + cfg, + data_cache, + newrelic, + DataFormat.LOGS, + { 'foo': 'bar' }, + {}, + set(), + ) + + # execute / verify + with self.assertRaises(SalesforceApiException) as _: + p.execute( + api, + session, + query, + self.log_records, + ) + ''' given: the values from use case 1 when: the pipeline is configured as in use case 1 @@ -1501,11 +1674,11 @@ def test_pipeline_execute(self): and when: a data cache is specified and when: the number of event records to be processed is less than the maximum number of rows - then: a single Events API post is made containing all labels in the + then: a single Logs API post is made containing all labels in the 'common' property of the logs post and one log for each exported and transformed event record, and the cache is flushed ''' - + api = ApiStub() cfg = config.Config({}) newrelic = NewRelicStub() query = QueryStub({ 'id': ['Name'] }) @@ -1524,10 +1697,9 @@ def test_pipeline_execute(self): self.assertEqual(len(newrelic.logs), 0) p.execute( + api, session, query, - 'https://test.local.test', - '12345', self.event_records, ) diff --git a/src/tests/test_query.py b/src/tests/test_query.py index 74fc5c7..2cc7645 100644 --- a/src/tests/test_query.py +++ b/src/tests/test_query.py @@ -1,10 +1,10 @@ from datetime import datetime import unittest -from newrelic_logging import SalesforceApiException +from newrelic_logging import LoginException, SalesforceApiException from newrelic_logging import config as mod_config, query, util from . import \ - ResponseStub, \ + ApiStub, \ SessionStub class TestQuery(unittest.TestCase): @@ -18,13 +18,14 @@ def test_get_returns_backing_config_value_when_key_exists(self): ''' # setup + api = ApiStub() config = mod_config.Config({ 'foo': 'bar' }) # execute q = query.Query( + api, 'SELECT+LogFile+FROM+EventLogFile', config, - '55.0', ) val = q.get('foo') @@ -41,96 +42,128 @@ def test_get_returns_backing_config_default_when_key_missing(self): ''' # setup + api = ApiStub() config = mod_config.Config({}) # execute q = query.Query( + api, 'SELECT+LogFile+FROM+EventLogFile', config, - '55.0', ) val = q.get('foo', 'beep') # verify self.assertEqual(val, 'beep') - def test_execute_raises_exception_on_non_200_response(self): + def test_execute_raises_login_exception_if_api_query_does(self): ''' - execute() raises exception on non-200 status code from Salesforce API - given: a query string, a configuration, and an api version - when: execute() is called with an http session, an instance url, and an access token - and when: the response produces a non-200 status code - then: raise a SalesforceApiException + execute() raises a LoginException if api.query() raises a LoginException + given: an api instance, a query string, and a configuration + when: execute() is called with an http session + then: calls api.query() with the given session, query string, and no api version + and when: api.query() raises a LoginException (as a result of a reauthenticate) + then: raise a LoginException ''' # setup + api = ApiStub(raise_login_error=True) config = mod_config.Config({}) session = SessionStub() - session.response = ResponseStub(500, 'Error', '', []) # execute/verify q = query.Query( + api, 'SELECT+LogFile+FROM+EventLogFile', config, - '55.0', ) - with self.assertRaises(SalesforceApiException) as _: - q.execute(session, 'https://my.salesforce.test', '123456') + with self.assertRaises(LoginException) as _: + q.execute(session) - def test_execute_raises_exception_if_session_get_does(self): + def test_execute_raises_salesforce_api_exception_if_api_query_does(self): ''' - execute() raises exception if session.get() raises a RequestException - given: a query string, a configuration, and an api version - when: execute() is called with an http session, an instance url, and an access token - and when: session.get() raises a RequestException + execute() raises a SalesforceApiException if api.query() raises a SalesforceApiException + given: an api instance, a query string, and a configuration + when: execute() is called with an http session + then: calls api.query() with the given session, query string, and no api version + and when: api.query() raises a SalesforceApiException then: raise a SalesforceApiException ''' # setup + api = ApiStub(raise_error=True) config = mod_config.Config({}) - session = SessionStub(raise_error=True) - session.response = ResponseStub(200, 'OK', '[]', [] ) + session = SessionStub() # execute/verify q = query.Query( + api, 'SELECT+LogFile+FROM+EventLogFile', config, - '55.0', ) with self.assertRaises(SalesforceApiException) as _: - q.execute(session, 'https://my.salesforce.test', '123456') + q.execute(session) - def test_execute_calls_query_api_url_with_token_and_returns_json_response(self): + def test_execute_calls_query_api_with_query_and_returns_result(self): ''' - execute() calls the correct query API url with the access token and returns a json response - given: a query string, a configuration, and an api version - when: execute() is called with an http session, an instance url, and an access token - then: a get request is made to the correct API url with the given access token and returns a json response + execute() calls api.query() with the given session, query string, and no api version and returns the result + given: an api instance, a query string, and a configuration + when: execute() is called with an http session + then: calls api.query() with the given session, query string, and no api version + and: returns query result ''' # setup + api = ApiStub(query_result={ 'foo': 'bar' }) config = mod_config.Config({}) session = SessionStub() - session.response = ResponseStub(200, 'OK', '{"foo": "bar"}', [] ) # execute q = query.Query( + api, 'SELECT+LogFile+FROM+EventLogFile', config, - '55.0', ) - resp = q.execute(session, 'https://my.salesforce.test', '123456') + resp = q.execute(session) # verify - self.assertEqual( - session.url, - f'https://my.salesforce.test/services/data/v55.0/query?q=SELECT+LogFile+FROM+EventLogFile', + self.assertEqual('SELECT+LogFile+FROM+EventLogFile', api.soql) + self.assertIsNone(api.query_api_ver) + self.assertIsNotNone(resp) + self.assertTrue(type(resp) is dict) + self.assertTrue('foo' in resp) + self.assertEqual(resp['foo'], 'bar') + + def test_execute_calls_query_api_with_query_and_api_ver_and_returns_result(self): + ''' + execute() calls api.query() with the given session, query string, and api version and returns the result + given: an api instance, a query string, a configuration, and an api version + when: execute() is called with an http session + then: calls api.query() with the given session, query string, and api version + and: returns query result + ''' + + # setup + api = ApiStub(query_result={ 'foo': 'bar' }) + config = mod_config.Config({}) + session = SessionStub() + + # execute + q = query.Query( + api, + 'SELECT+LogFile+FROM+EventLogFile', + config, + '52.0', ) - self.assertTrue('Authorization' in session.headers) - self.assertEqual(session.headers['Authorization'], 'Bearer 123456') + + resp = q.execute(session) + + # verify + self.assertEqual('SELECT+LogFile+FROM+EventLogFile', api.soql) + self.assertEqual(api.query_api_ver, '52.0') self.assertIsNotNone(resp) self.assertTrue(type(resp) is dict) self.assertTrue('foo' in resp) @@ -251,7 +284,7 @@ def test_new_returns_query_obj_with_encoded_query_with_args_replaced(self): ''' new() returns a query instance with the given query with arguments replaced and URL encoded given: a query factory - when: new() is called with a query dict, lag time, timestamp, generation interval, and default api version + when: new() is called with an Api instance, query dict, lag time, timestamp, and generation interval then: returns a query instance with the input query with arguments replaced and URL encoded ''' @@ -264,6 +297,7 @@ def _utcnow(): util._UTCNOW = _utcnow + api = ApiStub() to_timestamp = util.get_iso_date_with_offset(time_lag_minutes=500) last_to_timestamp = util.get_iso_date_with_offset(time_lag_minutes=1000) now = _now.isoformat(timespec='milliseconds') + "Z" @@ -272,6 +306,7 @@ def _utcnow(): # execute f = query.QueryFactory() q = f.new( + api, { 'query': 'SELECT LogFile FROM EventLogFile WHERE CreatedDate>={from_timestamp} AND CreatedDate<{to_timestamp} AND LogIntervalType={log_interval_type} AND Foo={foo}', 'env': env, @@ -279,7 +314,6 @@ def _utcnow(): 500, last_to_timestamp, 'Daily', - '55.0', ) # verify @@ -292,7 +326,7 @@ def test_new_returns_query_obj_with_expected_config(self): ''' new() returns a query instance with the input query dict minus the query property given: a query factory - when: new() is called with a query dict, lag time, timestamp, generation interval, and default api version + when: new() is called with an Api instance, query dict, lag time, timestamp, and generation interval then: returns a query instance with a config equal to the input query dict minus the query property ''' @@ -300,8 +334,10 @@ def test_new_returns_query_obj_with_expected_config(self): last_to_timestamp = util.get_iso_date_with_offset(time_lag_minutes=1000) # execute + api = ApiStub() f = query.QueryFactory() q = f.new( + api, { 'query': 'SELECT LogFile FROM EventLogFile', 'foo': 'bar', @@ -312,7 +348,6 @@ def test_new_returns_query_obj_with_expected_config(self): 500, last_to_timestamp, 'Daily', - '55.0', ) # verify @@ -334,16 +369,19 @@ def test_new_returns_query_obj_with_given_api_ver(self): ''' new() returns a query instance with the api version specified in the query dict given: a query factory - when: new() is called with a query dict, lag time, timestamp, generation interval, and default api version + when: new() is called with an Api instance, query dict, lag time, timestamp, and generation interval + and when: an api version is specified in the query dict then: returns a query instance with the api version specified in the query dict ''' # setup + api = ApiStub(api_ver='54.0') last_to_timestamp = util.get_iso_date_with_offset(time_lag_minutes=1000) # execute f = query.QueryFactory() q = f.new( + api, { 'query': 'SELECT LogFile FROM EventLogFile', 'api_ver': '58.0' @@ -351,7 +389,6 @@ def test_new_returns_query_obj_with_given_api_ver(self): 500, last_to_timestamp, 'Daily', - '53.0', ) # verify @@ -359,26 +396,28 @@ def test_new_returns_query_obj_with_given_api_ver(self): def test_new_returns_query_obj_with_default_api_ver(self): ''' - new() returns a query instance with the default api version specified on the new() call + new() returns a query instance without an api version given: a query factory - when: new() is called with a query dict, lag time, timestamp, generation interval, and default api version - then: returns a query instance with the default api version specified on the the new() call + when: new() is called with an Api instance, query dict, lag time, timestamp, and generation interval + and when: no api version is specified in the query dict + then: returns a query instance without an api version ''' # setup + api = ApiStub(api_ver='54.0') last_to_timestamp = util.get_iso_date_with_offset(time_lag_minutes=1000) # execute f = query.QueryFactory() q = f.new( + api, { 'query': 'SELECT LogFile FROM EventLogFile', }, 500, last_to_timestamp, 'Daily', - '53.0', ) # verify - self.assertEqual(q.api_ver, '53.0') + self.assertIsNone(q.api_ver) diff --git a/src/tests/test_salesforce.py b/src/tests/test_salesforce.py index 4f2a02d..2e61137 100644 --- a/src/tests/test_salesforce.py +++ b/src/tests/test_salesforce.py @@ -1,14 +1,22 @@ from datetime import datetime, timedelta import unittest + from . import \ + ApiFactoryStub, \ AuthenticatorStub, \ DataCacheStub, \ PipelineStub, \ QueryStub, \ QueryFactoryStub, \ SessionStub -from newrelic_logging import config, salesforce, util +from newrelic_logging import \ + config, \ + salesforce, \ + util, \ + LoginException, \ + SalesforceApiException + class TestSalesforce(unittest.TestCase): def test_init(self): @@ -43,6 +51,7 @@ def _utcnow(): auth = AuthenticatorStub() pipeline = PipelineStub() data_cache = DataCacheStub() + api_factory = ApiFactoryStub() query_factory = QueryFactoryStub() last_to_timestamp = util.get_iso_date_with_offset( time_lag_minutes, @@ -56,19 +65,21 @@ def _utcnow(): None, auth, pipeline, + api_factory, query_factory, initial_delay, ) # verify self.assertEqual(client.instance_name, 'test_instance') - self.assertEqual(client.default_api_ver, '55.0') self.assertEqual(client.data_cache, None) - self.assertEqual(client.auth, auth) self.assertEqual(client.time_lag_minutes, time_lag_minutes) self.assertEqual(client.date_field, 'CreateDate') self.assertEqual(client.generation_interval, 'Hourly') self.assertEqual(client.last_to_timestamp, last_to_timestamp) + self.assertIsNotNone(client.api) + self.assertEqual(client.api.authenticator, auth) + self.assertEqual(client.api.api_ver, '55.0') self.assertTrue(len(client.queries) == 1) self.assertTrue('query' in client.queries[0]) self.assertEqual( @@ -104,6 +115,7 @@ def _utcnow(): None, auth, pipeline, + api_factory, query_factory, initial_delay, ) @@ -138,6 +150,7 @@ def _utcnow(): data_cache, auth, pipeline, + api_factory, query_factory, initial_delay, ) @@ -173,6 +186,7 @@ def _utcnow(): None, auth, pipeline, + api_factory, query_factory, initial_delay, ) @@ -197,6 +211,7 @@ def _utcnow(): data_cache, auth, pipeline, + api_factory, query_factory, initial_delay, ) @@ -229,6 +244,7 @@ def _utcnow(): None, auth, pipeline, + api_factory, query_factory, initial_delay, ) @@ -265,6 +281,7 @@ def _utcnow(): None, auth, pipeline, + api_factory, query_factory, initial_delay, ) @@ -296,6 +313,7 @@ def _utcnow(): None, auth, pipeline, + api_factory, query_factory, initial_delay, queries @@ -308,13 +326,13 @@ def _utcnow(): self.assertTrue('query' in client.queries[1]) self.assertEqual(client.queries[1]['query'], 'bar') - def test_authenticate(self): + def test_authenticate_is_called(self): ''' given: an instance name and configuration, a data cache, authenticator, pipeline, query factory, initial delay value, set of queries, and an http session when: called - then: the underlying authenticator is called + then: the underlying api is called ''' # setup @@ -328,6 +346,7 @@ def test_authenticate(self): }) auth = AuthenticatorStub() pipeline = PipelineStub() + api_factory = ApiFactoryStub() query_factory = QueryFactoryStub() session = SessionStub() @@ -337,6 +356,7 @@ def test_authenticate(self): None, auth, pipeline, + api_factory, query_factory, initial_delay, ) @@ -345,6 +365,53 @@ def test_authenticate(self): client.authenticate(session) # verify + self.assertIsNotNone(client.api) + self.assertEqual(client.api.authenticator, auth) + self.assertTrue(auth.authenticate_called) + + def test_authenticate_raises_login_exception_if_authenticate_does(self): + ''' + given: an instance name and configuration, a data cache, authenticator, + pipeline, query factory, initial delay value, set of queries, + and an http session + when: called + and when: authenticator.authenticate() raises a LoginException + then: raise a LoginException + ''' + + # setup + time_lag_minutes = 603 + initial_delay = 5 + cfg = config.Config({ + 'api_ver': '55.0', + 'time_lag_minutes': time_lag_minutes, + 'date_field': 'CreateDate', + 'generation_interval': 'Hourly', + }) + auth = AuthenticatorStub(raise_login_error=True) + pipeline = PipelineStub() + api_factory = ApiFactoryStub() + query_factory = QueryFactoryStub() + session = SessionStub() + + client = salesforce.SalesForce( + 'test_instance', + cfg, + None, + auth, + pipeline, + api_factory, + query_factory, + initial_delay, + ) + + # execute / verify + with self.assertRaises(LoginException) as _: + client.authenticate(session) + + # verify + self.assertIsNotNone(client.api) + self.assertEqual(client.api.authenticator, auth) self.assertTrue(auth.authenticate_called) def test_slide_time_range(self): @@ -375,6 +442,7 @@ def _utcnow(): }) auth = AuthenticatorStub() pipeline = PipelineStub() + api_factory = ApiFactoryStub() query_factory = QueryFactoryStub() client = salesforce.SalesForce( @@ -383,6 +451,7 @@ def _utcnow(): None, auth, pipeline, + api_factory, query_factory, initial_delay, ) @@ -421,6 +490,7 @@ def test_fetch_logs(self): }) auth = AuthenticatorStub() pipeline = PipelineStub() + api_factory = ApiFactoryStub() query_factory = QueryFactoryStub() session = SessionStub() @@ -430,6 +500,7 @@ def test_fetch_logs(self): None, auth, pipeline, + api_factory, query_factory, initial_delay, ) @@ -457,6 +528,7 @@ def test_fetch_logs(self): auth = AuthenticatorStub() pipeline = PipelineStub() query_factory = QueryFactoryStub() + api_factory = ApiFactoryStub() session = SessionStub() queries = [ { @@ -479,6 +551,7 @@ def test_fetch_logs(self): None, auth, pipeline, + api_factory, query_factory, initial_delay, queries, @@ -523,6 +596,7 @@ def test_fetch_logs(self): pipeline = PipelineStub() query = QueryStub(result=None) query_factory = QueryFactoryStub(query) + api_factory = ApiFactoryStub() session = SessionStub() client = salesforce.SalesForce( @@ -531,6 +605,7 @@ def test_fetch_logs(self): None, auth, pipeline, + api_factory, query_factory, initial_delay, ) @@ -555,6 +630,7 @@ def test_fetch_logs(self): pipeline = PipelineStub() query = QueryStub(result={ 'foo': 'bar' }) query_factory = QueryFactoryStub(query) + api_factory = ApiFactoryStub() session = SessionStub() client = salesforce.SalesForce( @@ -563,6 +639,7 @@ def test_fetch_logs(self): None, auth, pipeline, + api_factory, query_factory, initial_delay, ) @@ -575,6 +652,66 @@ def test_fetch_logs(self): self.assertTrue(query.executed) self.assertFalse(pipeline.executed) + ''' + given: an instance name and configuration, data cache, authenticator, + pipeline, api factory, query factory, initial delay value, and an + http session + when: pipeline.execute() raises a LoginException + then: raise a LoginException + ''' + + auth = AuthenticatorStub() + pipeline = PipelineStub(raise_login_error=True) + query = QueryStub() + query_factory = QueryFactoryStub(query) + api_factory = ApiFactoryStub() + session = SessionStub() + + client = salesforce.SalesForce( + 'test_instance', + cfg, + None, + auth, + pipeline, + api_factory, + query_factory, + initial_delay, + ) + + # execute / verify + with self.assertRaises(LoginException) as _: + client.fetch_logs(session) + + ''' + given: an instance name and configuration, data cache, authenticator, + pipeline, api factory, query factory, initial delay value, and an + http session + when: pipeline.execute() raises a SalesforceApiException + then: raise a SalesforceApiException + ''' + + auth = AuthenticatorStub() + pipeline = PipelineStub(raise_error=True) + query = QueryStub() + query_factory = QueryFactoryStub(query) + api_factory = ApiFactoryStub() + session = SessionStub() + + client = salesforce.SalesForce( + 'test_instance', + cfg, + None, + auth, + pipeline, + api_factory, + query_factory, + initial_delay, + ) + + # execute / verify + with self.assertRaises(SalesforceApiException) as _: + client.fetch_logs(session) + if __name__ == '__main__': unittest.main() From 1eddda425d072d67a1418b116eed5d3a0af27808 Mon Sep 17 00:00:00 2001 From: Scott DeWitt Date: Wed, 3 Apr 2024 11:11:04 -0400 Subject: [PATCH 09/11] fix: TypeError in get_iso_date_with_offset due to missing ApiFactory parameter in SalesforceFactory.new --- src/__main__.py | 2 ++ src/newrelic_logging/integration.py | 5 +++++ src/newrelic_logging/salesforce.py | 2 ++ 3 files changed, 9 insertions(+) diff --git a/src/__main__.py b/src/__main__.py index 98c0823..99d16c7 100644 --- a/src/__main__.py +++ b/src/__main__.py @@ -13,6 +13,7 @@ from apscheduler.schedulers.background import BlockingScheduler from pytz import utc from yaml import Loader, load +from newrelic_logging.api import ApiFactory from newrelic_logging.auth import AuthenticatorFactory from newrelic_logging.cache import CacheFactory, BackendFactory from newrelic_logging.config import Config, getenv @@ -141,6 +142,7 @@ def run_once( CacheFactory(BackendFactory()), PipelineFactory(), SalesForceFactory(), + ApiFactory(), QueryFactory(), NewRelicFactory(), event_type_fields_mapping, diff --git a/src/newrelic_logging/integration.py b/src/newrelic_logging/integration.py index 5b4e59b..afe2310 100644 --- a/src/newrelic_logging/integration.py +++ b/src/newrelic_logging/integration.py @@ -7,6 +7,7 @@ LoginException, \ NewRelicApiException, \ SalesforceApiException +from . import api from . import auth from . import cache from . import config as mod_config @@ -28,6 +29,7 @@ def build_instance( cache_factory: cache.CacheFactory, pipeline_factory: pipeline.PipelineFactory, salesforce_factory: salesforce.SalesForceFactory, + api_factory: api.ApiFactory, query_factory: query.QueryFactory, new_relic: newrelic.NewRelic, data_format: DataFormat, @@ -64,6 +66,7 @@ def build_instance( event_type_fields_mapping, numeric_fields_list, ), + api_factory, query_factory, initial_delay, config['queries'] if 'queries' in config else None, @@ -79,6 +82,7 @@ def __init__( cache_factory: cache.CacheFactory, pipeline_factory: pipeline.PipelineFactory, salesforce_factory: salesforce.SalesForceFactory, + api_factory: api.ApiFactory, query_factory: query.QueryFactory, newrelic_factory: newrelic.NewRelicFactory, event_type_fields_mapping: dict = {}, @@ -112,6 +116,7 @@ def __init__( cache_factory, pipeline_factory, salesforce_factory, + api_factory, query_factory, self.new_relic, data_format, diff --git a/src/newrelic_logging/salesforce.py b/src/newrelic_logging/salesforce.py index cb3556f..a117c72 100644 --- a/src/newrelic_logging/salesforce.py +++ b/src/newrelic_logging/salesforce.py @@ -115,6 +115,7 @@ def new( data_cache: DataCache, authenticator: Authenticator, pipeline: Pipeline, + api_factory: ApiFactory, query_factory: mod_query.QueryFactory, initial_delay: int, queries: list[dict] = None, @@ -125,6 +126,7 @@ def new( data_cache, authenticator, pipeline, + api_factory, query_factory, initial_delay, queries, From bb3836415b5e820c49e9409b2c52fb4e48e4703d Mon Sep 17 00:00:00 2001 From: Scott DeWitt Date: Thu, 4 Apr 2024 16:46:17 -0400 Subject: [PATCH 10/11] fix: use new auth token on get() after reauthenticate() --- src/newrelic_logging/api.py | 6 ++- src/tests/__init__.py | 10 ++++- src/tests/test_api.py | 69 +++++++++++++++++++++++++++++++---- src/tests/test_integration.py | 17 ++++++++- 4 files changed, 91 insertions(+), 11 deletions(-) diff --git a/src/newrelic_logging/api.py b/src/newrelic_logging/api.py index 7b94172..9c0f000 100644 --- a/src/newrelic_logging/api.py +++ b/src/newrelic_logging/api.py @@ -32,7 +32,11 @@ def get( ) auth.reauthenticate(session) - response = session.get(url, headers=headers, stream=stream) + new_headers = { + 'Authorization': f'Bearer {auth.get_access_token()}' + } + + response = session.get(url, headers=new_headers, stream=stream) if response.status_code == 200: return cb(response) diff --git a/src/tests/__init__.py b/src/tests/__init__.py index 06b3689..30540fb 100644 --- a/src/tests/__init__.py +++ b/src/tests/__init__.py @@ -4,7 +4,7 @@ from requests import Session, RequestException from newrelic_logging import DataFormat, LoginException, SalesforceApiException -from newrelic_logging.api import Api +from newrelic_logging.api import Api, ApiFactory from newrelic_logging.auth import Authenticator from newrelic_logging.cache import DataCache from newrelic_logging.config import Config @@ -84,6 +84,7 @@ def __init__( data_cache: DataCache = None, token_url: str = '', access_token: str = '', + access_token_2: str = '', instance_url: str = '', grant_type: str = '', authenticate_called: bool = False, @@ -94,6 +95,7 @@ def __init__( self.data_cache = data_cache self.token_url = token_url self.access_token = access_token + self.access_token_2 = access_token_2 self.instance_url = instance_url self.grant_type = grant_type self.authenticate_called = authenticate_called @@ -143,6 +145,8 @@ def reauthenticate( if self.raise_login_error: raise LoginException('Unauthorized') + self.access_token = self.access_token_2 + class AuthenticatorFactoryStub: def __init__(self): @@ -463,6 +467,7 @@ def __init__( data_cache: DataCache, authenticator: Authenticator, pipeline: Pipeline, + api_factory: ApiFactory, query_factory: QueryFactory, initial_delay: int, queries: list[dict] = None, @@ -472,6 +477,7 @@ def __init__( self.data_cache = data_cache self.authenticator = authenticator self.pipeline = pipeline + self.api_factory = api_factory self.query_factory = query_factory self.initial_delay = initial_delay self.queries = queries @@ -488,6 +494,7 @@ def new( data_cache: DataCache, authenticator: Authenticator, pipeline: Pipeline, + api_factory: ApiFactory, query_factory: QueryFactory, initial_delay: int, queries: list[dict] = None, @@ -498,6 +505,7 @@ def new( data_cache, authenticator, pipeline, + api_factory, query_factory, initial_delay, queries, diff --git a/src/tests/test_api.py b/src/tests/test_api.py index 6459a7d..9fd7371 100644 --- a/src/tests/test_api.py +++ b/src/tests/test_api.py @@ -256,7 +256,7 @@ def test_get_calls_reauthenticate_on_401_and_invokes_cb_with_response_on_200(sel and when: response status code is 401 then: reauthenticate() is called and when: reauthenticate() does not throw a LoginException - then: request is executed again with the same parameters + then: request is executed again with the same URL and stream setting as the first call to session.get() and the second access token and when: it returns a 200 then: calls callback with response and returns result ''' @@ -265,6 +265,7 @@ def test_get_calls_reauthenticate_on_401_and_invokes_cb_with_response_on_200(sel auth = AuthenticatorStub( instance_url='https://my.salesforce.test', access_token='123456', + access_token_2='567890', ) response1 = ResponseStub(401, 'Unauthorized', 'Unauthorized', []) response2 = ResponseStub(200, 'OK', 'OK', []) @@ -299,7 +300,7 @@ def cb(response): self.assertTrue('Authorization' in session.requests[1]['headers']) self.assertEqual( session.requests[1]['headers']['Authorization'], - 'Bearer 123456', + 'Bearer 567890', ) self.assertEqual( session.requests[1]['stream'], @@ -308,9 +309,9 @@ def cb(response): self.assertIsNotNone(val) self.assertEqual(val, 'OK') - def test_get_passes_same_params_to_get_on_reauthenticate(self): + def test_get_passed_correct_params_after_reauthenticate(self): ''' - get() receives the same set of parameters on the second call after reauthenticate() succeeds + get() receives the correct set of parameters when it is called after reauthenticate() succeeds given: an authenticator and given: a session and given: a service url @@ -320,13 +321,14 @@ def test_get_passes_same_params_to_get_on_reauthenticate(self): and when: response status code is 401 then: reauthenticate() is called and when: reauthenticate() does not throw a LoginException - then: request is executed again with the same set of parameters as the first call to session.get() + then: request is executed again with the same URL and stream setting as the first call to session.get() and the second access token ''' # setup auth = AuthenticatorStub( instance_url='https://my.salesforce.test', access_token='123456', + access_token_2='567890' ) response1 = ResponseStub(401, 'Unauthorized', 'Unauthorized', []) response2 = ResponseStub(200, 'OK', 'OK', []) @@ -340,8 +342,33 @@ def cb(response): # verify self.assertEqual(len(session.requests), 2) - self.assertEqual(session.requests[0], session.requests[1]) self.assertTrue(auth.reauthenticate_called) + self.assertEqual( + session.requests[0]['url'], + 'https://my.salesforce.test/foo', + ) + self.assertTrue('Authorization' in session.requests[0]['headers']) + self.assertEqual( + session.requests[0]['headers']['Authorization'], + 'Bearer 123456', + ) + self.assertEqual( + session.requests[0]['stream'], + True, + ) + self.assertEqual( + session.requests[1]['url'], + 'https://my.salesforce.test/foo', + ) + self.assertTrue('Authorization' in session.requests[1]['headers']) + self.assertEqual( + session.requests[1]['headers']['Authorization'], + 'Bearer 567890', + ) + self.assertEqual( + session.requests[1]['stream'], + True, + ) self.assertIsNotNone(val) self.assertEqual(val, 'OK') @@ -357,7 +384,7 @@ def test_get_calls_reauthenticate_on_401_and_raises_on_non_200(self): and when: response status code is 401 then: reauthenticate() is called and when: reauthenticate() does not throw a LoginException - then: request is executed again with the same parameters + then: request is executed again with the same URL and stream setting as the first call to session.get() and the second access token and when: it returns a non-200 status code then: throws a SalesforceApiException ''' @@ -366,6 +393,7 @@ def test_get_calls_reauthenticate_on_401_and_raises_on_non_200(self): auth = AuthenticatorStub( instance_url='https://my.salesforce.test', access_token='123456', + access_token_2='567890', ) response1 = ResponseStub(401, 'Unauthorized', 'Unauthorized', []) response2 = ResponseStub(401, 'Unauthorized', 'Unauthorized 2', []) @@ -384,8 +412,33 @@ def cb(response): ) self.assertEqual(len(session.requests), 2) - self.assertEqual(session.requests[0], session.requests[1]) self.assertTrue(auth.reauthenticate_called) + self.assertEqual( + session.requests[0]['url'], + 'https://my.salesforce.test/foo', + ) + self.assertTrue('Authorization' in session.requests[0]['headers']) + self.assertEqual( + session.requests[0]['headers']['Authorization'], + 'Bearer 123456', + ) + self.assertEqual( + session.requests[0]['stream'], + False, + ) + self.assertEqual( + session.requests[1]['url'], + 'https://my.salesforce.test/foo', + ) + self.assertTrue('Authorization' in session.requests[1]['headers']) + self.assertEqual( + session.requests[1]['headers']['Authorization'], + 'Bearer 567890', + ) + self.assertEqual( + session.requests[1]['stream'], + False, + ) def test_stream_lines_sets_fallback_encoding_and_calls_iter_lines_with_chunk_size_and_decode_unicode(self): ''' diff --git a/src/tests/test_integration.py b/src/tests/test_integration.py index 4d7a60a..96e9666 100644 --- a/src/tests/test_integration.py +++ b/src/tests/test_integration.py @@ -1,6 +1,7 @@ import unittest -from . import AuthenticatorFactoryStub, \ +from . import ApiFactoryStub, \ + AuthenticatorFactoryStub, \ CacheFactoryStub, \ NewRelicStub, \ NewRelicFactoryStub, \ @@ -38,6 +39,7 @@ def test_build_instance(self): } ] }) + api_factory = ApiFactoryStub() auth_factory = AuthenticatorFactoryStub() cache_factory = CacheFactoryStub() pipeline_factory = PipelineFactoryStub() @@ -54,6 +56,7 @@ def test_build_instance(self): cache_factory, pipeline_factory, salesforce_factory, + api_factory, query_factory, new_relic, DataFormat.EVENTS, @@ -141,6 +144,7 @@ def test_build_instance(self): } ] }) + api_factory = ApiFactoryStub() auth_factory = AuthenticatorFactoryStub() cache_factory = CacheFactoryStub() pipeline_factory = PipelineFactoryStub() @@ -154,6 +158,7 @@ def test_build_instance(self): cache_factory, pipeline_factory, salesforce_factory, + api_factory, query_factory, new_relic, DataFormat.EVENTS, @@ -200,6 +205,7 @@ def test_build_instance(self): } ] }) + api_factory = ApiFactoryStub() auth_factory = AuthenticatorFactoryStub() cache_factory = CacheFactoryStub() pipeline_factory = PipelineFactoryStub() @@ -213,6 +219,7 @@ def test_build_instance(self): cache_factory, pipeline_factory, salesforce_factory, + api_factory, query_factory, new_relic, DataFormat.EVENTS, @@ -273,6 +280,7 @@ def test_init(self): } }) + api_factory = ApiFactoryStub() auth_factory = AuthenticatorFactoryStub() cache_factory = CacheFactoryStub() pipeline_factory = PipelineFactoryStub() @@ -290,6 +298,7 @@ def test_init(self): cache_factory, pipeline_factory, salesforce_factory, + api_factory, query_factory, newrelic_factory, event_type_fields_mapping, @@ -345,6 +354,7 @@ def test_init(self): } }) + api_factory = ApiFactoryStub() auth_factory = AuthenticatorFactoryStub() cache_factory = CacheFactoryStub() pipeline_factory = PipelineFactoryStub() @@ -361,6 +371,7 @@ def test_init(self): cache_factory, pipeline_factory, salesforce_factory, + api_factory, query_factory, newrelic_factory, event_type_fields_mapping, @@ -384,6 +395,7 @@ def test_init(self): } }) + api_factory = ApiFactoryStub() auth_factory = AuthenticatorFactoryStub() cache_factory = CacheFactoryStub() pipeline_factory = PipelineFactoryStub() @@ -399,6 +411,7 @@ def test_init(self): cache_factory, pipeline_factory, salesforce_factory, + api_factory, query_factory, newrelic_factory, event_type_fields_mapping, @@ -425,6 +438,7 @@ def test_init(self): } }) + api_factory = ApiFactoryStub() auth_factory = AuthenticatorFactoryStub() cache_factory = CacheFactoryStub() pipeline_factory = PipelineFactoryStub() @@ -440,6 +454,7 @@ def test_init(self): cache_factory, pipeline_factory, salesforce_factory, + api_factory, query_factory, newrelic_factory, event_type_fields_mapping, From 0fc3dc4fe378b562818b0bcd5ddc57290f0cba22 Mon Sep 17 00:00:00 2001 From: Scott DeWitt Date: Thu, 18 Apr 2024 16:58:27 -0400 Subject: [PATCH 11/11] chore: change name to New Relic Salesforce Exporter --- newrelic.ini | 2 +- src/newrelic_logging/__init__.py | 6 +++--- src/newrelic_logging/integration.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/newrelic.ini b/newrelic.ini index 0de7da5..7092cd7 100644 --- a/newrelic.ini +++ b/newrelic.ini @@ -40,7 +40,7 @@ # app names to group your aggregated data. For further details, # please see: # https://docs.newrelic.com/docs/apm/agents/manage-apm-agents/app-naming/use-multiple-names-app/ -app_name = Salesforce Eventlogfile Integration +app_name = New Relic Salesforce Exporter # When "true", the agent collects performance data about your # application and reports this data to the New Relic UI at diff --git a/src/newrelic_logging/__init__.py b/src/newrelic_logging/__init__.py index 91deb11..f9a3907 100644 --- a/src/newrelic_logging/__init__.py +++ b/src/newrelic_logging/__init__.py @@ -3,10 +3,10 @@ # Integration definitions -VERSION = "1.0.0" -NAME = "salesforce-eventlogfile" +VERSION = "2.0.0" +NAME = "salesforce-exporter" PROVIDER = "newrelic-labs" -COLLECTOR_NAME = "newrelic-logs-salesforce-eventlogfile" +COLLECTOR_NAME = "newrelic-salesforce-exporter" class DataFormat(Enum): diff --git a/src/newrelic_logging/integration.py b/src/newrelic_logging/integration.py index afe2310..fd54871 100644 --- a/src/newrelic_logging/integration.py +++ b/src/newrelic_logging/integration.py @@ -91,7 +91,7 @@ def __init__( ): Telemetry( config['integration_name'] if 'integration_name' in config \ - else 'com.newrelic.labs.sfdc.eventlogfiles' + else 'com.newrelic.labs.salesforce.exporter' ) data_format = config.get('newrelic.data_format', 'logs').lower()