diff --git a/newrelic.ini b/newrelic.ini new file mode 100644 index 0000000..7092cd7 --- /dev/null +++ b/newrelic.ini @@ -0,0 +1,256 @@ +# --------------------------------------------------------------------------- + +# +# This file configures the New Relic Python Agent. +# +# The path to the configuration file should be supplied to the function +# newrelic.agent.initialize() when the agent is being initialized. +# +# The configuration file follows a structure similar to what you would +# find for Microsoft Windows INI files. For further information on the +# configuration file format see the Python ConfigParser documentation at: +# +# https://docs.python.org/library/configparser.html +# +# For further discussion on the behaviour of the Python agent that can +# be configured via this configuration file see: +# +# https://docs.newrelic.com/docs/apm/agents/python-agent/configuration/python-agent-configuration/ +# + +# --------------------------------------------------------------------------- + +# Here are the settings that are common to all environments. + +[newrelic] + +# You must specify the license key associated with your New +# Relic account. This may also be set using the NEW_RELIC_LICENSE_KEY +# environment variable. This key binds the Python Agent's data to +# your account in the New Relic service. For more information on +# storing and generating license keys, see +# https://docs.newrelic.com/docs/apis/intro-apis/new-relic-api-keys/#ingest-license-key +#license_key = [YOUR LICENSE KEY] + +# The application name. Set this to be the name of your +# application as you would like it to show up in New Relic UI. +# You may also set this using the NEW_RELIC_APP_NAME environment variable. +# The UI will then auto-map instances of your application into a +# entry on your home dashboard page. You can also specify multiple +# app names to group your aggregated data. For further details, +# please see: +# https://docs.newrelic.com/docs/apm/agents/manage-apm-agents/app-naming/use-multiple-names-app/ +app_name = New Relic Salesforce Exporter + +# When "true", the agent collects performance data about your +# application and reports this data to the New Relic UI at +# newrelic.com. This global switch is normally overridden for +# each environment below. It may also be set using the +# NEW_RELIC_MONITOR_MODE environment variable. +monitor_mode = true + +# Sets the name of a file to log agent messages to. Whatever you +# set this to, you must ensure that the permissions for the +# containing directory and the file itself are correct, and +# that the user that your web application runs as can write out +# to the file. If not able to out a log file, it is also +# possible to say "stderr" and output to standard error output. +# This would normally result in output appearing in your web +# server log. It can also be set using the NEW_RELIC_LOG +# environment variable. +log_file = stdout + +# Sets the level of detail of messages sent to the log file, if +# a log file location has been provided. Possible values, in +# increasing order of detail, are: "critical", "error", "warning", +# "info" and "debug". When reporting any agent issues to New +# Relic technical support, the most useful setting for the +# support engineers is "debug". However, this can generate a lot +# of information very quickly, so it is best not to keep the +# agent at this level for longer than it takes to reproduce the +# problem you are experiencing. This may also be set using the +# NEW_RELIC_LOG_LEVEL environment variable. +log_level = info + +# High Security Mode enforces certain security settings, and prevents +# them from being overridden, so that no sensitive data is sent to New +# Relic. Enabling High Security Mode means that request parameters are +# not collected and SQL can not be sent to New Relic in its raw form. +# To activate High Security Mode, it must be set to 'true' in this +# local .ini configuration file AND be set to 'true' in the +# server-side configuration in the New Relic user interface. It can +# also be set using the NEW_RELIC_HIGH_SECURITY environment variable. +# For details, see +# https://docs.newrelic.com/docs/subscriptions/high-security +high_security = false + +# The Python Agent will attempt to connect directly to the New +# Relic service. If there is an intermediate firewall between +# your host and the New Relic service that requires you to use a +# HTTP proxy, then you should set both the "proxy_host" and +# "proxy_port" settings to the required values for the HTTP +# proxy. The "proxy_user" and "proxy_pass" settings should +# additionally be set if proxy authentication is implemented by +# the HTTP proxy. The "proxy_scheme" setting dictates what +# protocol scheme is used in talking to the HTTP proxy. This +# would normally always be set as "http" which will result in the +# agent then using a SSL tunnel through the HTTP proxy for end to +# end encryption. +# See https://docs.newrelic.com/docs/apm/agents/python-agent/configuration/python-agent-configuration/#proxy +# for information on proxy configuration via environment variables. +# proxy_scheme = http +# proxy_host = hostname +# proxy_port = 8080 +# proxy_user = +# proxy_pass = + +# Capturing request parameters is off by default. To enable the +# capturing of request parameters, first ensure that the setting +# "attributes.enabled" is set to "true" (the default value), and +# then add "request.parameters.*" to the "attributes.include" +# setting. For details about attributes configuration, please +# consult the documentation. +# attributes.include = request.parameters.* + +# The transaction tracer captures deep information about slow +# transactions and sends this to the UI on a periodic basis. The +# transaction tracer is enabled by default. Set this to "false" +# to turn it off. +transaction_tracer.enabled = true + +# Threshold in seconds for when to collect a transaction trace. +# When the response time of a controller action exceeds this +# threshold, a transaction trace will be recorded and sent to +# the UI. Valid values are any positive float value, or (default) +# "apdex_f", which will use the threshold for a dissatisfying +# Apdex controller action - four times the Apdex T value. +transaction_tracer.transaction_threshold = apdex_f + +# When the transaction tracer is on, SQL statements can +# optionally be recorded. The recorder has three modes, "off" +# which sends no SQL, "raw" which sends the SQL statement in its +# original form, and "obfuscated", which strips out numeric and +# string literals. +transaction_tracer.record_sql = obfuscated + +# Threshold in seconds for when to collect stack trace for a SQL +# call. In other words, when SQL statements exceed this +# threshold, then capture and send to the UI the current stack +# trace. This is helpful for pinpointing where long SQL calls +# originate from in an application. +transaction_tracer.stack_trace_threshold = 0.5 + +# Determines whether the agent will capture query plans for slow +# SQL queries. Only supported in MySQL and PostgreSQL. Set this +# to "false" to turn it off. +transaction_tracer.explain_enabled = true + +# Threshold for query execution time below which query plans +# will not not be captured. Relevant only when "explain_enabled" +# is true. +transaction_tracer.explain_threshold = 0.5 + +# Space separated list of function or method names in form +# 'module:function' or 'module:class.function' for which +# additional function timing instrumentation will be added. +transaction_tracer.function_trace = + +# The error collector captures information about uncaught +# exceptions or logged exceptions and sends them to UI for +# viewing. The error collector is enabled by default. Set this +# to "false" to turn it off. For more details on errors, see +# https://docs.newrelic.com/docs/apm/agents/manage-apm-agents/agent-data/manage-errors-apm-collect-ignore-or-mark-expected/ +error_collector.enabled = true + +# To stop specific errors from reporting to the UI, set this to +# a space separated list of the Python exception type names to +# ignore. The exception name should be of the form 'module:class'. +error_collector.ignore_classes = + +# Expected errors are reported to the UI but will not affect the +# Apdex or error rate. To mark specific errors as expected, set this +# to a space separated list of the Python exception type names to +# expected. The exception name should be of the form 'module:class'. +error_collector.expected_classes = + +# Browser monitoring is the Real User Monitoring feature of the UI. +# For those Python web frameworks that are supported, this +# setting enables the auto-insertion of the browser monitoring +# JavaScript fragments. +browser_monitoring.auto_instrument = true + +# A thread profiling session can be scheduled via the UI when +# this option is enabled. The thread profiler will periodically +# capture a snapshot of the call stack for each active thread in +# the application to construct a statistically representative +# call tree. For more details on the thread profiler tool, see +# https://docs.newrelic.com/docs/apm/apm-ui-pages/events/thread-profiler-tool/ +thread_profiler.enabled = true + +# Your application deployments can be recorded through the +# New Relic REST API. To use this feature provide your API key +# below then use the `newrelic-admin record-deploy` command. +# This can also be set using the NEW_RELIC_API_KEY +# environment variable. +# api_key = + +# Distributed tracing lets you see the path that a request takes +# through your distributed system. For more information, please +# consult our distributed tracing planning guide. +# https://docs.newrelic.com/docs/transition-guide-distributed-tracing +distributed_tracing.enabled = true + +# This setting enables log decoration, the forwarding of log events, +# and the collection of logging metrics if these sub-feature +# configurations are also enabled. If this setting is false, no +# logging instrumentation features are enabled. This can also be +# set using the NEW_RELIC_APPLICATION_LOGGING_ENABLED environment +# variable. +# application_logging.enabled = true + +# If true, the agent captures log records emitted by your application +# and forwards them to New Relic. `application_logging.enabled` must +# also be true for this setting to take effect. You can also set +# this using the NEW_RELIC_APPLICATION_LOGGING_FORWARDING_ENABLED +# environment variable. +# application_logging.forwarding.enabled = true + +# If true, the agent decorates logs with metadata to link to entities, +# hosts, traces, and spans. `application_logging.enabled` must also +# be true for this setting to take effect. This can also be set +# using the NEW_RELIC_APPLICATION_LOGGING_LOCAL_DECORATING_ENABLED +# environment variable. +# application_logging.local_decorating.enabled = true + +# If true, the agent captures metrics related to the log lines +# being sent up by your application. This can also be set +# using the NEW_RELIC_APPLICATION_LOGGING_METRICS_ENABLED +# environment variable. +# application_logging.metrics.enabled = true + +startup_timeout = 10.0 + +# --------------------------------------------------------------------------- + +# +# The application environments. These are specific settings which +# override the common environment settings. The settings related to a +# specific environment will be used when the environment argument to the +# newrelic.agent.initialize() function has been defined to be either +# "development", "test", "staging" or "production". +# + +[newrelic:development] +monitor_mode = false + +[newrelic:test] +monitor_mode = false + +[newrelic:staging] +app_name = Python Application (Staging) +monitor_mode = true + +[newrelic:production] +monitor_mode = true + +# --------------------------------------------------------------------------- diff --git a/requirements.txt b/requirements.txt index 671da52..342b75d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,5 @@ cryptography==41.0.6 pip==23.3 wheel==0.38.1 setuptools==65.5.1 -future~=0.18.2 \ No newline at end of file +future~=0.18.2 +newrelic==9.7.0 diff --git a/src/__main__.py b/src/__main__.py index 42e0fbd..99d16c7 100644 --- a/src/__main__.py +++ b/src/__main__.py @@ -1,4 +1,7 @@ #!/usr/bin/env python +import newrelic.agent +newrelic.agent.initialize('./newrelic.ini') + import optparse import os import sys @@ -10,7 +13,15 @@ from apscheduler.schedulers.background import BlockingScheduler from pytz import utc from yaml import Loader, load +from newrelic_logging.api import ApiFactory +from newrelic_logging.auth import AuthenticatorFactory +from newrelic_logging.cache import CacheFactory, BackendFactory from newrelic_logging.config import Config, getenv +from newrelic_logging.newrelic import NewRelicFactory +from newrelic_logging.pipeline import PipelineFactory +from newrelic_logging.query import QueryFactory +from newrelic_logging.salesforce import SalesForceFactory + from newrelic_logging.integration import Integration from newrelic_logging.telemetry import print_info, print_warn @@ -124,8 +135,16 @@ def run_once( event_type_fields_mapping: dict, numeric_fields_list: set ): + Integration( config, + AuthenticatorFactory(), + CacheFactory(BackendFactory()), + PipelineFactory(), + SalesForceFactory(), + ApiFactory(), + QueryFactory(), + NewRelicFactory(), event_type_fields_mapping, numeric_fields_list, config.get_int(CRON_INTERVAL_MINUTES, 60), @@ -154,6 +173,12 @@ def run_as_service( scheduler.add_job( Integration( config, + AuthenticatorFactory(), + CacheFactory(BackendFactory()), + PipelineFactory(), + SalesForceFactory(), + QueryFactory(), + NewRelicFactory(), event_type_fields_mapping, numeric_fields_list, 0 @@ -179,7 +204,7 @@ def run( run_as_service(config, event_type_fields_mapping, numeric_fields_list) - +@newrelic.agent.background_task() def main(): print_info(f'Integration start. Using program arguments {sys.argv[1:]}') diff --git a/src/newrelic_logging/__init__.py b/src/newrelic_logging/__init__.py index 0770ec8..f9a3907 100644 --- a/src/newrelic_logging/__init__.py +++ b/src/newrelic_logging/__init__.py @@ -1,6 +1,38 @@ +from enum import Enum + + # Integration definitions -VERSION = "1.0.0" -NAME = "salesforce-eventlogfile" +VERSION = "2.0.0" +NAME = "salesforce-exporter" PROVIDER = "newrelic-labs" -COLLECTOR_NAME = "newrelic-logs-salesforce-eventlogfile" +COLLECTOR_NAME = "newrelic-salesforce-exporter" + + +class DataFormat(Enum): + LOGS = 1 + EVENTS = 2 + + +class ConfigException(Exception): + def __init__(self, prop_name: str = None, *args: object): + self.prop_name = prop_name + super().__init__(*args) + + +class LoginException(Exception): + pass + + +class SalesforceApiException(Exception): + def __init__(self, err_code: int = 0, *args: object): + self.err_code = err_code + super().__init__(*args) + + +class CacheException(Exception): + pass + + +class NewRelicApiException(Exception): + pass diff --git a/src/newrelic_logging/api.py b/src/newrelic_logging/api.py new file mode 100644 index 0000000..9c0f000 --- /dev/null +++ b/src/newrelic_logging/api.py @@ -0,0 +1,114 @@ +from requests import RequestException, Session, Response +from typing import Any + +from . import SalesforceApiException +from .auth import Authenticator +from .telemetry import print_warn + +def get( + auth: Authenticator, + session: Session, + serviceUrl: str, + cb, + stream: bool = False, +) -> Any: + url = f'{auth.get_instance_url()}{serviceUrl}' + + try: + headers = { + 'Authorization': f'Bearer {auth.get_access_token()}' + } + + response = session.get(url, headers=headers, stream=stream) + + status_code = response.status_code + + if status_code == 200: + return cb(response) + + if status_code == 401: + print_warn( + f'invalid token while executing api operation: {url}', + ) + auth.reauthenticate(session) + + new_headers = { + 'Authorization': f'Bearer {auth.get_access_token()}' + } + + response = session.get(url, headers=new_headers, stream=stream) + if response.status_code == 200: + return cb(response) + + raise SalesforceApiException( + status_code, + f'error executing api operation: {url}, ' \ + f'status-code: {status_code}, ' \ + f'reason: {response.reason}, ' + ) + except ConnectionError as e: + raise SalesforceApiException( + -1, + f'connection error executing api operation: {url}', + ) from e + except RequestException as e: + raise SalesforceApiException( + -1, + f'request error executing api operation: {url}', + ) from e + + +def stream_lines(response: Response, chunk_size: int): + if response.encoding is None: + response.encoding = 'utf-8' + + # Stream the response as a set of lines. This function will return an + # iterator that yields one line at a time holding only the minimum + # amount of data chunks in memory to make up a single line + + return response.iter_lines( + decode_unicode=True, + chunk_size=chunk_size, + ) + + +class Api: + def __init__(self, authenticator: Authenticator, api_ver: str): + self.authenticator = authenticator + self.api_ver = api_ver + + def authenticate(self, session: Session) -> None: + self.authenticator.authenticate(session) + + def query(self, session: Session, soql: str, api_ver: str = None) -> dict: + ver = self.api_ver + if not api_ver is None: + ver = api_ver + + return get( + self.authenticator, + session, + f'/services/data/v{ver}/query?q={soql}', + lambda response : response.json() + ) + + def get_log_file( + self, + session: Session, + log_file_path: str, + chunk_size: int, + ): + return get( + self.authenticator, + session, + log_file_path, + lambda response : stream_lines(response, chunk_size), + stream=True, + ) + +class ApiFactory: + def __init__(self): + pass + + def new(self, authenticator: Authenticator, api_ver: str) -> Api: + return Api(authenticator, api_ver) diff --git a/src/newrelic_logging/auth.py b/src/newrelic_logging/auth.py new file mode 100644 index 0000000..ea72a1c --- /dev/null +++ b/src/newrelic_logging/auth.py @@ -0,0 +1,313 @@ +from cryptography.hazmat.primitives import serialization +from datetime import datetime, timedelta +import jwt +from requests import RequestException, Session + +from . import ConfigException, LoginException +from .cache import DataCache +from .config import Config +from .telemetry import print_err, print_info, print_warn + + +AUTH_CACHE_KEY = 'com.newrelic.labs.sf_auth' +SF_GRANT_TYPE = 'SF_GRANT_TYPE' +SF_CLIENT_ID = 'SF_CLIENT_ID' +SF_CLIENT_SECRET = 'SF_CLIENT_SECRET' +SF_USERNAME = 'SF_USERNAME' +SF_PASSWORD = 'SF_PASSWORD' +SF_PRIVATE_KEY = 'SF_PRIVATE_KEY' +SF_SUBJECT = 'SF_SUBJECT' +SF_AUDIENCE = 'SF_AUDIENCE' +SF_TOKEN_URL = 'SF_TOKEN_URL' + + +class Authenticator: + def __init__( + self, + token_url: str, + auth_data: dict, + data_cache: DataCache + ): + self.token_url = token_url + self.auth_data = auth_data + self.data_cache = data_cache + self.access_token = None + self.instance_url = None + + def get_access_token(self) -> str: + return self.access_token + + def get_instance_url(self) -> str: + return self.instance_url + + def get_grant_type(self) -> str: + return self.auth_data['grant_type'] + + def set_auth_data(self, access_token: str, instance_url: str) -> None: + self.access_token = access_token + self.instance_url = instance_url + + def clear_auth(self) -> None: + self.set_auth_data(None, None) + + if self.data_cache: + try: + # @TODO need to change all the places where redis is explicitly + # referenced since this breaks encapsulation. + self.data_cache.backend.redis.delete(AUTH_CACHE_KEY) + except Exception as e: + print_warn(f'Failed deleting data from cache: {e}') + + def load_auth_from_cache(self) -> bool: + try: + auth_exists = self.data_cache.backend.redis.exists(AUTH_CACHE_KEY) + if auth_exists: + print_info('Retrieving credentials from Redis.') + try: + auth = self.data_cache.backend.redis.hmget( + AUTH_CACHE_KEY, + ['access_token', 'instance_url'], + ) + + self.set_auth_data( + auth[0], + auth[1], + ) + + return True + except Exception as e: + print_err(f"Failed getting 'auth' key: {e}") + except Exception as e: + print_err(f"Failed checking 'auth' key: {e}") + + return False + + def store_auth(self, auth_resp: dict) -> None: + self.access_token = auth_resp['access_token'] + self.instance_url = auth_resp['instance_url'] + + if self.data_cache: + print_info('Storing credentials in cache.') + + auth = { + 'access_token': self.access_token, + 'instance_url': self.instance_url, + } + + try: + self.data_cache.backend.redis.hmset(AUTH_CACHE_KEY, auth) + except Exception as e: + print_warn(f"Failed storing data in cache: {e}") + + def authenticate_with_jwt(self, session: Session) -> None: + private_key_file = self.auth_data['private_key'] + client_id = self.auth_data['client_id'] + subject = self.auth_data['subject'] + audience = self.auth_data['audience'] + exp = int((datetime.utcnow() - timedelta(minutes=5)).timestamp()) + + with open(private_key_file, 'r') as f: + try: + private_key = f.read() + key = serialization.load_ssh_private_key(private_key.encode(), password=b'') + except ValueError as e: + raise LoginException(f'authentication failed for {self.instance_name}. error message: {str(e)}') + + jwt_claim_set = { + "iss": client_id, + "sub": subject, + "aud": audience, + "exp": exp + } + + signed_token = jwt.encode( + jwt_claim_set, + key, + algorithm='RS256', + ) + + params = { + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "assertion": signed_token, + "format": "json" + } + + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json" + } + + try: + print_info(f'retrieving salesforce token at {self.token_url}') + resp = session.post(self.token_url, params=params, + headers=headers) + if resp.status_code != 200: + raise LoginException(f'sfdc token request failed. http-status-code:{resp.status_code}, reason: {resp.text}') + + self.store_auth(resp.json()) + except ConnectionError as e: + raise LoginException(f'authentication failed for sfdc instance {self.instance_name}') from e + except RequestException as e: + raise LoginException(f'authentication failed for sfdc instance {self.instance_name}') from e + + def authenticate_with_password(self, session: Session) -> None: + client_id = self.auth_data['client_id'] + client_secret = self.auth_data['client_secret'] + username = self.auth_data['username'] + password = self.auth_data['password'] + + params = { + "grant_type": "password", + "client_id": client_id, + "client_secret": client_secret, + "username": username, + "password": password + } + + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json" + } + + try: + print_info(f'retrieving salesforce token at {self.token_url}') + resp = session.post(self.token_url, params=params, + headers=headers) + if resp.status_code != 200: + raise LoginException(f'salesforce token request failed. status-code:{resp.status_code}, reason: {resp.reason}') + + self.store_auth(resp.json()) + except ConnectionError as e: + raise LoginException(f'authentication failed for sfdc instance {self.instance_name}') from e + except RequestException as e: + raise LoginException(f'authentication failed for sfdc instance {self.instance_name}') from e + + def authenticate(self, session: Session) -> None: + if self.data_cache and self.load_auth_from_cache(): + return + + oauth_type = self.get_grant_type() + if oauth_type == 'password': + self.authenticate_with_password(session) + print_info('Correctly authenticated with user/pass flow') + return + + self.authenticate_with_jwt(session) + print_info('Correctly authenticated with JWT flow') + + def reauthenticate(self, session: Session) -> None: + self.clear_auth() + self.authenticate(session) + +def validate_oauth_config(auth: dict) -> dict: + if not auth['client_id']: + raise ConfigException('client_id', 'missing OAuth client id') + + if not auth['client_secret']: + raise ConfigException( + 'client_secret', + 'missing OAuth client secret', + ) + + if not auth['username']: + raise ConfigException('username', 'missing OAuth username') + + if not auth['password']: + raise ConfigException('password', 'missing OAuth client secret') + + return auth + + +def validate_jwt_config(auth: dict) -> dict: + if not auth['client_id']: + raise ConfigException('client_id', 'missing JWT client id') + + if not auth['private_key']: + raise ConfigException('private_key', 'missing JWT private key') + + if not auth['subject']: + raise ConfigException('subject', 'missing JWT subject') + + if not auth['audience']: + raise ConfigException('audience', 'missing JWT audience') + + return auth + + +def make_auth_from_config(auth: Config) -> dict: + grant_type = auth.get( + 'grant_type', + 'password', + env_var_name=SF_GRANT_TYPE, + ).lower() + + if grant_type == 'password': + return validate_oauth_config({ + 'grant_type': grant_type, + 'client_id': auth.get('client_id', env_var_name=SF_CLIENT_ID), + 'client_secret': auth.get( + 'client_secret', + env_var_name=SF_CLIENT_SECRET + ), + 'username': auth.get('username', env_var_name=SF_USERNAME), + 'password': auth.get('password', env_var_name=SF_PASSWORD), + }) + + if grant_type == 'urn:ietf:params:oauth:grant-type:jwt-bearer': + return validate_jwt_config({ + 'grant_type': grant_type, + 'client_id': auth.get('client_id', env_var_name=SF_CLIENT_ID), + 'private_key': auth.get('private_key', env_var_name=SF_PRIVATE_KEY), + 'subject': auth.get('subject', env_var_name=SF_SUBJECT), + 'audience': auth.get('audience', env_var_name=SF_AUDIENCE), + }) + + raise Exception(f'Wrong or missing grant_type') + + +def make_auth_from_env(config: Config) -> dict: + grant_type = config.getenv(SF_GRANT_TYPE, 'password').lower() + + if grant_type == 'password': + return validate_oauth_config({ + 'grant_type': grant_type, + 'client_id': config.getenv(SF_CLIENT_ID), + 'client_secret': config.getenv(SF_CLIENT_SECRET), + 'username': config.getenv(SF_USERNAME), + 'password': config.getenv(SF_PASSWORD), + }) + + if grant_type == 'urn:ietf:params:oauth:grant-type:jwt-bearer': + return validate_jwt_config({ + 'grant_type': grant_type, + 'client_id': config.getenv(SF_CLIENT_ID), + 'private_key': config.getenv(SF_PRIVATE_KEY), + 'subject': config.getenv(SF_SUBJECT), + 'audience': config.getenv(SF_AUDIENCE), + }) + + raise Exception(f'Wrong or missing grant_type') + + +class AuthenticatorFactory: + def __init__(self): + pass + + def new(self, config: Config, data_cache: DataCache) -> Authenticator: + token_url = config.get('token_url', env_var_name=SF_TOKEN_URL) + + if not token_url: + raise ConfigException('token_url', 'missing token URL') + + if 'auth' in config: + return Authenticator( + token_url, + make_auth_from_config(config.sub('auth')), + data_cache, + ) + + return Authenticator( + token_url, + make_auth_from_env(config), + data_cache + ) diff --git a/src/newrelic_logging/cache.py b/src/newrelic_logging/cache.py index 536f6fc..6f7a197 100644 --- a/src/newrelic_logging/cache.py +++ b/src/newrelic_logging/cache.py @@ -1,8 +1,10 @@ +import gc import redis from datetime import timedelta +from . import CacheException from .config import Config -from .telemetry import print_err, print_info +from .telemetry import print_info CONFIG_CACHE_ENABLED = 'cache_enabled' @@ -20,138 +22,154 @@ DEFAULT_REDIS_SSL = False -# Local cache, to store data before sending it to Redis. -class DataCache: - redis = None - redis_expire = None - cached_events = {} - cached_logs = {} - - def __init__(self, redis, redis_expire) -> None: +class RedisBackend: + def __init__(self, redis): self.redis = redis - self.redis_expire = redis_expire - def set_redis_expire(self, key): - try: - self.redis.expire(key, timedelta(days=self.redis_expire)) - except Exception as e: - print_err(f"Failed setting expire time for key {key}: {e}") - exit(1) - - def persist_logs(self, record_id: str) -> bool: - if record_id in self.cached_logs: - for row_id in self.cached_logs[record_id]: - try: - self.redis.rpush(record_id, row_id) - except Exception as e: - print_err(f"Failed pushing record {record_id}: {e}") - exit(1) - # Set expire date for the whole list only once, when it find the first entry ('init') - if row_id == 'init': - self.set_redis_expire(record_id) - del self.cached_logs[record_id] - return True - else: - return False - - def persist_event(self, record_id: str) -> bool: - if record_id in self.cached_events: - try: - self.redis.set(record_id, '') - except Exception as e: - print_err(f"Failed setting record {record_id}: {e}") - exit(1) - self.set_redis_expire(record_id) - del self.cached_events[record_id] + def exists(self, key): + return self.redis.exists(key) + + def put(self, key, item): + self.redis.set(key, item) + + def get_set(self, key): + return self.redis.smembers(key) + + def set_add(self, key, *values): + self.redis.sadd(key, *values) + + def set_expiry(self, key, days): + self.redis.expire(key, timedelta(days=days)) + + +class BufferedAddSetCache: + def __init__(self, s: set): + self.s = s + self.buffer = set() + + def check_or_set(self, item: str) -> bool: + if item in self.s or item in self.buffer: return True - else: - return False - def can_skip_downloading_record(self, record_id: str) -> bool: - try: - does_exist = self.redis.exists(record_id) - except Exception as e: - print_err(f"Failed checking record {record_id}: {e}") - exit(1) - if does_exist: - try: - return self.redis.llen(record_id) > 1 - except Exception as e: - print_err(f"Failed checking len for record {record_id}: {e}") - exit(1) + self.buffer.add(item) return False - def retrieve_cached_message_list(self, record_id: str): + def get_buffer(self) -> set: + return self.buffer + + +class DataCache: + def __init__(self, backend, expiry): + self.backend = backend + self.expiry = expiry + self.log_records = {} + self.event_records = None + + def can_skip_downloading_logfile(self, record_id: str) -> bool: try: - cache_key_exists = self.redis.exists(record_id) + return self.backend.exists(record_id) except Exception as e: - print_err(f"Failed checking record {record_id}: {e}") - exit(1) - - if cache_key_exists: - try: - cached_messages = self.redis.lrange(record_id, 0, -1) - except Exception as e: - print_err(f"Failed getting list range for record {record_id}: {e}") - exit(1) - return cached_messages - else: - self.cached_logs[record_id] = ['init'] - - return None - - # Cache event - def check_cached_id(self, record_id: str): + raise CacheException(f'failed checking record {record_id}: {e}') + + def check_or_set_log_line(self, record_id: str, line: dict) -> bool: try: - does_exist = self.redis.exists(record_id) + if not record_id in self.log_records: + self.log_records[record_id] = BufferedAddSetCache( + self.backend.get_set(record_id), + ) + + return self.log_records[record_id].check_or_set(line['REQUEST_ID']) except Exception as e: - print_err(f"Failed checking record {record_id}: {e}") - exit(1) + raise CacheException(f'failed checking record {record_id}: {e}') - if does_exist: - return True - else: - self.cached_events[record_id] = '' - return False - - # Cache log - def record_or_skip_row(self, record_id: str, row: dict, cached_messages: dict) -> bool: - row_id = row["REQUEST_ID"] - - if cached_messages is not None: - row_id_b = row_id.encode('utf-8') - if row_id_b in cached_messages: - return True - self.cached_logs[record_id].append(row_id) - else: - self.cached_logs[record_id].append(row_id) + def check_or_set_event_id(self, record_id: str) -> bool: + try: + if not self.event_records: + self.event_records = BufferedAddSetCache( + self.backend.get_set('event_ids'), + ) - return False + return self.event_records.check_or_set(record_id) + except Exception as e: + raise CacheException(f'failed checking record {record_id}: {e}') + + def flush(self) -> None: + try: + for record_id in self.log_records: + buf = self.log_records[record_id].get_buffer() + if len(buf) > 0: + self.backend.set_add(record_id, *buf) + + self.backend.set_expiry(record_id, self.expiry) + + if self.event_records: + buf = self.event_records.get_buffer() + if len(buf) > 0: + for id in buf: + self.backend.put(id, 1) + self.backend.set_expiry(id, self.expiry) + + self.backend.set_add('event_ids', *buf) + self.backend.set_expiry('event_ids', self.expiry) + + # attempt to reclaim memory + for record_id in self.log_records: + self.log_records[record_id] = None + + self.log_records = {} + self.event_records = None + gc.collect() + except Exception as e: + raise CacheException(f'failed flushing cache: {e}') + + +class BackendFactory: + def __init__(self): + pass -def make_cache(config: Config): - if config.get_bool(CONFIG_CACHE_ENABLED, DEFAULT_CACHE_ENABLED): + def new(self, config: Config): host = config.get(CONFIG_REDIS_HOST, DEFAULT_REDIS_HOST) port = config.get_int(CONFIG_REDIS_PORT, DEFAULT_REDIS_PORT) db = config.get_int(CONFIG_REDIS_DB_NUMBER, DEFAULT_REDIS_DB_NUMBER) password = config.get(CONFIG_REDIS_PASSWORD) ssl = config.get_bool(CONFIG_REDIS_USE_SSL, DEFAULT_REDIS_SSL) - expire_days = config.get_int(CONFIG_REDIS_EXPIRE_DAYS) password_display = "XXXXXX" if password != None else None print_info( - f'Cache enabled, connecting to redis instance {host}:{port}:{db}, ssl={ssl}, password={password_display}' + f'connecting to redis instance {host}:{port}:{db}, ssl={ssl}, password={password_display}' + ) + + return RedisBackend( + redis.Redis( + host=host, + port=port, + db=db, + password=password, + ssl=ssl, + decode_responses=True, + ), + ) - return DataCache(redis.Redis( - host=host, - port=port, - db=db, - password=password, - ssl=ssl - ), expire_days) - print_info('Cache disabled') +class CacheFactory: + def __init__(self, backend_factory): + self.backend_factory = backend_factory + pass + + def new(self, config: Config): + if not config.get_bool(CONFIG_CACHE_ENABLED, DEFAULT_CACHE_ENABLED): + print_info('Cache disabled') + return None + + print_info('Cache enabled') - return None + try: + return DataCache( + self.backend_factory.new(config), + config.get_int(CONFIG_REDIS_EXPIRE_DAYS) + ) + except Exception as e: + raise CacheException(f'failed creating backend: {e}') diff --git a/src/newrelic_logging/config.py b/src/newrelic_logging/config.py index 81a4b42..a8e6d15 100644 --- a/src/newrelic_logging/config.py +++ b/src/newrelic_logging/config.py @@ -4,6 +4,18 @@ from typing import Any +CONFIG_DATE_FIELD = 'date_field' +CONFIG_GENERATION_INTERVAL = 'generation_interval' +CONFIG_TIME_LAG_MINUTES = 'time_lag_minutes' + +DATE_FIELD_LOG_DATE = 'LogDate' +DATE_FIELD_CREATE_DATE = 'CreateDate' +GENERATION_INTERVAL_DAILY = 'Daily' +GENERATION_INTERVAL_HOURLY = 'Hourly' + +DEFAULT_TIME_LAG_MINUTES = 300 +DEFAULT_GENERATION_INTERVAL = GENERATION_INTERVAL_DAILY + BOOL_TRUE_VALS = ['true', '1', 'on', 'yes'] NOT_FOUND = SimpleNamespace() @@ -74,15 +86,21 @@ def set_prefix(self, prefix: str) -> None: def getenv(self, env_var_name: str, default = None) -> str: return getenv(env_var_name, default, self.prefix) - def get(self, key: str, default = None, allow_none = False) -> Any: + def get( + self, + key: str, + default = None, + allow_none = False, + env_var_name = None, + ) -> Any: val = get_nested(self.config, key) if not val == NOT_FOUND and not allow_none and not val == None: return val - return self.getenv( - re.sub(r'[^a-zA-Z0-9_]', '_', key.upper()), - default, - ) + var_name = env_var_name if env_var_name else \ + re.sub(r'[^a-zA-Z0-9_]', '_', key.upper()) + + return self.getenv(var_name, default) def get_int(self, key: str, default = None) -> int: val = self.get(key, default) diff --git a/src/newrelic_logging/env.py b/src/newrelic_logging/env.py deleted file mode 100644 index 9fadebe..0000000 --- a/src/newrelic_logging/env.py +++ /dev/null @@ -1,58 +0,0 @@ -SF_GRANT_TYPE = 'SF_GRANT_TYPE' -SF_CLIENT_ID = 'SF_CLIENT_ID' -SF_CLIENT_SECRET = 'SF_CLIENT_SECRET' -SF_USERNAME = 'SF_USERNAME' -SF_PASSWORD = 'SF_PASSWORD' -SF_PRIVATE_KEY = 'SF_PRIVATE_KEY' -SF_SUBJECT = 'SF_SUBJECT' -SF_AUDIENCE = 'SF_AUDIENCE' -SF_TOKEN_URL = 'SF_TOKEN_URL' - -class AuthEnv: - def __init__(self, config): - self.config = config - - def get_grant_type(self): - return self.config.getenv(SF_GRANT_TYPE) - - def get_client_id(self): - return self.config.getenv(SF_CLIENT_ID) - - def get_client_secret(self): - return self.config.getenv(SF_CLIENT_SECRET) - - def get_username(self): - return self.config.getenv(SF_USERNAME) - - def get_password(self): - return self.config.getenv(SF_PASSWORD) - - def get_private_key(self): - return self.config.getenv(SF_PRIVATE_KEY) - - def get_subject(self): - return self.config.getenv(SF_SUBJECT) - - def get_audience(self): - return self.config.getenv(SF_AUDIENCE) - - def get_token_url(self): - return self.config.getenv(SF_TOKEN_URL) - - -class Auth: - access_token = None - instance_url = None - # Never used, maybe in the future - token_type = None - - def __init__(self, access_token: str, instance_url: str, token_type: str) -> None: - self.access_token = access_token - self.instance_url = instance_url - self.token_type = token_type - - def get_access_token(self) -> str: - return self.access_token - - def get_instance_url(self) -> str: - return self.instance_url diff --git a/src/newrelic_logging/integration.py b/src/newrelic_logging/integration.py index ba50de3..fd54871 100644 --- a/src/newrelic_logging/integration.py +++ b/src/newrelic_logging/integration.py @@ -1,235 +1,164 @@ -import sys +from requests import Session + +from . import \ + ConfigException, \ + CacheException, \ + DataFormat, \ + LoginException, \ + NewRelicApiException, \ + SalesforceApiException +from . import api +from . import auth +from . import cache +from . import config as mod_config +from . import newrelic +from . import query +from . import pipeline +from . import salesforce from .http_session import new_retry_session -from .newrelic import NewRelic -from .cache import DataCache -from .salesforce import SalesForce, SalesforceApiException -from .env import AuthEnv -from enum import Enum -from .config import Config, getenv -from .telemetry import Telemetry, print_info, print_err +from .telemetry import print_err, print_info, print_warn, Telemetry -NR_LICENSE_KEY = 'NR_LICENSE_KEY' -NR_ACCOUNT_ID = 'NR_ACCOUNT_ID' - - -class DataFormat(Enum): - LOGS = 1 - EVENTS = 2 - -# TODO: move queries to the instance level, so we can have different queries for +# @TODO: move queries to the instance level, so we can have different queries for # each instance. -# TODO: also keep general queries that apply to all instances. +# @TODO: also keep general queries that apply to all instances. + +def build_instance( + config: mod_config.Config, + auth_factory: auth.AuthenticatorFactory, + cache_factory: cache.CacheFactory, + pipeline_factory: pipeline.PipelineFactory, + salesforce_factory: salesforce.SalesForceFactory, + api_factory: api.ApiFactory, + query_factory: query.QueryFactory, + new_relic: newrelic.NewRelic, + data_format: DataFormat, + event_type_fields_mapping: dict, + numeric_fields_list: set, + initial_delay: int, + instance: dict, + index: int, +): + instance_name = instance['name'] + labels = instance['labels'] + labels['nr-labs'] = 'data' + instance_config = config.sub(f'instances.{index}.arguments') + instance_config.set_prefix( + instance_config['auth_env_prefix'] \ + if 'auth_env_prefix' in instance_config else '' + ) + + data_cache = cache_factory.new(instance_config) + authenticator = auth_factory.new(instance_config, data_cache) + + return { + 'client': salesforce_factory.new( + instance_name, + instance_config, + data_cache, + authenticator, + pipeline_factory.new( + instance_config, + data_cache, + new_relic, + data_format, + labels, + event_type_fields_mapping, + numeric_fields_list, + ), + api_factory, + query_factory, + initial_delay, + config['queries'] if 'queries' in config else None, + ), + 'name': instance_name, + } class Integration: - numeric_fields_list = set() - def __init__( self, - config: Config, + config: mod_config.Config, + auth_factory: auth.AuthenticatorFactory, + cache_factory: cache.CacheFactory, + pipeline_factory: pipeline.PipelineFactory, + salesforce_factory: salesforce.SalesForceFactory, + api_factory: api.ApiFactory, + query_factory: query.QueryFactory, + newrelic_factory: newrelic.NewRelicFactory, event_type_fields_mapping: dict = {}, numeric_fields_list: set = set(), initial_delay: int = 0, ): - Integration.numeric_fields_list = numeric_fields_list - self.instances = [] - Telemetry(config["integration_name"]) - for count, instance in enumerate(config['instances']): - instance_name = instance['name'] - labels = instance['labels'] - labels['nr-labs'] = 'data' - instance_config = config.sub(f'instances.{count}.arguments') - instance_config.set_prefix( - instance_config['auth_env_prefix'] \ - if 'auth_env_prefix' in instance_config else '' - ) - auth_env = AuthEnv(instance_config) - - if 'queries' in config: - client = SalesForce(auth_env, instance_name, instance_config, event_type_fields_mapping, initial_delay, config['queries']) - else: - client = SalesForce(auth_env, instance_name, instance_config, event_type_fields_mapping, initial_delay) - - if 'auth' in instance_config: - auth = instance_config['auth'] - if 'grant_type' in auth: - oauth_type = auth['grant_type'] - else: - sys.exit("No 'grant_type' specified under 'auth' section in config.yml for instance '" + instance_name + "'") - else: - oauth_type = auth_env.get_grant_type() - - self.instances.append({'labels': labels, 'client': client, "oauth_type": oauth_type, 'name': instance_name}) - - newrelic_config = config['newrelic'] - data_format = newrelic_config['data_format'].lower() \ - if 'data_format' in newrelic_config else 'logs' - - if data_format == "logs": - self.data_format = DataFormat.LOGS - elif data_format == "events": - self.data_format = DataFormat.EVENTS + Telemetry( + config['integration_name'] if 'integration_name' in config \ + else 'com.newrelic.labs.salesforce.exporter' + ) + + data_format = config.get('newrelic.data_format', 'logs').lower() + if data_format == 'logs': + data_format = DataFormat.LOGS + elif data_format == 'events': + data_format = DataFormat.EVENTS else: - sys.exit(f'invalid data_format specified. valid values are "logs" or "events"') - - # Fill credentials for NR APIs - if 'license_key' in newrelic_config: - NewRelic.logs_license_key = NewRelic.events_api_key = newrelic_config['license_key'] - else: - NewRelic.logs_license_key = NewRelic.events_api_key = getenv(NR_LICENSE_KEY) - - if self.data_format == DataFormat.EVENTS: - if 'account_id' in newrelic_config: - account_id = newrelic_config['account_id'] - else: - account_id = getenv(NR_ACCOUNT_ID) - NewRelic.set_api_endpoint(newrelic_config['api_endpoint'], account_id) - - NewRelic.set_logs_endpoint(newrelic_config['api_endpoint']) + raise ConfigException(f'invalid data format {data_format}') - def run(self): - sfdc_session = new_retry_session() + self.new_relic = newrelic_factory.new(config) + self.instances = [] - for instance in self.instances: - print_info(f"Running instance '{instance['name']}'") - - labels = instance['labels'] - client = instance['client'] - oauth_type = instance['oauth_type'] - - logs = self.auth_and_fetch(True, client, oauth_type, sfdc_session) - if self.response_empty(logs): - print_info("No data to be sent") - self.process_telemetry() - continue - - if self.data_format == DataFormat.LOGS: - self.process_logs(logs, labels, client.data_cache) - else: - self.process_events(logs, labels, client.data_cache) - - self.process_telemetry() - - def process_telemetry(self): - if not Telemetry().is_empty(): - print_info("Sending telemetry data") - self.process_logs(Telemetry().build_model(), {}, None) - Telemetry().clear() - else: + if not 'instances' in config or len(config['instances']) == 0: + print_warn('no instances found to run') + return + + for index, instance in enumerate(config['instances']): + self.instances.append(build_instance( + config, + auth_factory, + cache_factory, + pipeline_factory, + salesforce_factory, + api_factory, + query_factory, + self.new_relic, + data_format, + event_type_fields_mapping, + numeric_fields_list, + initial_delay, + instance, + index, + )) + + def process_telemetry(self, session: Session): + if Telemetry().is_empty(): print_info("No telemetry data") + return - def auth_and_fetch(self, retry, client, oauth_type, sfdc_session): - if not client.authenticate(oauth_type, sfdc_session): - return None + print_info("Sending telemetry data") + self.new_relic.post_logs(session, Telemetry().build_model()) + Telemetry().clear() - logs = None + def auth_and_fetch( + self, + client: salesforce.SalesForce, + session: Session, + ) -> None: try: - logs = client.fetch_logs(sfdc_session) + client.authenticate(session) + client.fetch_logs(session) + except LoginException as e: + print_err(f'authentication failed: {e}') except SalesforceApiException as e: - if e.err_code == 401: - if retry: - print_err("Invalid token, retry auth and fetch...") - client.clear_auth() - return self.auth_and_fetch(False, client, oauth_type, sfdc_session) - else: - print_err(f"Exception while fetching data from SF: {e}") - return None - else: - print_err(f"Exception while fetching data from SF: {e}") - return None + print_err(f'exception while fetching data from SF: {e}') + except CacheException as e: + print_err(f'exception while accessing Redis cache: {e}') + except NewRelicApiException as e: + print_err(f'exception while posting data to New Relic: {e}') except Exception as e: - print_err(f"Exception while fetching data from SF: {e}") - return None - - return logs - - @staticmethod - def response_empty(logs): - # Empty or None - if not logs: - return True - for l in logs: - if "log_entries" in l and l["log_entries"]: - return False - return True - - @staticmethod - def cache_processed_data(log_file_id, log_entries, data_cache: DataCache): - if data_cache and data_cache.redis: - if log_file_id == '': - # Events - for log in log_entries: - log_id = log.get('attributes', {}).get('Id', '') - data_cache.persist_event(log_id) - else: - # Logs - data_cache.persist_logs(log_file_id) - - @staticmethod - def process_logs(logs, labels, data_cache: DataCache): - nr_session = new_retry_session() - for log_file_obj in logs: - log_entries = log_file_obj['log_entries'] - if len(log_entries) == 0: - continue - - payload = [{'common': labels, 'logs': log_entries}] - log_type = log_file_obj.get('log_type', '') - log_file_id = log_file_obj.get('Id', '') - - status_code = NewRelic.post_logs(nr_session, payload) - if status_code != 202: - print_err(f'newrelic logs api returned code- {status_code}') - else: - print_info(f"Sent {len(log_entries)} log messages from log file {log_type}/{log_file_id}") - Integration.cache_processed_data(log_file_id, log_entries, data_cache) - - @staticmethod - def process_events(logs, labels, data_cache: DataCache): - nr_session = new_retry_session() - for log_file_obj in logs: - log_file_id = log_file_obj.get('Id', '') - log_entries = log_file_obj['log_entries'] - if len(log_entries) == 0: - continue - log_events = [] - for log_entry in log_entries: - log_event = {} - attributes = log_entry['attributes'] - for event_name in attributes: - # currently no need to modify as we did not see any special chars that need to be removed - modified_event_name = event_name - event_value = attributes[event_name] - if event_name in Integration.numeric_fields_list: - if event_value: - try: - log_event[modified_event_name] = int(event_value) - except (TypeError, ValueError) as e: - try: - log_event[modified_event_name] = float(event_value) - except (TypeError, ValueError) as e: - print_err(f'Type conversion error for {event_name}[{event_value}]') - log_event[modified_event_name] = event_value - else: - log_event[modified_event_name] = 0 - else: - log_event[modified_event_name] = event_value - log_event.update(labels) - event_type = log_event.get('EVENT_TYPE', "UnknownSFEvent") - log_event['eventType'] = event_type - log_events.append(log_event) - - # NOTE: this is probably unnecessary now, because we already have a slicing method with a limit of 1000 in SalesForce.extract_row_slice - # since the max number of events that can be posted in a single payload to New Relic is 2000 - max_events = 2000 - x = [log_events[i:i + max_events] for i in range(0, len(log_events), max_events)] - - for log_entries_slice in x: - status_code = NewRelic.post_events(nr_session, log_entries_slice) - if status_code != 200: - print_err(f'newrelic events api returned code- {status_code}') - else: - log_type = log_file_obj.get('log_type', '') - log_file_id = log_file_obj.get('Id', '') - print_info(f"Posted {len(log_entries_slice)} events from log file {log_type}/{log_file_id}") - Integration.cache_processed_data(log_file_id, log_entries, data_cache) + print_err(f'unknown exception occurred: {e}') + + def run(self): + session = new_retry_session() + + for instance in self.instances: + print_info(f'Running instance "{instance["name"]}"') + self.auth_and_fetch(instance['client'], session) + self.process_telemetry(session) diff --git a/src/newrelic_logging/newrelic.py b/src/newrelic_logging/newrelic.py index 0df766c..3d3bba1 100644 --- a/src/newrelic_logging/newrelic.py +++ b/src/newrelic_logging/newrelic.py @@ -1,31 +1,45 @@ import gzip import json -from .telemetry import print_info, print_err -from requests import RequestException -from newrelic_logging import VERSION, NAME, PROVIDER, COLLECTOR_NAME +from requests import RequestException, Session -class NewRelicApiException(Exception): - pass +from . import \ + VERSION, \ + NAME, \ + PROVIDER, \ + COLLECTOR_NAME, \ + NewRelicApiException +from .config import Config +from .telemetry import print_info -class NewRelic: - INGEST_SERVICE_VERSION = "v1" - US_LOGGING_ENDPOINT = "https://log-api.newrelic.com/log/v1" - EU_LOGGING_ENDPOINT = "https://log-api.eu.newrelic.com/log/v1" - LOGS_EVENT_SOURCE = 'logs' - US_EVENTS_ENDPOINT = "https://insights-collector.newrelic.com/v1/accounts/{account_id}/events" - EU_EVENTS_ENDPOINT = "https://insights-collector.eu01.nr-data.net/v1/accounts/{account_id}/events" +NR_LICENSE_KEY = 'NR_LICENSE_KEY' +NR_ACCOUNT_ID = 'NR_ACCOUNT_ID' + +US_LOGGING_ENDPOINT = 'https://log-api.newrelic.com/log/v1' +EU_LOGGING_ENDPOINT = 'https://log-api.eu.newrelic.com/log/v1' +LOGS_EVENT_SOURCE = 'logs' - CONTENT_ENCODING = 'gzip' +US_EVENTS_ENDPOINT = 'https://insights-collector.newrelic.com/v1/accounts/{account_id}/events' +EU_EVENTS_ENDPOINT = 'https://insights-collector.eu01.nr-data.net/v1/accounts/{account_id}/events' - logs_api_endpoint = US_LOGGING_ENDPOINT - logs_license_key = '' +CONTENT_ENCODING = 'gzip' +MAX_EVENTS = 2000 - events_api_endpoint = US_EVENTS_ENDPOINT - events_api_key = '' - @classmethod - def post_logs(cls, session, data): +class NewRelic: + def __init__( + self, + logs_api_endpoint, + logs_license_key, + events_api_endpoint, + events_api_key, + ): + self.logs_api_endpoint = logs_api_endpoint + self.logs_license_key = logs_license_key + self.events_api_endpoint = events_api_endpoint + self.events_api_key = events_api_key + + def post_logs(self, session: Session, data: list[dict]) -> None: # Append integration attributes for log in data[0]['logs']: if not 'attributes' in log: @@ -35,77 +49,89 @@ def post_logs(cls, session, data): log['attributes']['instrumentation.version'] = VERSION log['attributes']['collector.name'] = COLLECTOR_NAME - json_payload = json.dumps(data).encode() - - # print("----- POST DATA (LOGS) -----") - # print(json_payload.decode("utf-8")) - # print("----------------------------") - # return 202 - - payload = gzip.compress(json_payload) - headers = { - "X-License-Key": cls.logs_license_key, - "X-Event-Source": cls.LOGS_EVENT_SOURCE, - "Content-Encoding": cls.CONTENT_ENCODING, - } try: - r = session.post(cls.logs_api_endpoint, data=payload, - headers=headers) - except RequestException as e: - print_err(f"Failed posting logs to New Relic: {repr(e)}") - return 0 - - response = r.content.decode("utf-8") - print_info(f"NR Log API response body = {response}") - - return r.status_code - - @classmethod - def post_events(cls, session, data): + r = session.post( + self.logs_api_endpoint, + data=gzip.compress(json.dumps(data).encode()), + headers={ + 'X-License-Key': self.logs_license_key, + 'X-Event-Source': LOGS_EVENT_SOURCE, + 'Content-Encoding': CONTENT_ENCODING, + }, + ) + + if r.status_code != 202: + raise NewRelicApiException( + f'newrelic logs api returned code {r.status_code}' + ) + + response = r.content.decode("utf-8") + print_info(f"NR Log API response body = {response}") + except RequestException: + raise NewRelicApiException('newrelic logs api request failed') + + def post_events(self, session: Session, events: list[dict]) -> None: # Append integration attributes - for event in data: + for event in events: event['instrumentation.name'] = NAME event['instrumentation.provider'] = PROVIDER event['instrumentation.version'] = VERSION event['collector.name'] = COLLECTOR_NAME - json_payload = json.dumps(data).encode() - - # print("----- POST DATA (EVENTS) -----") - # print(json_payload.decode("utf-8")) - # print("------------------------------") - # return 200 - - payload = gzip.compress(json_payload) - headers = { - "Api-Key": cls.events_api_key, - "Content-Encoding": cls.CONTENT_ENCODING, - } - try: - r = session.post(cls.events_api_endpoint, data=payload, - headers=headers) - except RequestException as e: - print_err(f"Failed posting events to New Relic: {repr(e)}") - return 0 - - response = r.content.decode("utf-8") - print_info(f"NR Event API response body = {response}") - - return r.status_code - - @classmethod - def set_api_endpoint(cls, api_endpoint, account_id): - if api_endpoint == "US": - api_endpoint = NewRelic.US_EVENTS_ENDPOINT; - elif api_endpoint == "EU": - api_endpoint = NewRelic.EU_EVENTS_ENDPOINT - NewRelic.events_api_endpoint = api_endpoint.format(account_id='account_id') - - @classmethod - def set_logs_endpoint(cls, api_endpoint): - if api_endpoint == "US": - NewRelic.logs_api_endpoint = NewRelic.US_LOGGING_ENDPOINT - elif api_endpoint == "EU": - NewRelic.logs_api_endpoint = NewRelic.EU_LOGGING_ENDPOINT + # This funky code produces an array of arrays where each one will be at most + # length 2000 with the last one being <= 2000. This is done to account for + # the fact that only 2000 events can be posted at a time. + + slices = [events[i:(i + MAX_EVENTS)] \ + for i in range(0, len(events), MAX_EVENTS)] + + for slice in slices: + try: + r = session.post( + self.events_api_endpoint, + data=gzip.compress(json.dumps(slice).encode()), + headers={ + 'Api-Key': self.events_api_key, + 'Content-Encoding': CONTENT_ENCODING, + }, + ) + + if r.status_code != 200: + raise NewRelicApiException( + f'newrelic events api returned code {r.status_code}' + ) + + response = r.content.decode("utf-8") + print_info(f"NR Event API response body = {response}") + except RequestException: + raise NewRelicApiException('newrelic events api request failed') + + +class NewRelicFactory: + def __init__(self): + pass + + def new(self, config: Config): + license_key = config.get( + 'newrelic.license_key', + env_var_name=NR_LICENSE_KEY, + ) + + region = config.get('newrelic.api_endpoint') + account_id = config.get('newrelic.account_id', env_var_name=NR_ACCOUNT_ID) + + if region == "US": + logs_api_endpoint = US_LOGGING_ENDPOINT + events_api_endpoint = US_EVENTS_ENDPOINT.format(account_id=account_id) + elif region == "EU": + logs_api_endpoint = EU_LOGGING_ENDPOINT + events_api_endpoint = EU_EVENTS_ENDPOINT.format(account_id=account_id) else: - NewRelic.logs_api_endpoint = api_endpoint + raise NewRelicApiException(f'Invalid region {region}') + + return NewRelic( + logs_api_endpoint, + license_key, + events_api_endpoint, + license_key, + ) diff --git a/src/newrelic_logging/pipeline.py b/src/newrelic_logging/pipeline.py new file mode 100644 index 0000000..cecd341 --- /dev/null +++ b/src/newrelic_logging/pipeline.py @@ -0,0 +1,473 @@ +from copy import deepcopy +import csv +from datetime import datetime +import gc +import pytz +from requests import Session + +from . import DataFormat, SalesforceApiException +from .api import Api +from .cache import DataCache +from .config import Config +from .http_session import new_retry_session +from .newrelic import NewRelic +from .query import Query +from .telemetry import print_info +from .util import generate_record_id, \ + is_logfile_response, \ + maybe_convert_str_to_num, \ + process_query_result, \ + get_timestamp + + +DEFAULT_CHUNK_SIZE = 4096 +DEFAULT_MAX_ROWS = 1000 +MAX_ROWS = 2000 + +def init_fields_from_log_line( + record_event_type: str, + log_line: dict, + event_type_fields_mapping: dict, +) -> dict: + if record_event_type in event_type_fields_mapping: + attrs = {} + + for field in event_type_fields_mapping[record_event_type]: + attrs[field] = log_line[field] + + return attrs + + return deepcopy(log_line) + + +def get_log_line_timestamp(log_line: dict) -> float: + epoch = log_line.get('TIMESTAMP') + + if epoch: + return pytz.utc.localize( + datetime.strptime(epoch, '%Y%m%d%H%M%S.%f') + ).replace(microsecond=0).timestamp() + + return datetime.utcnow().replace(microsecond=0).timestamp() + + +def pack_log_line_into_log( + query: Query, + record_id: str, + record_event_type: str, + log_line: dict, + line_no: int, + event_type_fields_mapping: dict, +) -> dict: + attrs = init_fields_from_log_line( + record_event_type, + log_line, + event_type_fields_mapping, + ) + + timestamp = int(get_log_line_timestamp(log_line)) + attrs.pop('TIMESTAMP', None) + + attrs['LogFileId'] = record_id + + actual_event_type = attrs.pop('EVENT_TYPE', 'SFEvent') + new_event_type = query.get('event_type', actual_event_type) + attrs['EVENT_TYPE'] = new_event_type + + timestamp_field_name = query.get('rename_timestamp', 'timestamp') + attrs[timestamp_field_name] = timestamp + + log_entry = { + 'message': f'LogFile {record_id} row {str(line_no)}', + 'attributes': attrs + } + + if timestamp_field_name == 'timestamp': + log_entry[timestamp_field_name] = timestamp + + return log_entry + + +def export_log_lines( + api: Api, + session: Session, + log_file_path: str, + chunk_size: int, +): + print_info(f'Downloading log lines for log file: {log_file_path}') + return api.get_log_file(session, log_file_path, chunk_size) + + +def transform_log_lines( + iter, + query: Query, + record_id: str, + record_event_type: str, + data_cache: DataCache, + event_type_fields_mapping: dict, +): + # iter is a generator iterator that yields a single line at a time + reader = csv.DictReader(iter) + + # This should cause the reader to request the next line from the iterator + # which will cause the generator iterator to yield the next line + + row_index = 0 + + for row in reader: + # If we've already seen this log line, skip it + if data_cache and data_cache.check_or_set_log_line(record_id, row): + continue + + # Otherwise, pack it up for shipping and yield it for consumption + yield pack_log_line_into_log( + query, + record_id, + record_event_type, + row, + row_index, + event_type_fields_mapping, + ) + + row_index += 1 + + +def pack_event_record_into_log( + query: Query, + record_id: str, + record: dict, +) -> dict: + attrs = process_query_result(record) + if record_id: + attrs['Id'] = record_id + + message = query.get('event_type', 'SFEvent') + if 'attributes' in record and type(record['attributes']) == dict: + attributes = record['attributes'] + if 'type' in attributes and type(attributes['type']) == str: + attrs['EVENT_TYPE'] = message = \ + query.get('event_type', attributes['type']) + + timestamp_attr = query.get('timestamp_attr', 'CreatedDate') + if timestamp_attr in attrs: + created_date = attrs[timestamp_attr] + message += f' {created_date}' + timestamp = get_timestamp(created_date) + else: + timestamp = get_timestamp() + + timestamp_field_name = query.get('rename_timestamp', 'timestamp') + attrs[timestamp_field_name] = timestamp + + log_entry = { + 'message': message, + 'attributes': attrs, + } + + if timestamp_field_name == 'timestamp': + log_entry[timestamp_field_name] = timestamp + + return log_entry + + +def transform_event_records(iter, query: Query, data_cache: DataCache): + # iter here is a list which does mean it's entirely held in memory but these + # are event records not log lines so hopefully it is not as bad. + # @TODO figure out if we can stream event records + for record in iter: + config = query.get_config() + + record_id = record['Id'] if 'Id' in record \ + else generate_record_id( + config['id'] if 'id' in config else [], + record, + ) + + # If we've already seen this event record, skip it. + if data_cache and data_cache.check_or_set_event_id(record_id): + continue + + # Build a New Relic log record from the SF event record + yield pack_event_record_into_log( + query, + record_id, + record, + ) + + +def load_as_logs( + iter, + new_relic: NewRelic, + labels: dict, + max_rows: int, +) -> None: + nr_session = new_retry_session() + + logs = [] + count = total = 0 + + def send_logs(): + nonlocal logs + nonlocal count + + new_relic.post_logs(nr_session, [{'common': labels, 'logs': logs}]) + + print_info(f'Sent {count} log messages.') + + # Attempt to release memory + del logs + + logs = [] + count = 0 + + for log in iter: + if count == max_rows: + send_logs() + + logs.append(log) + + count += 1 + total += 1 + + if len(logs) > 0: + send_logs() + + print_info(f'Sent a total of {total} log messages.') + + # Attempt to reclaim memory + gc.collect() + + +def pack_log_into_event( + log: dict, + labels: dict, + numeric_fields_list: set, +) -> dict: + log_event = {} + + attributes = log['attributes'] + for key in attributes: + value = attributes[key] + + if key in numeric_fields_list: + log_event[key] = \ + maybe_convert_str_to_num(value) if value \ + else 0 + continue + + log_event[key] = value + + log_event.update(labels) + log_event['eventType'] = log_event.get('EVENT_TYPE', "UnknownSFEvent") + + return log_event + + +def load_as_events( + iter, + new_relic: NewRelic, + labels: dict, + max_rows: int, + numeric_fields_list: set, +) -> None: + nr_session = new_retry_session() + + events = [] + count = total = 0 + + def send_events(): + nonlocal events + nonlocal count + + new_relic.post_events(nr_session, events) + + print_info(f'Sent {count} events.') + + # Attempt to release memory + del events + + events = [] + count = 0 + + + for log_entry in iter: + if count == max_rows: + send_events() + + events.append(pack_log_into_event( + log_entry, + labels, + numeric_fields_list, + )) + + count += 1 + total += 1 + + if len(events) > 0: + send_events() + + print_info(f'Sent a total of {total} events.') + + # Attempt to reclaim memory + gc.collect() + + +def load_data( + logs, + new_relic: NewRelic, + data_format: DataFormat, + labels: dict, + max_rows: int, + numeric_fields_list: set, +): + if data_format == DataFormat.LOGS: + load_as_logs( + logs, + new_relic, + labels, + max_rows, + ) + return + + load_as_events( + logs, + new_relic, + labels, + max_rows, + numeric_fields_list, + ) + +class Pipeline: + def __init__( + self, + config: Config, + data_cache: DataCache, + new_relic: NewRelic, + data_format: DataFormat, + labels: dict, + event_type_field_mappings: dict, + numeric_fields_list: set, + ): + self.config = config + self.data_cache = data_cache + self.new_relic = new_relic + self.data_format = data_format + self.labels = labels + self.event_type_field_mappings = event_type_field_mappings + self.numeric_fields_list = numeric_fields_list + self.max_rows = max( + self.config.get('max_rows', DEFAULT_MAX_ROWS), + MAX_ROWS, + ) + + def process_log_record( + self, + api: Api, + session: Session, + query: Query, + record: dict, + ): + record_id = str(record['Id']) + record_event_type = query.get("event_type", record['EventType']) + log_file_path = record['LogFile'] + interval = record['Interval'] + + # NOTE: only Hourly logs can be skipped, because Daily logs can change + # and the same record_id can contain different data. + if interval == 'Hourly' and self.data_cache and \ + self.data_cache.can_skip_downloading_logfile(record_id): + print_info( + f'Log lines for logfile with id {record_id} already cached, skipping download' + ) + return None + + load_data( + transform_log_lines( + export_log_lines( + api, + session, + log_file_path, + self.config.get('chunk_size', DEFAULT_CHUNK_SIZE) + ), + query, + record_id, + record_event_type, + self.data_cache, + self.event_type_field_mappings, + ), + self.new_relic, + self.data_format, + self.labels, + self.max_rows, + self.numeric_fields_list, + ) + + def process_event_records( + self, + query: Query, + records: list[dict], + ): + load_data( + transform_event_records( + records, + query, + self.data_cache, + ), + self.new_relic, + self.data_format, + self.labels, + self.max_rows, + self.numeric_fields_list, + ) + + def execute( + self, + api: Api, + session: Session, + query: Query, + records: list[dict], + ): + if is_logfile_response(records): + for record in records: + if 'LogFile' in record: + self.process_log_record( + api, + session, + query, + record, + ) + + if self.data_cache: + self.data_cache.flush() + + return + + self.process_event_records(query, records) + + # Flush the cache + if self.data_cache: + self.data_cache.flush() + +class PipelineFactory: + def __init__(self): + pass + + def new( + self, + config: Config, + data_cache: DataCache, + new_relic: NewRelic, + data_format: DataFormat, + labels: dict, + event_type_field_mappings: dict, + numeric_fields_list: set, + ): + return Pipeline( + config, + data_cache, + new_relic, + data_format, + labels, + event_type_field_mappings, + numeric_fields_list, + ) diff --git a/src/newrelic_logging/query.py b/src/newrelic_logging/query.py index 26bb501..e67223c 100644 --- a/src/newrelic_logging/query.py +++ b/src/newrelic_logging/query.py @@ -1,55 +1,83 @@ +import copy +from requests import RequestException, Session + +from . import SalesforceApiException +from .api import Api +from .config import Config +from .telemetry import print_info +from .util import get_iso_date_with_offset, substitute + class Query: - query = None - env = None - - def __init__(self, query) -> None: - if type(query) == dict: - self.query = query.get("query", "") - query.pop('query', None) - self.env = query - elif type(query) == str: - self.query = query - self.env = {} - - def get_query(self) -> str: - return self.query - - def set_query(self, query: str) -> None: + def __init__( + self, + api: Api, + query: str, + config: Config, + api_ver: str = None, + ): + self.api = api self.query = query - - def get_env(self) -> dict: - return self.env - -# NOTE: this sandbox can be jailbroken using the trick to exec statements inside an exec block, and run an import (and other tricks): -# https://book.hacktricks.xyz/generic-methodologies-and-resources/python/bypass-python-sandboxes#operators-and-short-tricks -# https://stackoverflow.com/a/3068475/2076108 -# Would be better to use a real sandbox like https://pypi.org/project/RestrictedPython/ or https://doc.pypy.org/en/latest/sandbox.html -# or parse a small language that only supports funcion calls and binary expressions. -def sandbox(code): - __import__ = None - __loader__ = None - __build_class__ = None - exec = None - - from datetime import datetime, timedelta - - def sf_time(t: datetime): - return t.isoformat(timespec='milliseconds') + "Z" - - def now(delta: timedelta = None): - if delta: - return sf_time(datetime.utcnow() + delta) - else: - return sf_time(datetime.utcnow()) - - try: - return eval(code) - except Exception as e: - return e - -def substitute(args: dict, query_template: str, env: dict) -> str: - for key, command in env.items(): - args[key] = sandbox(command) - for key, val in args.items(): - query_template = query_template.replace('{' + key + '}', val) - return query_template + self.config = config + self.api_ver = api_ver + + def get(self, key: str, default = None): + return self.config.get(key, default) + + def get_config(self): + return self.config + + def execute( + self, + session: Session, + ): + print_info(f'Running query {self.query}...') + return self.api.query(session, self.query, self.api_ver) + + +class QueryFactory: + def __init__(self): + pass + + def build_args( + self, + time_lag_minutes: int, + last_to_timestamp: str, + generation_interval: str, + ): + return { + 'to_timestamp': get_iso_date_with_offset(time_lag_minutes), + 'from_timestamp': last_to_timestamp, + 'log_interval_type': generation_interval, + } + + def get_env(self, q: dict) -> dict: + if 'env' in q and type(q['env']) is dict: + return q['env'] + + return {} + + def new( + self, + api: Api, + q: dict, + time_lag_minutes: int, + last_to_timestamp: str, + generation_interval: str, + ) -> Query: + qp = copy.deepcopy(q) + qq = qp.pop('query', '') + + return Query( + api, + substitute( + self.build_args( + time_lag_minutes, + last_to_timestamp, + generation_interval, + ), + qq, + self.get_env(qp), + ).replace(' ', '+'), + Config(qp), + qp.get('api_ver', None) + ) diff --git a/src/newrelic_logging/salesforce.py b/src/newrelic_logging/salesforce.py index 61085da..a117c72 100644 --- a/src/newrelic_logging/salesforce.py +++ b/src/newrelic_logging/salesforce.py @@ -1,29 +1,17 @@ -import base64 -import csv -import json -import sys from datetime import datetime, timedelta -import jwt -from cryptography.hazmat.primitives import serialization -import pytz -from requests import RequestException -import copy -import hashlib -from .cache import make_cache -from .env import Auth, AuthEnv -from .query import Query, substitute -from .telemetry import print_info, print_err +from requests import Session -class LoginException(Exception): - pass +from .api import ApiFactory +from .auth import Authenticator +from .cache import DataCache +from . import config as mod_config +from .pipeline import Pipeline +from . import query as mod_query +from .telemetry import print_info, print_warn +from .util import get_iso_date_with_offset -class SalesforceApiException(Exception): - err_code = 0 - def __init__(self, err_code: int, *args: object) -> None: - self.err_code = err_code - super().__init__(*args) - pass +CSV_SLICE_SIZE = 1000 SALESFORCE_CREATED_DATE_QUERY = \ "SELECT Id,EventType,CreatedDate,LogDate,Interval,LogFile,Sequence From EventLogFile Where CreatedDate>={" \ "from_timestamp} AND CreatedDate<{to_timestamp} AND Interval='{log_interval_type}'" @@ -31,548 +19,115 @@ def __init__(self, err_code: int, *args: object) -> None: "SELECT Id,EventType,CreatedDate,LogDate,Interval,LogFile,Sequence From EventLogFile Where LogDate>={" \ "from_timestamp} AND LogDate<{to_timestamp} AND Interval='{log_interval_type}'" -CSV_SLICE_SIZE = 1000 - -def base64_url_encode(json_obj): - json_str = json.dumps(json_obj) - encoded_bytes = base64.urlsafe_b64encode(json_str.encode('utf-8')) - encoded_str = str(encoded_bytes, 'utf-8') - return encoded_str class SalesForce: - auth = None - oauth_type = None - token_url = '' - query_template = None - data_cache = None - default_api_ver = '' - - def __init__(self, auth_env: AuthEnv, instance_name, config, event_type_fields_mapping, initial_delay, queries=[]): + def __init__( + self, + instance_name: str, + config: mod_config.Config, + data_cache: DataCache, + authenticator: Authenticator, + pipeline: Pipeline, + api_factory: ApiFactory, + query_factory: mod_query.QueryFactory, + initial_delay: int, + queries: list[dict] = None, + ): self.instance_name = instance_name - self.default_api_ver = config.get('api_ver', '52.0') - if 'auth' in config: - self.auth_data = config['auth'] - else: - self.auth_data = {'grant_type': auth_env.get_grant_type()} - if self.auth_data['grant_type'] == 'password': - # user/pass flow - try: - self.auth_data["client_id"] = auth_env.get_client_id() - self.auth_data["client_secret"] = auth_env.get_client_secret() - self.auth_data["username"] = auth_env.get_username() - self.auth_data["password"] = auth_env.get_password() - except: - print_err(f'Missing credentials for user/pass flow') - sys.exit(1) - elif self.auth_data['grant_type'] == 'urn:ietf:params:oauth:grant-type:jwt-bearer': - # jwt flow - try: - self.auth_data["client_id"] = auth_env.get_client_id() - self.auth_data["private_key"] = auth_env.get_private_key() - self.auth_data["subject"] = auth_env.get_subject() - self.auth_data["audience"] = auth_env.get_audience() - except: - print_err(f'Missing credentials for JWT flow') - sys.exit(1) - else: - print_err(f'Wrong or missing grant_type') - sys.exit(1) - - if 'token_url' in config: - self.token_url = config['token_url'] - else: - self.token_url = auth_env.get_token_url() - - try: - self.time_lag_minutes = config['time_lag_minutes'] - self.generation_interval = config['generation_interval'] - self.date_field = config['date_field'] - except KeyError as e: - print_err(f'Please specify a "{e.args[0]}" parameter for sfdc instance "{instance_name}" in config.yml') - sys.exit(1) - - self.last_to_timestamp = (datetime.utcnow() - timedelta( - minutes=self.time_lag_minutes + initial_delay)).isoformat(timespec='milliseconds') + "Z" - - if len(queries) > 0: - self.query_template = queries - else: - if self.date_field.lower() == "logdate": - self.query_template = SALESFORCE_LOG_DATE_QUERY - else: - self.query_template = SALESFORCE_CREATED_DATE_QUERY - - self.data_cache = make_cache(config) - self.event_type_fields_mapping = event_type_fields_mapping - - def clear_auth(self): - if self.data_cache: - try: - self.data_cache.redis.delete("auth") - except Exception as e: - print_err(f"Failed deleting 'auth' key from Redis: {e}") - exit(1) - self.auth = None - - def store_auth(self, auth_resp): - access_token = auth_resp['access_token'] - instance_url = auth_resp['instance_url'] - token_type = auth_resp['token_type'] - if self.data_cache: - print_info("Storing credentials on Redis.") - auth = { - "access_token": access_token, - "instance_url": instance_url, - "token_type": token_type - } - try: - self.data_cache.redis.hmset("auth", auth) - except Exception as e: - print_err(f"Failed setting 'auth' key: {e}") - exit(1) - self.auth = Auth(access_token, instance_url, token_type) - - def authenticate(self, oauth_type, session): - self.oauth_type = oauth_type - if self.data_cache: - try: - auth_exists = self.data_cache.redis.exists("auth") - except Exception as e: - print_err(f"Failed checking 'auth' key: {e}") - exit(1) - if auth_exists: - print_info("Retrieving credentials from Redis.") - #NOTE: hmget and hgetall both return byte arrays, not strings. We have to convert. - # We could fix it by adding the argument "decode_responses=True" to Redis constructor, - # but then we would have to change all places where we assume a byte array instead of a string, - # and refactoring in a language without static types is a pain. - try: - auth = self.data_cache.redis.hmget("auth", ["access_token", "instance_url", "token_type"]) - auth = { - "access_token": auth[0].decode("utf-8"), - "instance_url": auth[1].decode("utf-8"), - "token_type": auth[2].decode("utf-8") - } - self.store_auth(auth) - except Exception as e: - print_err(f"Failed getting 'auth' key: {e}") - exit(1) - - return True - - if oauth_type == 'password': - if not self.authenticate_with_password(session): - print_err(f"Error authenticating with {self.token_url}") - return False - print_info("Correctly authenticated with user/pass flow") - else: - if not self.authenticate_with_jwt(session): - print_err(f"Error authenticating with {self.token_url}") - return False - print_info("Correctly authenticated with JWT flow") - return True - - def authenticate_with_jwt(self, session): - try: - private_key_file = self.auth_data['private_key'] - client_id = self.auth_data['client_id'] - subject = self.auth_data['subject'] - audience = self.auth_data['audience'] - except KeyError as e: - print_err(f'Please specify a "{e.args[0]}" parameter under "auth" section ' - 'of salesforce instance in config.yml') - sys.exit(1) - - exp = int((datetime.utcnow() - timedelta(minutes=5)).timestamp()) - - private_key = open(private_key_file, 'r').read() - try: - key = serialization.load_ssh_private_key(private_key.encode(), password=b'') - except ValueError as e: - print_err(f'Authentication failed for {self.instance_name}. error message: {str(e)}') - return False - - jwt_claim_set = {"iss": client_id, - "sub": subject, - "aud": audience, - "exp": exp} - - signed_token = jwt.encode( - jwt_claim_set, - key, - algorithm='RS256', + self.data_cache = data_cache + self.pipeline = pipeline + self.query_factory = query_factory + self.time_lag_minutes = config.get( + mod_config.CONFIG_TIME_LAG_MINUTES, + mod_config.DEFAULT_TIME_LAG_MINUTES if not self.data_cache else 0, ) + self.date_field = config.get( + mod_config.CONFIG_DATE_FIELD, + mod_config.DATE_FIELD_LOG_DATE if not self.data_cache \ + else mod_config.DATE_FIELD_CREATE_DATE, + ) + self.generation_interval = config.get( + mod_config.CONFIG_GENERATION_INTERVAL, + mod_config.DEFAULT_GENERATION_INTERVAL, + ) + self.last_to_timestamp = get_iso_date_with_offset( + self.time_lag_minutes, + initial_delay, + ) + self.api = api_factory.new( + authenticator, + config.get('api_ver', '52.0'), + ) + self.queries = queries if queries else \ + [{ + 'query': SALESFORCE_LOG_DATE_QUERY \ + if self.date_field.lower() == 'logdate' \ + else SALESFORCE_CREATED_DATE_QUERY + }] - params = { - "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", - "assertion": signed_token, - "format": "json" - } - - headers = { - "Content-Type": "application/x-www-form-urlencoded", - "Accept": "application/json" - } - - try: - print_info(f'retrieving salesforce token at {self.token_url}') - resp = session.post(self.token_url, params=params, - headers=headers) - if resp.status_code != 200: - error_message = f'sfdc token request failed. http-status-code:{resp.status_code}, reason: {resp.text}' - print_err(f'Authentication failed for {self.instance_name}. message: {error_message}', file=sys.stderr) - return False - - self.store_auth(resp.json()) - return True - except ConnectionError as e: - print_err(f"SFDC auth failed for instance {self.instance_name}: {repr(e)}") - raise LoginException(f'authentication failed for sfdc instance {self.instance_name}') from e - except RequestException as e: - print_err(f"SFDC auth failed for instance {self.instance_name}: {repr(e)}") - raise LoginException(f'authentication failed for sfdc instance {self.instance_name}') from e - - def authenticate_with_password(self, session): - client_id = self.auth_data['client_id'] - client_secret = self.auth_data['client_secret'] - username = self.auth_data['username'] - password = self.auth_data['password'] - - params = { - "grant_type": "password", - "client_id": client_id, - "client_secret": client_secret, - "username": username, - "password": password - } - - headers = { - "Content-Type": "application/x-www-form-urlencoded", - "Accept": "application/json" - } - - try: - print_info(f'retrieving salesforce token at {self.token_url}') - resp = session.post(self.token_url, params=params, - headers=headers) - if resp.status_code != 200: - error_message = f'salesforce token request failed. status-code:{resp.status_code}, reason: {resp.reason}' - print_err(error_message) - return False - - self.store_auth(resp.json()) - return True - except ConnectionError as e: - print_err(f"SFDC auth failed for instance {self.instance_name}: {repr(e)}") - raise LoginException(f'authentication failed for sfdc instance {self.instance_name}') from e - except RequestException as e: - print_err(f"SFDC auth failed for instance {self.instance_name}: {repr(e)}") - raise LoginException(f'authentication failed for sfdc instance {self.instance_name}') from e - - def make_multiple_queries(self, query_objects) -> list[Query]: - return [self.make_single_query(Query(obj)) for obj in query_objects] - - def make_single_query(self, query_obj: Query) -> Query: - to_timestamp = (datetime.utcnow() - timedelta(minutes=self.time_lag_minutes)).isoformat( - timespec='milliseconds') + "Z" - from_timestamp = self.last_to_timestamp - - env = copy.deepcopy(query_obj.get_env().get('env', {})) - args = { - 'to_timestamp': to_timestamp, - 'from_timestamp': from_timestamp, - 'log_interval_type': self.generation_interval - } - query = substitute(args, query_obj.get_query(), env) - query = query.replace(' ', '+') - - query_obj.set_query(query) - return query_obj + def authenticate(self, sfdc_session: Session): + self.api.authenticate(sfdc_session) def slide_time_range(self): - self.last_to_timestamp = (datetime.utcnow() - timedelta(minutes=self.time_lag_minutes)).isoformat( - timespec='milliseconds') + "Z" - - def execute_query(self, query: Query, session): - api_ver = query.get_env().get("api_ver", self.default_api_ver) - url = f'{self.auth.get_instance_url()}/services/data/v{api_ver}/query?q={query.get_query()}' - - try: - headers = { - 'Authorization': f'Bearer {self.auth.get_access_token()}' - } - query_response = session.get(url, headers=headers) - if query_response.status_code != 200: - error_message = f'salesforce event log query failed. ' \ - f'status-code:{query_response.status_code}, ' \ - f'reason: {query_response.reason} ' \ - f'response: {query_response.text} ' - - print_err(f"SOQL query failed with code {query_response.status_code}: {error_message}") - raise SalesforceApiException(query_response.status_code, f'error when trying to run SOQL query. message: {error_message}') - return query_response.json() - except RequestException as e: - print_err(f"Error while trying SOQL query: {repr(e)}") - raise SalesforceApiException(-1, f'error when trying to run SOQL query. cause: {e}') from e + self.last_to_timestamp = get_iso_date_with_offset( + self.time_lag_minutes + ) # NOTE: Is it possible that different SF orgs have overlapping IDs? If this is possible, we should use a different # database for each org, or add a prefix to keys to avoid conflicts. - def download_file(self, session, url): - print_info(f"Downloading CSV file: {url}") - - headers = { - 'Authorization': f'Bearer {self.auth.get_access_token()}' - } - response = session.get(url, headers=headers) - if response.status_code != 200: - error_message = f'salesforce event log file download failed. ' \ - f'status-code: {response.status_code}, ' \ - f'reason: {response.reason} ' \ - f'response: {response.text}' - print_err(error_message) - raise SalesforceApiException(response.status_code, error_message) - return response - - def parse_csv(self, download_response, record_id, record_event_type, cached_messages): - content = download_response.content.decode('utf-8') - reader = csv.DictReader(content.splitlines()) - rows = [] - for row in reader: - if self.data_cache and self.data_cache.record_or_skip_row(record_id, row, cached_messages): - continue - rows.append(row) - return rows - - def fetch_logs(self, session): - print_info(f"Query object = {self.query_template}") - - if type(self.query_template) is list: - # "query_template" contains a list of objects, each one is a Query object - queries = self.make_multiple_queries(copy.deepcopy(self.query_template)) - response = self.fetch_logs_from_multiple_req(session, queries) - self.slide_time_range() - return response - else: - # "query_template" contains a string with the SOQL to run. - query = self.make_single_query(Query(self.query_template)) - response = self.fetch_logs_from_single_req(session, query) - self.slide_time_range() - return response - - def fetch_logs_from_multiple_req(self, session, queries: list[Query]): - logs = [] - for query in queries: - part_logs = self.fetch_logs_from_single_req(session, query) - logs.extend(part_logs) - return logs - - def fetch_logs_from_single_req(self, session, query: Query): - print_info(f'Running query {query.get_query()}') - response = self.execute_query(query, session) - - # Show query response - #print("Response = ", response) - - records = response['records'] - if self.is_logfile_response(records): - logs = [] - for record in records: - if 'LogFile' in record: - log = self.build_log_from_logfile(True, session, record, query) - if log is not None: - logs.extend(log) - else: - logs = self.build_log_from_event(records, query) - - return logs - - def is_logfile_response(self, records): - if len(records) > 0: - return 'LogFile' in records[0] - else: - return True - - def build_log_from_event(self, records, query: Query): - logs = [] - while True: - part_rows = self.extract_row_slice(records) - if len(part_rows) > 0: - logs.append(self.pack_event_into_log(part_rows, query)) - else: - break - return logs - - def pack_event_into_log(self, rows, query: Query): - log_entries = [] - for row in rows: - if 'Id' in row: - record_id = row['Id'] - if self.data_cache and self.data_cache.check_cached_id(record_id): - # Record cached, skip it - continue - else: - id_keys = query.get_env().get("id", []) - compound_id = "" - for key in id_keys: - if key not in row: - print_err(f"Error building compound id, key '{key}' not found") - raise Exception(f"Error building compound id, key '{key}' not found") - compound_id = compound_id + str(row.get(key, "")) - if compound_id != "": - m = hashlib.sha3_256() - m.update(compound_id.encode('utf-8')) - row['Id'] = m.hexdigest() - record_id = row['Id'] - if self.data_cache and self.data_cache.check_cached_id(record_id): - # Record cached, skip it - continue - - timestamp_attr = query.get_env().get("timestamp_attr", "CreatedDate") - if timestamp_attr in row: - created_date = row[timestamp_attr] - timestamp = int(datetime.strptime(created_date, '%Y-%m-%dT%H:%M:%S.%f%z').timestamp() * 1000) - else: - created_date = "" - timestamp = int(datetime.now().timestamp() * 1000) - - message = query.get_env().get("event_type", "SFEvent") - if 'attributes' in row and type(row['attributes']) == dict: - attributes = row.pop('attributes', []) - if 'type' in attributes and type(attributes['type']) == str: - event_type_attr_name = query.get_env().get("event_type", attributes['type']) - message = event_type_attr_name - row['EVENT_TYPE'] = event_type_attr_name - - if created_date != "": - message = message + " " + created_date - - timestamp_field_name = query.get_env().get("rename_timestamp", "timestamp") - row[timestamp_field_name] = int(timestamp) - - log_entry = { - 'message': message, - 'attributes': row, - } + def fetch_logs(self, session: Session) -> list[dict]: + print_info(f"Queries = {self.queries}") - if timestamp_field_name == 'timestamp': - log_entry[timestamp_field_name] = timestamp + for q in self.queries: + query = self.query_factory.new( + self.api, + q, + self.time_lag_minutes, + self.last_to_timestamp, + self.generation_interval, + ) - log_entries.append(log_entry) - return { - 'log_entries': log_entries - } + response = query.execute(session) - def build_log_from_logfile(self, retry, session, record, query: Query): - record_file_name = record['LogFile'] - record_id = str(record['Id']) - interval = record['Interval'] - record_event_type = query.get_env().get("event_type", record['EventType']) - - # NOTE: only Hourly logs can be skipped, because Daily logs can change and the same record_id can contain different data. - if interval == 'Hourly' and self.data_cache and \ - self.data_cache.can_skip_downloading_record(record_id): - print_info(f"Record {record_id} already cached, skip downloading CSV") - return None - - cached_messages = None if not self.data_cache else \ - self.data_cache.retrieve_cached_message_list(record_id) - - try: - download_response = self.download_file(session, f'{self.auth.get_instance_url()}{record_file_name}') - if download_response is None: - return None - except SalesforceApiException as e: - if e.err_code == 401: - if retry: - print_err("invalid token while downloading CSV file, retry auth and download...") - self.clear_auth() - if self.authenticate(self.oauth_type, session): - return self.build_log_from_logfile(False, session, record, query) - else: - return None - else: - print_err(f'salesforce event log file "{record_file_name}" download failed: {e}') - return None - else: - print_err(f'salesforce event log file "{record_file_name}" download failed: {e}') - return None - except RequestException as e: - print_err(f'salesforce event log file "{record_file_name}" download failed: {e}') - return None - - csv_rows = self.parse_csv(download_response, record_id, record_event_type, cached_messages) - - print_info(f"CSV rows = {len(csv_rows)}") - - # Split CSV rows into smaller chunks to avoid hitting API payload limits - logs = [] - row_offset = 0 - while True: - part_rows = self.extract_row_slice(csv_rows) - part_rows_len = len(part_rows) - if part_rows_len > 0: - logs.append(self.pack_csv_into_log(record, row_offset, part_rows, query)) - row_offset += part_rows_len - else: - break - - return logs - - def pack_csv_into_log(self, record, row_offset, csv_rows, query: Query): - record_id = str(record['Id']) - record_event_type = query.get_env().get("event_type", record['EventType']) - - log_entries = [] - for row_index, row in enumerate(csv_rows): - message = {} - if record_event_type in self.event_type_fields_mapping: - for field in self.event_type_fields_mapping[record_event_type]: - message[field] = row[field] - else: - message = row - - if row.get('TIMESTAMP'): - timestamp_obj = datetime.strptime(row.get('TIMESTAMP'), '%Y%m%d%H%M%S.%f') - timestamp = pytz.utc.localize(timestamp_obj).replace(microsecond=0).timestamp() - else: - timestamp = datetime.utcnow().replace(microsecond=0).timestamp() - - message['LogFileId'] = record_id - message.pop('TIMESTAMP', None) - - actual_event_type = message.pop('EVENT_TYPE', "SFEvent") - new_event_type = query.get_env().get("event_type", actual_event_type) - message['EVENT_TYPE'] = new_event_type - - timestamp_field_name = query.get_env().get("rename_timestamp", "timestamp") - message[timestamp_field_name] = int(timestamp) - - log_entry = { - 'message': "LogFile " + record_id + " row " + str(row_index + row_offset), - 'attributes': message - } - - if timestamp_field_name == 'timestamp': - log_entry[timestamp_field_name] = int(timestamp) - - log_entries.append(log_entry) - - return { - 'log_type': record_event_type, - 'Id': record_id, - 'CreatedDate': record['CreatedDate'], - 'LogDate': record['LogDate'], - 'log_entries': log_entries - } + if not response or not 'records' in response: + print_warn(f'no records returned for query {query.query}') + continue - # Slice record into smaller chunks - def extract_row_slice(self, rows): - part_rows = [] - i = 0 - while len(rows) > 0: - part_rows.append(rows.pop()) - i += 1 - if i >= CSV_SLICE_SIZE: - break - return part_rows + self.pipeline.execute( + self.api, + session, + query, + response['records'], + ) + + self.slide_time_range() + + +class SalesForceFactory: + def __init__(self): + pass + + def new( + self, + instance_name: str, + config: mod_config.Config, + data_cache: DataCache, + authenticator: Authenticator, + pipeline: Pipeline, + api_factory: ApiFactory, + query_factory: mod_query.QueryFactory, + initial_delay: int, + queries: list[dict] = None, + ): + return SalesForce( + instance_name, + config, + data_cache, + authenticator, + pipeline, + api_factory, + query_factory, + initial_delay, + queries, + ) diff --git a/src/newrelic_logging/telemetry.py b/src/newrelic_logging/telemetry.py index 3e4174e..f8984f3 100644 --- a/src/newrelic_logging/telemetry.py +++ b/src/newrelic_logging/telemetry.py @@ -18,19 +18,19 @@ class Telemetry: def __init__(self, integration_name: str) -> None: self.integration_name = integration_name - + def is_empty(self): return len(self.logs) == 0 - + def log_info(self, msg: str): self.record_log(msg, "info") - + def log_err(self, msg: str): self.record_log(msg, "error") - + def log_warn(self, msg: str): self.record_log(msg, "warn") - + def record_log(self, msg: str, level: str): log = { "timestamp": round(time.time() * 1000), @@ -41,15 +41,16 @@ def record_log(self, msg: str, level: str): } } self.logs.append(log) - + def clear(self): self.logs = [] def build_model(self): return [{ - "log_entries": self.logs + "common": {}, + "logs": self.logs, }] - + def print_log(msg: str, level: str): print(json.dumps({ "message": msg, diff --git a/src/newrelic_logging/util.py b/src/newrelic_logging/util.py new file mode 100644 index 0000000..d12d74d --- /dev/null +++ b/src/newrelic_logging/util.py @@ -0,0 +1,168 @@ +from copy import deepcopy +from datetime import datetime, timedelta +import hashlib +from typing import Any, Union + +from .telemetry import print_warn + +PRIMITIVE_TYPES = (str, int, float, bool, type(None)) + + +def is_logfile_response(records): + if len(records) > 0: + return 'LogFile' in records[0] + + return True + + +def generate_record_id(id_keys: list[str], record: dict) -> str: + compound_id = '' + for key in id_keys: + if key not in record: + raise Exception( + f'error building compound id, key \'{key}\' not found' + ) + + compound_id = compound_id + str(record.get(key, '')) + + if compound_id != '': + m = hashlib.sha3_256() + m.update(compound_id.encode('utf-8')) + return m.hexdigest() + + return '' + + +def maybe_convert_str_to_num(val: str) -> Union[int, str, float]: + try: + return int(val) + except (TypeError, ValueError) as _: + try: + return float(val) + except (TypeError, ValueError) as _: + print_warn(f'Type conversion error for "{val}"') + return val + + +def is_primitive(val: Any) -> bool: + vt = type(val) + + for t in PRIMITIVE_TYPES: + if vt == t: + return True + + return False + + +def process_query_result_helper( + item: tuple[str, Any], + name: list[str] = [], + vals: list[tuple[str, Any]] = [], +) -> list[tuple[str, Any]]: + (k, v) = item + + if k == 'attributes': + return vals + + if is_primitive(v): + return vals + [('.'.join(name + [k]), v)] + + if not type(v) is dict: + print_warn(f'ignoring structured element {k} in query result') + return vals + + new_vals = vals + + for item0 in v.items(): + new_vals = process_query_result_helper(item0, name + [k], new_vals) + + return new_vals + + +def process_query_result(query_result: dict) -> dict: + out = {} + + for item in query_result.items(): + for (k, v) in process_query_result_helper(item): + out[k] = v + + return out + + +# Make testing easier +def _utcnow(): + return datetime.utcnow() + +_UTCNOW = _utcnow + +def get_iso_date_with_offset( + time_lag_minutes: int = 0, + initial_delay: int = 0, +) -> str: + return ( + _UTCNOW() - timedelta( + minutes=(time_lag_minutes + initial_delay) + ) + ).isoformat( + timespec='milliseconds' + ) + 'Z' + + +# Make testing easier +def _now(): + return datetime.now() + +_NOW = _now + + +def get_timestamp(date_string: str = None): + if not date_string: + return int(_NOW().timestamp() * 1000) + + return int( + datetime.strptime( + date_string, + '%Y-%m-%dT%H:%M:%S.%f%z' + ).timestamp() * 1000 + ) + + +# NOTE: this sandbox can be jailbroken using the trick to exec statements inside +# an exec block, and run an import (and other tricks): +# https://book.hacktricks.xyz/generic-methodologies-and-resources/python/bypass-python-sandboxes#operators-and-short-tricks +# https://stackoverflow.com/a/3068475/2076108 +# Would be better to use a real sandbox like +# https://pypi.org/project/RestrictedPython/ or https://doc.pypy.org/en/latest/sandbox.html +# or parse a small language that only supports funcion calls and binary +# expressions. +# +# @TODO See if we can do this a different way We shouldn't be executing eval ever. + +def sandbox(code): + __import__ = None + __loader__ = None + __build_class__ = None + exec = None + + + def sf_time(t: datetime): + return t.isoformat(timespec='milliseconds') + "Z" + + def now(delta: timedelta = None): + if delta: + return sf_time(_UTCNOW() + delta) + else: + return sf_time(_UTCNOW()) + + try: + return eval(code) + except Exception as e: + return e + + +def substitute(args: dict, template: str, env: dict) -> str: + for key, command in env.items(): + args[key] = sandbox(command) + for key, val in args.items(): + template = template.replace('{' + key + '}', val) + return template diff --git a/src/tests/__init__.py b/src/tests/__init__.py new file mode 100644 index 0000000..30540fb --- /dev/null +++ b/src/tests/__init__.py @@ -0,0 +1,536 @@ +from datetime import timedelta +import json +from redis import RedisError +from requests import Session, RequestException + +from newrelic_logging import DataFormat, LoginException, SalesforceApiException +from newrelic_logging.api import Api, ApiFactory +from newrelic_logging.auth import Authenticator +from newrelic_logging.cache import DataCache +from newrelic_logging.config import Config +from newrelic_logging.newrelic import NewRelic +from newrelic_logging.pipeline import Pipeline +from newrelic_logging.query import Query, QueryFactory + + +class ApiStub: + def __init__( + self, + authenticator: Authenticator = None, + api_ver: str = None, + query_result: dict = None, + lines: list[str] = None, + raise_error = False, + raise_login_error = False, + ): + self.authenticator = authenticator + self.api_ver = api_ver + self.query_result = query_result + self.lines = lines + self.soql = None + self.query_api_ver = None + self.log_file_path = None + self.chunk_size = None + self.raise_error = raise_error + self.raise_login_error = raise_login_error + + def authenticate(self, session: Session): + if self.raise_login_error: + raise LoginException() + + self.authenticator.authenticate(session) + + def query(self, session: Session, soql: str, api_ver: str = None) -> dict: + self.soql = soql + self.query_api_ver = api_ver + + if self.raise_error: + raise SalesforceApiException() + + if self.raise_login_error: + raise LoginException() + + return self.query_result + + def get_log_file( + self, + session: Session, + log_file_path: str, + chunk_size: int, + ): + self.log_file_path = log_file_path + self.chunk_size = chunk_size + + if self.raise_error: + raise SalesforceApiException() + + if self.raise_login_error: + raise LoginException() + + yield from self.lines + +class ApiFactoryStub: + def __init__(self): + pass + + def new(self, authenticator: Authenticator, api_ver: str): + return ApiStub(authenticator, api_ver) + + +class AuthenticatorStub: + def __init__( + self, + config: Config = None, + data_cache: DataCache = None, + token_url: str = '', + access_token: str = '', + access_token_2: str = '', + instance_url: str = '', + grant_type: str = '', + authenticate_called: bool = False, + reauthenticate_called: bool = False, + raise_login_error = False, + ): + self.config = config + self.data_cache = data_cache + self.token_url = token_url + self.access_token = access_token + self.access_token_2 = access_token_2 + self.instance_url = instance_url + self.grant_type = grant_type + self.authenticate_called = authenticate_called + self.reauthenticate_called = reauthenticate_called + self.raise_login_error = raise_login_error + + def get_access_token(self) -> str: + return self.access_token + + def get_instance_url(self) -> str: + return self.instance_url + + def get_grant_type(self) -> str: + return self.grant_type + + def set_auth_data(self, access_token: str, instance_url: str) -> None: + pass + + def clear_auth(self) -> None: + pass + + def load_auth_from_cache(self) -> bool: + return False + + def store_auth(self, auth_resp: dict) -> None: + pass + + def authenticate_with_jwt(self, session: Session) -> None: + pass + + def authenticate_with_password(self, session: Session) -> None: + pass + + def authenticate( + self, + session: Session, + ) -> None: + self.authenticate_called = True + if self.raise_login_error: + raise LoginException('Unauthorized') + + def reauthenticate( + self, + session: Session, + ) -> None: + self.reauthenticate_called = True + if self.raise_login_error: + raise LoginException('Unauthorized') + + self.access_token = self.access_token_2 + + +class AuthenticatorFactoryStub: + def __init__(self): + pass + + def new(self, config: Config, data_cache: DataCache) -> Authenticator: + return AuthenticatorStub(config, data_cache) + + +class DataCacheStub: + def __init__( + self, + config: Config = None, + cached_logs = {}, + cached_events = [], + skip_record_ids = [], + ): + self.config = config + self.cached_logs = cached_logs + self.cached_events = cached_events + self.skip_record_ids = skip_record_ids + self.flush_called = False + + def can_skip_downloading_logfile(self, record_id: str) -> bool: + return record_id in self.skip_record_ids + + def check_or_set_log_line(self, record_id: str, row: dict) -> bool: + return record_id in self.cached_logs and \ + row['REQUEST_ID'] in self.cached_logs[record_id] + + def check_or_set_event_id(self, record_id: str) -> bool: + return record_id in self.cached_events + + def flush(self) -> None: + self.flush_called = True + + +class CacheFactoryStub: + def __init__(self): + pass + + def new(self, config: Config): + return DataCacheStub(config) + + +class NewRelicStub: + def __init__(self, config: Config = None): + self.config = config + self.logs = [] + self.events = [] + + def post_logs(self, session: Session, data: list[dict]) -> None: + self.logs.append(data) + + def post_events(self, session: Session, events: list[dict]) -> None: + self.events.append(events) + + +class NewRelicFactoryStub: + def __init__(self): + pass + + def new(self, config: Config): + return NewRelicStub(config) + + +class QueryStub: + def __init__( + self, + api: Api = None, + query: str = '', + config: Config = Config({}), + api_ver: str = None, + result: dict = { 'records': [] }, + ): + self.api = api + self.query = query + self.config = config + self.api_ver = api_ver + self.executed = False + self.result = result + + def get(self, key: str, default = None): + return self.config.get(key, default) + + def get_config(self): + return self.config + + def execute(self, session: Session = None): + self.executed = True + return self.result + + +class QueryFactoryStub: + def __init__(self, query: QueryStub = None ): + self.query = query + self.queries = [] if not query else None + pass + + def new( + self, + api: Api, + q: dict, + time_lag_minutes: int = 0, + last_to_timestamp: str = '', + generation_interval: str = '', + ) -> Query: + if self.query: + return self.query + + qq = QueryStub(api, q['query'], q) + self.queries.append(qq) + return qq + + +class PipelineStub: + def __init__( + self, + config: Config = Config({}), + data_cache: DataCache = None, + new_relic: NewRelic = None, + data_format: DataFormat = DataFormat.LOGS, + labels: dict = {}, + event_type_fields_mapping: dict = {}, + numeric_fields_list: set = set(), + raise_error: bool = False, + raise_login_error: bool = False, + ): + self.config = config + self.data_cache = data_cache + self.new_relic = new_relic + self.data_format = data_format + self.labels = labels + self.event_type_fields_mapping = event_type_fields_mapping + self.numeric_fields_list = numeric_fields_list + self.queries = [] + self.executed = False + self.raise_error = raise_error + self.raise_login_error = raise_login_error + + def execute( + self, + api: Api, + session: Session, + query: Query, + records: list[dict], + ): + if self.raise_error: + raise SalesforceApiException() + + if self.raise_login_error: + raise LoginException() + + self.queries.append(query) + self.executed = True + + +class PipelineFactoryStub: + def __init__(self): + pass + + def new( + self, + config: Config, + data_cache: DataCache, + new_relic: NewRelic, + data_format: DataFormat, + labels: dict, + event_type_fields_mapping: dict, + numeric_fields_list: set, + ): + return PipelineStub( + config, + data_cache, + new_relic, + data_format, + labels, + event_type_fields_mapping, + numeric_fields_list, + ) + + +class RedisStub: + def __init__(self, test_cache, raise_error = False): + self.expiry = {} + self.test_cache = test_cache + self.raise_error = raise_error + + def exists(self, key): + if self.raise_error: + raise RedisError('raise_error set') + + return key in self.test_cache + + def set(self, key, item): + if self.raise_error: + raise RedisError('raise_error set') + + self.test_cache[key] = item + + def smembers(self, key): + if self.raise_error: + raise RedisError('raise_error set') + + if not key in self.test_cache: + return set() + + if not type(self.test_cache[key]) is set: + raise RedisError(f'{key} is not a set') + + return self.test_cache[key] + + def sadd(self, key, *values): + if self.raise_error: + raise RedisError('raise_error set') + + if key in self.test_cache and not type(self.test_cache[key]) is set: + raise RedisError(f'{key} is not a set') + + if not key in self.test_cache: + self.test_cache[key] = set() + + for v in values: + self.test_cache[key].add(v) + + def expire(self, key, time): + if self.raise_error: + raise RedisError('raise_error set') + + self.expiry[key] = time + + +class BackendStub: + def __init__(self, test_cache, raise_error = False): + self.redis = RedisStub(test_cache, raise_error) + + def exists(self, key): + return self.redis.exists(key) + + def put(self, key, item): + self.redis.set(key, item) + + def get_set(self, key): + return self.redis.smembers(key) + + def set_add(self, key, *values): + self.redis.sadd(key, *values) + + def set_expiry(self, key, days): + self.redis.expire(key, timedelta(days=days)) + + +class BackendFactoryStub: + def __init__(self, raise_error = False): + self.raise_error = raise_error + + def new(self, _: Config): + if self.raise_error: + raise RedisError('raise_error set') + + return BackendStub({}) + + +class MultiRequestSessionStub: + def __init__(self, responses=[], raise_error=False): + self.raise_error = raise_error + self.requests = [] + self.responses = responses + self.count = 0 + + def get(self, *args, **kwargs): + self.requests.append({ + 'url': args[0], + 'headers': kwargs['headers'], + 'stream': kwargs['stream'] if 'stream' in kwargs else None, + }) + + if self.raise_error: + raise RequestException('raise_error set') + + if self.count < len(self.responses): + self.count += 1 + + return self.responses[self.count - 1] + + +class ResponseStub: + def __init__(self, status_code, reason, text, lines, encoding=None): + self.status_code = status_code + self.reason = reason + self.text = text + self.lines = lines + self.chunk_size = None + self.decode_unicode = None + self.encoding = encoding + self.iter_lines_called = False + + def iter_lines(self, *args, **kwargs): + self.iter_lines_called = True + + if 'chunk_size' in kwargs: + self.chunk_size = kwargs['chunk_size'] + + if 'decode_unicode' in kwargs: + self.decode_unicode = kwargs['decode_unicode'] + + yield from self.lines + + def json(self, *args, **kwargs): + return json.loads(self.text) + + +class SalesForceStub: + def __init__( + self, + instance_name: str, + config: Config, + data_cache: DataCache, + authenticator: Authenticator, + pipeline: Pipeline, + api_factory: ApiFactory, + query_factory: QueryFactory, + initial_delay: int, + queries: list[dict] = None, + ): + self.instance_name = instance_name + self.config = config + self.data_cache = data_cache + self.authenticator = authenticator + self.pipeline = pipeline + self.api_factory = api_factory + self.query_factory = query_factory + self.initial_delay = initial_delay + self.queries = queries + + +class SalesForceFactoryStub: + def __init__(self): + pass + + def new( + self, + instance_name: str, + config: Config, + data_cache: DataCache, + authenticator: Authenticator, + pipeline: Pipeline, + api_factory: ApiFactory, + query_factory: QueryFactory, + initial_delay: int, + queries: list[dict] = None, + ): + return SalesForceStub( + instance_name, + config, + data_cache, + authenticator, + pipeline, + api_factory, + query_factory, + initial_delay, + queries, + ) + + +class SessionStub: + def __init__(self, raise_error=False, raise_connection_error=False): + self.raise_error = raise_error + self.raise_connection_error = raise_connection_error + self.response = None + self.headers = None + self.url = None + self.stream = None + + def get(self, *args, **kwargs): + self.url = args[0] + self.headers = kwargs['headers'] + if 'stream' in kwargs: + self.stream = kwargs['stream'] + + if self.raise_connection_error: + raise ConnectionError('raise_connection_error set') + + if self.raise_error: + raise RequestException('raise_error set') + + return self.response diff --git a/src/tests/sample_event_records.json b/src/tests/sample_event_records.json new file mode 100644 index 0000000..5f7b1a1 --- /dev/null +++ b/src/tests/sample_event_records.json @@ -0,0 +1,76 @@ +[ + { + "attributes": { + "type": "Account", + "url": "/services/data/v58.0/sobjects/Account/12345" + }, + "Id": "000012345", + "Name": "My Account", + "BillingCity": null, + "CreatedDate": "2024-03-11T00:00:00.000+0000", + "CreatedBy": { + "attributes": { + "type": "User", + "url": "/services/data/v55.0/sobjects/User/12345" + }, + "Name": "Foo Bar", + "Profile": { + "attributes": { + "type": "Profile", + "url": "/services/data/v55.0/sobjects/Profile/12345" + }, + "Name": "Beep Boop" + }, + "UserType": "Bip Bop" + } + }, + { + "attributes": { + "type": "Account", + "url": "/services/data/v58.0/sobjects/Account/54321" + }, + "Id": "000054321", + "Name": "My Other Account", + "BillingCity": null, + "CreatedDate": "2024-03-10T00:00:00.000+0000", + "CreatedBy": { + "attributes": { + "type": "User", + "url": "/services/data/v55.0/sobjects/User/12345" + }, + "Name": "Foo Bar", + "Profile": { + "attributes": { + "type": "Profile", + "url": "/services/data/v55.0/sobjects/Profile/12345" + }, + "Name": "Beep Boop" + }, + "UserType": "Bip Bop" + } + }, + { + "attributes": { + "type": "Account", + "url": "/services/data/v58.0/sobjects/Account/00000" + }, + "Name": "My Last Account", + "BillingCity": null, + "CreatedDate": "2024-03-09T00:00:00.000+0000", + "CreatedBy": { + "attributes": { + "type": "User", + "url": "/services/data/v55.0/sobjects/User/12345" + }, + "Name": "Foo Bar", + "Profile": { + "attributes": { + "type": "Profile", + "url": "/services/data/v55.0/sobjects/Profile/12345" + }, + "Name": "Beep Boop" + }, + "UserType": "Bip Bop" + } + } +] diff --git a/src/tests/sample_log_lines.csv b/src/tests/sample_log_lines.csv new file mode 100644 index 0000000..f8a120a --- /dev/null +++ b/src/tests/sample_log_lines.csv @@ -0,0 +1,3 @@ +"EVENT_TYPE","TIMESTAMP","REQUEST_ID","ORGANIZATION_ID","USER_ID","RUN_TIME","CPU_TIME","URI","SESSION_KEY","LOGIN_KEY","TYPE","METHOD","SUCCESS","STATUS_CODE","TIME","REQUEST_SIZE","RESPONSE_SIZE","URL","TIMESTAMP_DERIVED","USER_ID_DERIVED","CLIENT_IP","URI_ID_DERIVED" +"ApexCallout","20240311160000.000","YYZ:abcdef123456","001122334455667","000000001111111","2112","10","TEST-LOG-1","","","REST","POST","1","200","1234","8192","4096","""https://test.local.test""","2024-03-11T16:00:00.000Z","001122334455667","","" +"ApexCallout","20240311170000.000","YYZ:fedcba654321","776655443322110","111111110000000","5150","20","TEST-LOG-2","","","REST","POST","1","200","4321","8192","4096","""https://test.local.test""","2024-03-11T17:00:00.000Z","776655443322110","","" diff --git a/src/tests/sample_log_lines.json b/src/tests/sample_log_lines.json new file mode 100644 index 0000000..10d1eff --- /dev/null +++ b/src/tests/sample_log_lines.json @@ -0,0 +1,48 @@ +[{ + "EVENT_TYPE": "ApexCallout", + "TIMESTAMP": "20240311160000.000", + "REQUEST_ID": "YYZ:abcdef123456", + "ORGANIZATION_ID": "000000001111111", + "USER_ID": "001122334455667", + "RUN_TIME": "2112", + "CPU_TIME": "10", + "URI": "TEST-LOG-1", + "SESSION_KEY": "", + "LOGIN_KEY": "", + "TYPE": "REST", + "METHOD": "POST", + "SUCCESS": "1", + "STATUS_CODE": "200", + "TIME": "1234", + "REQUEST_SIZE": "8192", + "RESPONSE_SIZE": "4096", + "URL": "\"https://test.local.test\"", + "TIMESTAMP_DERIVED": "2024-03-11T16:00:00.000Z", + "USER_ID_DERIVED": "001122334455667", + "CLIENT_IP": "", + "URI_ID_DERIVED": "" +}, +{ + "EVENT_TYPE": "ApexCallout", + "TIMESTAMP": "20240311170000.000", + "REQUEST_ID": "YYZ:fedcba654321", + "ORGANIZATION_ID": "111111110000000", + "USER_ID": "776655443322110", + "RUN_TIME": "5150", + "CPU_TIME": "20", + "URI": "TEST-LOG-2", + "SESSION_KEY": "", + "LOGIN_KEY": "", + "TYPE": "REST", + "METHOD": "POST", + "SUCCESS": "1", + "STATUS_CODE": "200", + "TIME": "4321", + "REQUEST_SIZE": "8192", + "RESPONSE_SIZE": "4096", + "URL": "\"https://test.local.test\"", + "TIMESTAMP_DERIVED": "2024-03-11T17:00:00.000Z", + "USER_ID_DERIVED": "776655443322110", + "CLIENT_IP": "", + "URI_ID_DERIVED": "" +}] diff --git a/src/tests/sample_log_records.json b/src/tests/sample_log_records.json new file mode 100644 index 0000000..b48cdd5 --- /dev/null +++ b/src/tests/sample_log_records.json @@ -0,0 +1,28 @@ +[ + { + "attributes": { + "type": "EventLogFile", + "url": "/services/data/v52.0/sobjects/EventLogFile/00001111AAAABBBB" + }, + "Id": "00001111AAAABBBB", + "EventType": "ApexCallout", + "CreatedDate": "2024-03-11T15:00:00.000+0000", + "LogDate": "2024-03-11T02:00:00.000+0000", + "Interval": "Hourly", + "LogFile": "/services/data/v52.0/sobjects/EventLogFile/00001111AAAABBBB/LogFile", + "Sequence": 1 + }, + { + "attributes": { + "type": "EventLogFile", + "url": "/services/data/v52.0/sobjects/EventLogFile/00002222AAAABBBB" + }, + "Id": "00002222AAAABBBB", + "EventType": "ApexCallout", + "CreatedDate": "2024-03-11T16:00:00.000+0000", + "LogDate": "2024-03-11T03:00:00.000+0000", + "Interval": "Hourly", + "LogFile": "/services/data/v52.0/sobjects/EventLogFile/00002222AAAABBBB/LogFile", + "Sequence": 2 + } +] diff --git a/src/tests/test_api.py b/src/tests/test_api.py new file mode 100644 index 0000000..9fd7371 --- /dev/null +++ b/src/tests/test_api.py @@ -0,0 +1,891 @@ +from requests import Session +import unittest + +from newrelic_logging import api, SalesforceApiException, LoginException +from . import \ + AuthenticatorStub, \ + ResponseStub, \ + SessionStub, \ + MultiRequestSessionStub + + +class TestApi(unittest.TestCase): + def test_get_calls_session_get_with_url_access_token_and_default_stream_flag_and_invokes_cb_and_returns_result_on_200(self): + ''' + get() makes a request with the given URL, access token, and default stream flag and invokes the callback with response and returns the result on a 200 status code + given: an authenticator + and given: a session + and given: a url + and given: a service URL + and given: a callback + when: get() is called + then: session.get() is called with full URL, access token, and default stream flag + and when: response status code is 200 + then: invokes callback with response and returns result + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + ) + response = ResponseStub(200, 'OK', 'OK', []) + session = SessionStub() + session.response = response + + def cb(response): + return response.text + + # execute + val = api.get(auth, session, '/foo', cb) + + # verify + self.assertEqual(session.url, 'https://my.salesforce.test/foo') + self.assertTrue('Authorization' in session.headers) + self.assertEqual(session.headers['Authorization'], 'Bearer 123456') + self.assertFalse(session.stream) + self.assertIsNotNone(val) + self.assertEqual(val, 'OK') + + def test_get_calls_session_get_with_url_access_token_and_given_stream_flag_and_invokes_cb_and_returns_result_on_200(self): + ''' + get() makes a request with the given URL, access token, and stream flag and invokes the callback with response and returns the result on a 200 status code + given: an authenticator + and given: a session + and given: a url + and given: a callback + when: get() is called + then: session.get() is called with full URL, access token, and stream flag + and when: response status code is 200 + then: invokes callback with response and returns result + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + ) + response = ResponseStub(200, 'OK', 'OK', []) + session = SessionStub() + session.response = response + + def cb(response): + return response.text + + # execute + val = api.get(auth, session, '/foo', cb, stream=True) + + # verify + self.assertEqual(session.url, 'https://my.salesforce.test/foo') + self.assertTrue('Authorization' in session.headers) + self.assertEqual(session.headers['Authorization'], 'Bearer 123456') + self.assertTrue(session.stream) + self.assertIsNotNone(val) + self.assertEqual(val, 'OK') + + def test_get_raises_on_response_not_200_or_401(self): + ''' + get() raises a SalesforceApiException when the response status code is not 200 or 401 + given: an authenticator + and given: a session + and given: a service url + and given: a callback + when: get() is called + then: session.get() is called with full URL, access token, and default stream flag + and when: response status code is not 200 + and when: response status code is not 401 + then: raises a SalesforceApiException + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + ) + response = ResponseStub(500, 'Server Error', 'Server Error', []) + session = SessionStub() + session.response = response + + def cb(response): + return response + + # execute / verify + with self.assertRaises(SalesforceApiException) as _: + _ = api.get( + auth, + session, + '/foo', + cb, + ) + + self.assertEqual(session.url, 'https://my.salesforce.test/foo') + self.assertTrue('Authorization' in session.headers) + self.assertEqual(session.headers['Authorization'], 'Bearer 123456') + self.assertFalse(session.stream) + + def test_get_raises_on_connection_error(self): + ''' + get() raises a SalesforceApiException when session.get() raises a ConnectionError + given: an authenticator + and given: a session + and given: a service url + and given: a callback + when: get() is called + then: session.get() is called with full URL, access token, and default stream flag + and when: session.get() raises a ConnectionError + then: raises a SalesforceApiException + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + ) + response = ResponseStub(500, 'Server Error', 'Server Error', []) + session = SessionStub(raise_connection_error=True) + session.response = response + + def cb(response): + return response + + # execute / verify + with self.assertRaises(SalesforceApiException) as _: + _ = api.get( + auth, + session, + '/foo', + cb, + ) + + self.assertEqual(session.url, 'https://my.salesforce.test/foo') + self.assertTrue('Authorization' in session.headers) + self.assertEqual(session.headers['Authorization'], 'Bearer 123456') + self.assertFalse(session.stream) + + def test_get_raises_on_request_exception(self): + ''' + get() raises a SalesforceApiException when session.get() raises a RequestException + given: an authenticator + and given: a session + and given: a service url + and given: a callback + when: get() is called + then: session.get() is called with full URL, access token, and default stream flag + and when: session.get() raises a RequestException + then: raises a SalesforceApiException + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + ) + response = ResponseStub(500, 'Server Error', 'Server Error', []) + session = SessionStub(raise_error=True) + session.response = response + + def cb(response): + return response + + # execute / verify + with self.assertRaises(SalesforceApiException) as _: + _ = api.get( + auth, + session, + '/foo', + cb, + ) + + self.assertEqual(session.url, 'https://my.salesforce.test/foo') + self.assertTrue('Authorization' in session.headers) + self.assertEqual(session.headers['Authorization'], 'Bearer 123456') + self.assertFalse(session.stream) + + def test_get_raises_login_exception_if_reauthenticate_does(self): + ''' + get() raises a LoginException when the status code is 401 and reauthenticate() raises a LoginException + given: an authenticator + and given: a session + and given: a service url + and given: a callback + when: get() is called + then: session.get() is called with full URL, access token, and default stream flag + and when: response status code is 401 + then: reauthenticate() is called + and when: reauthenticate() raises a LoginException + then: raises a LoginException + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + raise_login_error=True + ) + response = ResponseStub(401, 'Unauthorized', 'Unauthorized', []) + session = SessionStub() + session.response = response + + def cb(response): + return response + + # execute / verify + with self.assertRaises(LoginException) as _: + _ = api.get( + auth, + session, + '/foo', + cb, + ) + + self.assertEqual(session.url, 'https://my.salesforce.test/foo') + self.assertTrue('Authorization' in session.headers) + self.assertEqual(session.headers['Authorization'], 'Bearer 123456') + self.assertFalse(session.stream) + self.assertTrue(auth.reauthenticate_called) + + def test_get_calls_reauthenticate_on_401_and_invokes_cb_with_response_on_200(self): + ''' + get() calls reauthenticate() on a 401 and then invokes the callback with response and returns the result on a 200 status code + given: an authenticator + and given: a session + and given: a service url + and given: a callback + when: get() is called + then: session.get() is called with given URL, access token, and default stream flag + and when: response status code is 401 + then: reauthenticate() is called + and when: reauthenticate() does not throw a LoginException + then: request is executed again with the same URL and stream setting as the first call to session.get() and the second access token + and when: it returns a 200 + then: calls callback with response and returns result + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + access_token_2='567890', + ) + response1 = ResponseStub(401, 'Unauthorized', 'Unauthorized', []) + response2 = ResponseStub(200, 'OK', 'OK', []) + session = MultiRequestSessionStub(responses=[response1, response2]) + + def cb(response): + return response.text + + # execute + val = api.get(auth, session, '/foo', cb) + + # verify + self.assertEqual(len(session.requests), 2) + self.assertEqual( + session.requests[0]['url'], + 'https://my.salesforce.test/foo', + ) + self.assertTrue('Authorization' in session.requests[0]['headers']) + self.assertEqual( + session.requests[0]['headers']['Authorization'], + 'Bearer 123456', + ) + self.assertEqual( + session.requests[0]['stream'], + False, + ) + self.assertTrue(auth.reauthenticate_called) + self.assertEqual( + session.requests[1]['url'], + 'https://my.salesforce.test/foo', + ) + self.assertTrue('Authorization' in session.requests[1]['headers']) + self.assertEqual( + session.requests[1]['headers']['Authorization'], + 'Bearer 567890', + ) + self.assertEqual( + session.requests[1]['stream'], + False, + ) + self.assertIsNotNone(val) + self.assertEqual(val, 'OK') + + def test_get_passed_correct_params_after_reauthenticate(self): + ''' + get() receives the correct set of parameters when it is called after reauthenticate() succeeds + given: an authenticator + and given: a session + and given: a service url + and given: a callback + when: get() is called + then: session.get() is called with full URL, access token, and default stream flag + and when: response status code is 401 + then: reauthenticate() is called + and when: reauthenticate() does not throw a LoginException + then: request is executed again with the same URL and stream setting as the first call to session.get() and the second access token + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + access_token_2='567890' + ) + response1 = ResponseStub(401, 'Unauthorized', 'Unauthorized', []) + response2 = ResponseStub(200, 'OK', 'OK', []) + session = MultiRequestSessionStub(responses=[response1, response2]) + + def cb(response): + return response.text + + # execute + val = api.get(auth, session, '/foo', cb, stream=True) + + # verify + self.assertEqual(len(session.requests), 2) + self.assertTrue(auth.reauthenticate_called) + self.assertEqual( + session.requests[0]['url'], + 'https://my.salesforce.test/foo', + ) + self.assertTrue('Authorization' in session.requests[0]['headers']) + self.assertEqual( + session.requests[0]['headers']['Authorization'], + 'Bearer 123456', + ) + self.assertEqual( + session.requests[0]['stream'], + True, + ) + self.assertEqual( + session.requests[1]['url'], + 'https://my.salesforce.test/foo', + ) + self.assertTrue('Authorization' in session.requests[1]['headers']) + self.assertEqual( + session.requests[1]['headers']['Authorization'], + 'Bearer 567890', + ) + self.assertEqual( + session.requests[1]['stream'], + True, + ) + self.assertIsNotNone(val) + self.assertEqual(val, 'OK') + + def test_get_calls_reauthenticate_on_401_and_raises_on_non_200(self): + ''' + get() function calls reauthenticate() on a 401 and then throws a SalesforceApiException on a non-200 status code + given: an authenticator + and given: a session + and given: a service url + and given: a callback + when: get() is called + then: session.get() is called with full URL, access token, and default stream flag + and when: response status code is 401 + then: reauthenticate() is called + and when: reauthenticate() does not throw a LoginException + then: request is executed again with the same URL and stream setting as the first call to session.get() and the second access token + and when: it returns a non-200 status code + then: throws a SalesforceApiException + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + access_token_2='567890', + ) + response1 = ResponseStub(401, 'Unauthorized', 'Unauthorized', []) + response2 = ResponseStub(401, 'Unauthorized', 'Unauthorized 2', []) + session = MultiRequestSessionStub(responses=[response1, response2]) + + def cb(response): + return response + + # execute / verify + with self.assertRaises(SalesforceApiException) as _: + _ = api.get( + auth, + session, + '/foo', + cb, + ) + + self.assertEqual(len(session.requests), 2) + self.assertTrue(auth.reauthenticate_called) + self.assertEqual( + session.requests[0]['url'], + 'https://my.salesforce.test/foo', + ) + self.assertTrue('Authorization' in session.requests[0]['headers']) + self.assertEqual( + session.requests[0]['headers']['Authorization'], + 'Bearer 123456', + ) + self.assertEqual( + session.requests[0]['stream'], + False, + ) + self.assertEqual( + session.requests[1]['url'], + 'https://my.salesforce.test/foo', + ) + self.assertTrue('Authorization' in session.requests[1]['headers']) + self.assertEqual( + session.requests[1]['headers']['Authorization'], + 'Bearer 567890', + ) + self.assertEqual( + session.requests[1]['stream'], + False, + ) + + def test_stream_lines_sets_fallback_encoding_and_calls_iter_lines_with_chunk_size_and_decode_unicode(self): + ''' + stream_lines() sets a default encoding on the response and calls iter_lines with the given chunk size and the decode_unicode flag = True + given: a response + and given: a chunk size + when: encoding on response is None + then: fallback utf-8 is used and iter_lines is called with given chunk size and decode_unicode flag = True + ''' + + # setup + response = ResponseStub(200, 'OK', 'OK', ['foo lines', 'bar line']) + + # execute + lines = api.stream_lines(response, 1024) + + # verify + next(lines) + next(lines) + # NOTE: this has to be done _after_ the generator iterator is called at + # least once since the generator function is not run until the first + # call to next() + self.assertEqual(response.encoding, 'utf-8') + self.assertEqual(response.chunk_size, 1024) + self.assertTrue(response.decode_unicode) + self.assertTrue(response.iter_lines_called) + + def test_stream_lines_uses_default_encoding_and_calls_iter_lines_with_chunk_size_and_decode_unicode(self): + ''' + stream_lines() sets a default encoding on the response and calls iter_lines with the given chunk size and the decode_unicode flag = True + given: a response + and given: a chunk size + when: encoding on response is set + then: response encoding is used and iter_lines is called with given chunk size and decode_unicode flag = True + ''' + + # setup + response = ResponseStub( + 200, + 'OK', + 'OK', + ['foo lines', 'bar line'], + encoding='iso-8859-1', + ) + + # execute + lines = api.stream_lines(response, 1024) + + # verify + next(lines) + next(lines) + # NOTE: this has to be done _after_ the generator iterator is called at + # least once since the generator function is not run until the first + # call to next() + self.assertEqual(response.encoding, 'iso-8859-1') + self.assertEqual(response.chunk_size, 1024) + self.assertTrue(response.decode_unicode) + self.assertTrue(response.iter_lines_called) + + def test_authenticate_calls_authenticator_authenticate(self): + ''' + authenticate() calls authenticate() on the backing authenticator + given: an authenticator + and given: an api version + and given: a session + when: authenticate() is called + then: authenticator.authenticate() is called + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + ) + session = SessionStub() + + # execute + sf_api = api.Api(auth, '55.0') + sf_api.authenticate(session) + + # verify + self.assertTrue(auth.authenticate_called) + + def test_authenticate_raises_login_exception_if_authenticator_authenticate_does(self): + ''' + authenticate() calls authenticate() on the backing authenticator and raises a LoginException if authenticate() does + given: an authenticator + and given: an api version + and given: a session + when: authenticate() is called + then: authenticator.authenticate() is called + and when: authenticator.authenticate() raises a LoginException + then: authenticate() raises a LoginException + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + raise_login_error=True, + ) + session = SessionStub() + + # execute / verify + with self.assertRaises(LoginException) as _: + sf_api = api.Api(auth, '55.0') + sf_api.authenticate(session) + + self.assertTrue(auth.authenticate_called) + + def test_query_requests_correct_url_with_access_token_and_returns_json_response_on_success(self): + ''' + query() calls the correct query API url with the access token and returns a JSON response when no errors occur + given: an authenticator + and given: an api version + and given: a session + and given: a query + when: query() is called + then: session.get() is called with correct URL and access token + and: stream is set to False + and when: session.get() response status code is 200 + then: calls callback with response and returns a JSON response + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + ) + session = SessionStub() + session.response = ResponseStub(200, 'OK', '{"foo": "bar"}', [] ) + + # execute + sf_api = api.Api(auth, '55.0') + resp = sf_api.query( + session, + 'SELECT+LogFile+FROM+EventLogFile', + ) + + # verify + + self.assertEqual( + session.url, + f'https://my.salesforce.test/services/data/v55.0/query?q=SELECT+LogFile+FROM+EventLogFile', + ) + self.assertTrue('Authorization' in session.headers) + self.assertEqual(session.headers['Authorization'], 'Bearer 123456') + self.assertFalse(session.stream) + self.assertIsNotNone(resp) + self.assertTrue(type(resp) is dict) + self.assertTrue('foo' in resp) + self.assertEqual(resp['foo'], 'bar') + + def test_query_requests_correct_url_with_access_token_given_api_version_and_returns_json_response_on_success(self): + ''' + query() calls the correct query API url with the access token when a specific api version is given and returns a JSON response + given: an authenticator + and given: an api version + and given: a session + and given: a query + when: query() is called + and when: the api version parameter is specified + then: session.get() is called with correct URL and access token + and: stream is set to False + and when: response status code is 200 + then: calls callback with response and returns a JSON response + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + ) + session = SessionStub() + session.response = ResponseStub(200, 'OK', '{"foo": "bar"}', [] ) + + # execute + sf_api = api.Api(auth, '55.0') + resp = sf_api.query( + session, + 'SELECT+LogFile+FROM+EventLogFile', + '52.0', + ) + + # verify + + self.assertEqual( + session.url, + f'https://my.salesforce.test/services/data/v52.0/query?q=SELECT+LogFile+FROM+EventLogFile', + ) + self.assertTrue('Authorization' in session.headers) + self.assertEqual(session.headers['Authorization'], 'Bearer 123456') + self.assertFalse(session.stream) + self.assertIsNotNone(resp) + self.assertTrue(type(resp) is dict) + self.assertTrue('foo' in resp) + self.assertEqual(resp['foo'], 'bar') + + def test_query_raises_login_exception_if_get_does(self): + ''' + query() calls the correct query API url with the access token and raises LoginException if get does + given: an authenticator + and given: an api version + and given: a session + and given: a query + when: query() is called + then: session.get() is called with correct URL and access token + and: stream is set to False + and when: session.get() raises a LoginException + then: query() raises a LoginException + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + raise_login_error=True + ) + session = SessionStub() + session.response = ResponseStub(401, 'Unauthorized', '{"foo": "bar"}', [] ) + + # execute / verify + with self.assertRaises(LoginException) as _: + sf_api = api.Api(auth, '55.0') + _ = sf_api.query( + session, + 'SELECT+LogFile+FROM+EventLogFile', + ) + + self.assertEqual( + session.url, + f'https://my.salesforce.test/services/data/v55.0/query?q=SELECT+LogFile+FROM+EventLogFile', + ) + self.assertTrue('Authorization' in session.headers) + self.assertEqual(session.headers['Authorization'], 'Bearer 123456') + self.assertFalse(session.stream) + + def test_query_raises_salesforce_exception_if_get_does(self): + ''' + query() calls the correct query API url with the access token and raises SalesforceApiException if get does + given: an authenticator + and given: an api version + and given: a session + and given: a query + when: query() is called + then: session.get() is called with correct URL and access token + and: stream is set to False + and when: session.get() raises a SalesforceApiException + then: query() raises a SalesforceApiException + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + ) + session = SessionStub() + session.response = ResponseStub(500, 'ServerError', '{"foo": "bar"}', [] ) + + # execute / verify + with self.assertRaises(SalesforceApiException) as _: + sf_api = api.Api(auth, '55.0') + _ = sf_api.query( + session, + 'SELECT+LogFile+FROM+EventLogFile', + ) + + self.assertEqual( + session.url, + f'https://my.salesforce.test/services/data/v55.0/query?q=SELECT+LogFile+FROM+EventLogFile', + ) + self.assertTrue('Authorization' in session.headers) + self.assertEqual(session.headers['Authorization'], 'Bearer 123456') + self.assertFalse(session.stream) + + def test_get_log_file_requests_correct_url_with_access_token_and_returns_generator_on_success(self): + ''' + get_log_file() calls the correct url with the access token and returns a generator iterator + given: an authenticator + and given: an api version + and given: a session + and given: a log file path + and given: a chunk size + when: get_log_file() is called + then: session.get() is called with correct URL and access token + and: stream is set to True + and when: response status code is 200 + and when: get() returns a response + then: calls callback with response + and: iter_lines is called with the correct chunk size + and: returns a generator iterator + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + ) + session = SessionStub() + session.response = ResponseStub( + 200, + 'OK', + '', + [ 'COL1,COL2,COL3', 'foo,bar,baz' ], + ) + + # execute + sf_api = api.Api(auth, '55.0') + resp = sf_api.get_log_file( + session, + '/services/data/v52.0/sobjects/EventLogFile/00001111AAAABBBB/LogFile', + chunk_size=8192, + ) + + # verify + self.assertEqual( + session.url, + f'https://my.salesforce.test/services/data/v52.0/sobjects/EventLogFile/00001111AAAABBBB/LogFile', + ) + self.assertTrue('Authorization' in session.headers) + self.assertEqual(session.headers['Authorization'], 'Bearer 123456') + self.assertTrue(session.stream) + self.assertIsNotNone(resp) + line = next(resp) + self.assertEqual('COL1,COL2,COL3', line) + line = next(resp) + self.assertEqual('foo,bar,baz', line) + line = next(resp, None) + self.assertIsNone(line) + # NOTE: this has to be done _after_ the generator iterator is called at + # least once since the generator function is not run until the first + # call to next() + self.assertTrue(session.response.iter_lines_called) + self.assertEqual(session.response.chunk_size, 8192) + + def test_get_log_file_raises_login_exception_if_get_does(self): + ''' + get_log_file() calls the correct query API url with the access token and raises a LoginException if get does + given: an authenticator + and given: an api version + and given: a session + and given: a log file path + and given: a chunk size + when: get_log_file() is called + then: session.get() is called with correct URL and access token + and: stream is set to True + and when: get() raises a LoginException + then: get_log_file() raises a LoginException + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + raise_login_error=True, + ) + session = SessionStub() + session.response = ResponseStub( + 401, + 'Unauthorized', + 'Unauthorized', + [ 'COL1,COL2,COL3', 'foo,bar,baz' ], + ) + + # execute / verify + with self.assertRaises(LoginException) as _: + sf_api = api.Api(auth, '55.0') + _ = sf_api.get_log_file( + session, + '/services/data/v52.0/sobjects/EventLogFile/00001111AAAABBBB/LogFile', + chunk_size=8192, + ) + + self.assertEqual( + session.url, + f'https://my.salesforce.test/services/data/v52.0/sobjects/EventLogFile/00001111AAAABBBB/LogFile', + ) + self.assertTrue('Authorization' in session.headers) + self.assertEqual(session.headers['Authorization'], 'Bearer 123456') + self.assertTrue(session.stream) + + def test_get_log_file_raises_salesforce_exception_if_get_does(self): + ''' + get_log_file() calls the correct query API url with the access token and raises a SalesforceApiException if get does + given: an authenticator + and given: an api version + and given: a session + and given: a log file path + and given: a chunk size + when: get_log_file() is called + then: session.get() is called with correct URL and access token + and: stream is set to True + and when: get() raises a SalesforceApiException + then: get_log_file() raises a SalesforceApiException + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + ) + session = SessionStub(raise_error=True) + session.response = ResponseStub( + 200, + 'OK', + '', + [ 'COL1,COL2,COL3', 'foo,bar,baz' ], + ) + + # execute / verify + with self.assertRaises(SalesforceApiException) as _: + sf_api = api.Api(auth, '55.0') + _ = sf_api.get_log_file( + session, + '/services/data/v52.0/sobjects/EventLogFile/00001111AAAABBBB/LogFile', + chunk_size=8192, + ) + + self.assertEqual( + session.url, + f'https://my.salesforce.test/services/data/v52.0/sobjects/EventLogFile/00001111AAAABBBB/LogFile', + ) + self.assertTrue('Authorization' in session.headers) + self.assertEqual(session.headers['Authorization'], 'Bearer 123456') + self.assertTrue(session.stream) + + +class TestApiFactory(unittest.TestCase): + def test_new_returns_api_with_correct_authenticator_and_version(self): + ''' + new() returns a new Api instance with the given authenticator and version + given: an authenticator + and given: an api version + when: new() is called + then: returns a new Api instance with the given authenticator and version + ''' + + # setup + auth = AuthenticatorStub( + instance_url='https://my.salesforce.test', + access_token='123456', + ) + + # execute + api_factory = api.ApiFactory() + sf_api = api_factory.new(auth, '55.0') + + # verify + self.assertEqual(sf_api.authenticator, auth) + self.assertEqual(sf_api.api_ver, '55.0') diff --git a/src/tests/test_cache.py b/src/tests/test_cache.py new file mode 100644 index 0000000..9d75950 --- /dev/null +++ b/src/tests/test_cache.py @@ -0,0 +1,728 @@ +from datetime import timedelta +from redis import RedisError +import unittest + +from newrelic_logging import cache, CacheException, config as mod_config +from . import RedisStub, BackendStub, BackendFactoryStub + +class TestRedisBackend(unittest.TestCase): + def test_exists(self): + ''' + backend exists returns redis exists + given: a redis instance + when: exists is called + then: the redis instance exists command result is returned + ''' + + # setup + redis = RedisStub({ 'foo': 'bar', 'beep': 'boop' }) + + # execute + backend = cache.RedisBackend(redis) + foo_exists = backend.exists('foo') + baz_exists = backend.exists('baz') + + # verify + self.assertTrue(foo_exists) + self.assertFalse(baz_exists) + + def test_exists_raises_when_redis_does(self): + ''' + backend exists raises error if redis exists does + given: a redis instance + when: exists is called + and when: the redis instance raises an exception + then: the redis instance exception is raised + ''' + + # setup + redis = RedisStub({ 'foo': 'bar', 'beep': 'boop' }, raise_error=True) + + # execute + backend = cache.RedisBackend(redis) + + # verify + with self.assertRaises(RedisError): + backend.exists('foo') + + def test_put(self): + ''' + backend put calls redis set + given: a redis instance + when: put is called + then: the redis instance set command is called + ''' + + # setup + redis = RedisStub({ 'foo': 'bar', 'beep': 'boop' }) + + # preconditions + self.assertFalse(redis.exists('r2')) + + # execute + backend = cache.RedisBackend(redis) + backend.put('r2', 'd2') + + # verify + self.assertTrue(redis.exists('r2')) + self.assertEqual(redis.test_cache['r2'], 'd2') + + def test_put_raises_if_redis_does(self): + ''' + backend put raises error if redis set does + given: a redis instance + when: put is called + and when: the redis instance raises an exception + then: the redis instance exception is raised + ''' + + # setup + redis = RedisStub({ 'foo': 'bar', 'beep': 'boop' }, raise_error=True) + + # execute + backend = cache.RedisBackend(redis) + + # verify + with self.assertRaises(RedisError) as _: + backend.put('r2', 'd2') + + + def test_get_set(self): + ''' + backend calls redis smembers + given: a redis instance + when: get_set is called + and when: key does not exist + or when: key exists and is set + then: the redis instance smembers command result is returned + ''' + + # setup + redis = RedisStub({ 'foo': set(['bar']) }) + + # execute + backend = cache.RedisBackend(redis) + foo_set = backend.get_set('foo') + beep_set = backend.get_set('beep') + + # verify + self.assertEqual(len(foo_set), 1) + self.assertEqual(foo_set, set(['bar'])) + self.assertEqual(len(beep_set), 0) + self.assertEqual(beep_set, set()) + + def test_get_set_raises_if_redis_does(self): + ''' + backend raises exception if redis smembers does + given: a redis instance + when: get_set is called + and when: redis instance raises an exception + then: the redis instance exception is raised + ''' + + # setup + redis = RedisStub({ 'foo': 'bar' }) + + # execute + backend = cache.RedisBackend(redis) + + # verify + with self.assertRaises(RedisError) as _: + backend.get_set('foo') + + def test_set_add(self): + ''' + backend set_add calls sadd + given: a redis instance + when: set_add is called + and when: key does not exist + or when: key exists and is set + then: the redis instance sadd command is called + ''' + + # setup + redis = RedisStub({ 'foo': set(['bar', 'beep', 'boop']) }) + + # execute + backend = cache.RedisBackend(redis) + backend.set_add('foo', 'baz') + backend.set_add('bop', 'biz') + + args = [1, 2, 3] + backend.set_add('bip', *args) + + # verify + self.assertEqual( + redis.test_cache['foo'], set(['bar', 'beep', 'boop', 'baz']), + ) + self.assertEqual(redis.test_cache['bop'], set(['biz'])) + self.assertEqual(redis.test_cache['bip'], set([1, 2, 3])) + + def test_set_add_raises_if_redis_does(self): + ''' + backend raises exception if redis sadd does + given: a redis instance + when: set_add is called + and when: redis instance raises an exception + then: the redis instance exception is raised + ''' + + # setup + redis = RedisStub({ 'foo': 'bar' }) + + # execute + backend = cache.RedisBackend(redis) + + # verify + with self.assertRaises(RedisError) as _: + backend.set_add('foo', 'beep') + + def test_set_expiry(self): + ''' + backend set_expiry calls expire + given: a redis instance + when: set_expiry is called + then: the redis instance expire command is called + ''' + + # setup + redis = RedisStub({ 'foo': set('bar') }) + + # execute + backend = cache.RedisBackend(redis) + backend.set_expiry('foo', 5) + + # verify + self.assertTrue('foo' in redis.expiry) + self.assertEqual(redis.expiry['foo'], timedelta(5)) + + def test_set_expiry_raises_if_redis_does(self): + ''' + backend set_expiry raises exception if redis expire does + given: a redis instance + when: set_expiry is called + and when: redis instance raises an exception + then: the redis instance exception is raised + ''' + + # setup + redis = RedisStub({ 'foo': set('bar') }, raise_error=True) + + # execute + backend = cache.RedisBackend(redis) + + # verify + with self.assertRaises(RedisError) as _: + backend.set_expiry('foo', 5) + + +class TestBufferedAddSetCache(unittest.TestCase): + def test_check_or_set_true_when_item_exists(self): + ''' + check_or_set returns true when item exists + given: a set + and when: the set contains the item to be checked + then: returns true + ''' + + # execute + s = cache.BufferedAddSetCache(set(['foo'])) + contains_foo = s.check_or_set('foo') + + # verify + self.assertTrue(contains_foo) + + def test_check_or_set_false_and_adds_item_when_item_missing(self): + ''' + check_or_set returns false and adds item when item is not in set + given: a set + when: the set does not contain the item to be checked + then: returns false + and then: the item is in the set + ''' + + # execute + s = cache.BufferedAddSetCache(set()) + contains_foo = s.check_or_set('foo') + + # verify + self.assertFalse(contains_foo) + self.assertTrue('foo' in s.get_buffer()) + + def test_check_or_set_checks_both_sets(self): + ''' + check_or_set checks both the given set and buffer for each check_or_set + given: a set + when: the set does not contain the item to be checked + then: returns false + and then: the item is in the set + and when: the item is added again + then: returns true + ''' + + # execute + s = cache.BufferedAddSetCache(set()) + contains_foo = s.check_or_set('foo') + + # verify + self.assertFalse(contains_foo) + self.assertTrue('foo' in s.get_buffer()) + + # execute + contains_foo = s.check_or_set('foo') + + # verify + # this verifies that check_or_set also checks the buffer + self.assertTrue(contains_foo) + + +class TestDataCache(unittest.TestCase): + def test_can_skip_download_logfile_true_when_key_exists(self): + ''' + dl logfile returns true if key exists in cache + given: a backend instance + when: can_skip_download_logfile is called + and when: the key exists in the cache + then: return true + ''' + + # setup + backend = BackendStub({ 'foo': ['bar'] }) + + # execute + data_cache = cache.DataCache(backend, 5) + can = data_cache.can_skip_downloading_logfile('foo') + + # verify + self.assertTrue(can) + + def test_can_skip_download_logfile_false_when_key_missing(self): + ''' + dl logfile returns false if key does not exist + given: a backend instance + when: can_skip_download_logfile is called + and when: the key does not exist in the backend + then: return false + ''' + + # setup + backend = BackendStub({}) + + # execute + data_cache = cache.DataCache(backend, 5) + can = data_cache.can_skip_downloading_logfile('foo') + + # verify + self.assertFalse(can) + + def test_can_skip_download_logfile_raises_if_backend_does(self): + ''' + dl logfile raises CacheException if backend raises any Exception + given: a backend instance + when: can_skip_download_logfile is called + and when: backend raises any exception + then: a CacheException is raised + ''' + + # setup + backend = BackendStub({}, raise_error=True) + + # execute + data_cache = cache.DataCache(backend, 5) + + # verify + with self.assertRaises(CacheException) as _: + data_cache.can_skip_downloading_logfile('foo') + + def test_check_or_set_log_line_true_when_exists(self): + ''' + check_or_set_log_line returns true when line ID is in the cached set + given: a backend instance + when: check_or_set_log_line is called + and when: row['REQUEST_ID'] is in the set for key + then: returns true + ''' + + # setup + backend = BackendStub({ 'foo': set(['bar']) }) + line = { 'REQUEST_ID': 'bar' } + + # execute + data_cache = cache.DataCache(backend, 5) + val = data_cache.check_or_set_log_line('foo', line) + + # verify + self.assertTrue(val) + + def test_check_or_set_log_line_false_and_adds_when_missing(self): + ''' + check_or_set_log_line returns false and adds line ID when line ID is not in the cached set + given: a backend instance + when: check_or_set_log_line is called + and when: row['REQUEST_ID'] is not in set for key + then: the line ID is added to the set and false is returned + ''' + + # setup + backend = BackendStub({ 'foo': set() }) + row = { 'REQUEST_ID': 'bar' } + + # preconditions + self.assertFalse('bar' in backend.redis.test_cache['foo']) + + # execute + data_cache = cache.DataCache(backend, 5) + val = data_cache.check_or_set_log_line('foo', row) + + # verify + self.assertFalse(val) + + # Need to flush before we can check the cache as how it is stored + # in memory is an implementation detail + data_cache.flush() + self.assertTrue('bar' in backend.redis.test_cache['foo']) + + def test_check_or_set_log_line_raises_if_backend_does(self): + ''' + check_or_set_log_line raises CacheException if backend raises any Exception + given: a backend instance + when: check_or_set_log_line is called + and when: backend raises any exception + then: a CacheException is raised + ''' + + # setup + backend = BackendStub({}, raise_error=True) + line = { 'REQUEST_ID': 'bar' } + + # execute / verify + data_cache = cache.DataCache(backend, 5) + + with self.assertRaises(CacheException) as _: + data_cache.check_or_set_log_line('foo', line) + + def test_check_or_set_event_id_true_when_exists(self): + ''' + check_or_set_event_id returns true when event ID is in the cached set + given: a backend instance + when: check_or_set_event_id is called + and when: event ID is in the set 'event_ids' + then: returns true + ''' + + # setup + backend = BackendStub({ 'event_ids': set(['foo']) }) + + # execute + data_cache = cache.DataCache(backend, 5) + val = data_cache.check_or_set_event_id('foo') + + # verify + self.assertTrue(val) + + def test_check_or_set_event_id_false_and_adds_when_missing(self): + ''' + check_or_set_event_id returns false and adds event ID when event ID is not in the cached set + given: a backend instance + when: check_or_set_event_id is called + and when: event ID is not in set 'event_ids' + then: the event ID is added to the set and False is returned + ''' + + # setup + backend = BackendStub({ 'event_ids': set() }) + + # preconditions + self.assertFalse('foo' in backend.redis.test_cache['event_ids']) + + # execute + data_cache = cache.DataCache(backend, 5) + val = data_cache.check_or_set_event_id('foo') + + # verify + self.assertFalse(val) + + # Need to flush before we can check the cache as how it is stored + # in memory is an implementation detail + data_cache.flush() + self.assertTrue('foo' in backend.redis.test_cache['event_ids']) + + def test_check_or_set_event_id_raises_if_backend_does(self): + ''' + check_or_set_event_id raises CacheException if backend raises any Exception + given: a backend instance + when: check_or_set_event_id is called + and when: backend raises any exception + then: a CacheException is raised + ''' + + # setup + backend = BackendStub({}, raise_error=True) + + # execute / verify + data_cache = cache.DataCache(backend, 5) + + with self.assertRaises(CacheException) as _: + data_cache.check_or_set_event_id('foo') + + def test_flush_does_not_affect_cache_when_add_buffers_empty(self): + ''' + backend cache is empty if flush is called when BufferedAddSet buffers are empty + given: a backend instance + when: flush is called + and when: the log lines and event ID add buffers are empty + then: the backend cache remains empty + ''' + + # setup + backend = BackendStub({}) + + # preconditions + self.assertEqual(len(backend.redis.test_cache), 0) + + # execute + data_cache = cache.DataCache(backend, 5) + data_cache.flush() + + # verify + self.assertEqual(len(backend.redis.test_cache), 0) + + def test_flush_writes_log_lines_when_add_buffer_not_empty(self): + ''' + flush writes any buffered log lines when log lines add buffer is not empty + given: a backend instance + when: flush is called + and when: the log lines add buffer is not empty + then: the backend cache is updated with items from the add buffer + ''' + + # setup + backend = BackendStub({}) + line1 = { 'REQUEST_ID': 'bar1' } + line2 = { 'REQUEST_ID': 'bar2' } + line3 = { 'REQUEST_ID': 'boop' } + + # preconditions + self.assertEqual(len(backend.redis.test_cache), 0) + + # execute + data_cache = cache.DataCache(backend, 5) + data_cache.check_or_set_log_line('foo', line1) + data_cache.check_or_set_log_line('foo', line2) + data_cache.check_or_set_log_line('beep', line3) + data_cache.flush() + + # verify + self.assertEqual(len(backend.redis.test_cache), 2) + self.assertTrue('foo' in backend.redis.test_cache) + self.assertEqual(backend.redis.test_cache['foo'], set(['bar1', 'bar2'])) + self.assertTrue('beep' in backend.redis.test_cache) + self.assertEqual(backend.redis.test_cache['beep'], set(['boop'])) + + def test_flush_writes_event_ids_when_add_buffer_not_empty(self): + ''' + flush writes any buffered event IDs when event IDs add buffer is not empty + given: a backend instance + when: flush is called + and when: the event IDs add buffer is not empty + then: the backend cache is updated with items from the add buffer + ''' + + # setup + backend = BackendStub({}) + + # preconditions + self.assertEqual(len(backend.redis.test_cache), 0) + + # execute + data_cache = cache.DataCache(backend, 5) + data_cache.check_or_set_event_id('foo') + data_cache.check_or_set_event_id('bar') + data_cache.check_or_set_event_id('beep') + data_cache.check_or_set_event_id('boop') + data_cache.flush() + + # verify + self.assertEqual(len(backend.redis.test_cache), 5) + self.assertTrue('event_ids' in backend.redis.test_cache) + self.assertEqual( + backend.redis.test_cache['event_ids'], + set(['foo', 'bar', 'beep', 'boop']) + ) + self.assertTrue('foo' in backend.redis.test_cache) + self.assertEqual(backend.redis.test_cache['foo'], 1) + self.assertTrue('bar' in backend.redis.test_cache) + self.assertEqual(backend.redis.test_cache['bar'], 1) + self.assertTrue('beep' in backend.redis.test_cache) + self.assertEqual(backend.redis.test_cache['beep'], 1) + self.assertTrue('boop' in backend.redis.test_cache) + self.assertEqual(backend.redis.test_cache['boop'], 1) + + def test_flush_sets_expiry_on_write(self): + ''' + flush sets the expiry time of any keys it writes + given: a backend instance and expiry + when: flush is called + and when: add buffers are not empty + then: the backend cache is updated with new cached sets with specified expiration time + ''' + + # setup + backend = BackendStub({}) + line1 = { 'REQUEST_ID': 'bar' } + + # preconditions + self.assertEqual(len(backend.redis.test_cache), 0) + self.assertEqual(len(backend.redis.expiry), 0) + + # execute + data_cache = cache.DataCache(backend, 5) + data_cache.check_or_set_log_line('foo', line1) + data_cache.check_or_set_event_id('bar') + data_cache.flush() + + # verify + self.assertEqual(len(backend.redis.test_cache), 3) + self.assertTrue('foo' in backend.redis.test_cache) + self.assertEqual(backend.redis.test_cache['foo'], set(['bar'])) + self.assertTrue('event_ids' in backend.redis.test_cache) + self.assertEqual(backend.redis.test_cache['event_ids'], set(['bar'])) + self.assertTrue('bar' in backend.redis.test_cache) + self.assertEqual(backend.redis.test_cache['bar'], 1) + self.assertEqual(len(backend.redis.expiry), 3) + self.assertTrue('foo' in backend.redis.expiry) + self.assertEqual(backend.redis.expiry['foo'], timedelta(days=5)) + self.assertTrue('bar' in backend.redis.expiry) + self.assertEqual(backend.redis.expiry['bar'], timedelta(days=5)) + self.assertTrue('event_ids' in backend.redis.expiry) + self.assertEqual(backend.redis.expiry['event_ids'], timedelta(days=5)) + + def test_flush_does_not_write_dups(self): + ''' + flush only writes items from add buffers + given: a backend instance + when: flush is called + and when: add buffers are not empty + then: the backend cache is updated with ONLY the items from the add buffers + ''' + + # setup + backend = BackendStub({ + 'foo': set(['bar', 'baz']), + 'beep': 1, + 'boop': 1, + 'event_ids': set(['beep', 'boop']) + }) + line1 = { 'REQUEST_ID': 'bip' } + line2 = { 'REQUEST_ID': 'bop' } + + # execute + data_cache = cache.DataCache(backend, 5) + data_cache.check_or_set_log_line('foo', line1) + data_cache.check_or_set_log_line('foo', line2) + data_cache.check_or_set_event_id('bim') + data_cache.check_or_set_event_id('bam') + data_cache.flush() + + # verify + self.assertEqual(len(backend.redis.test_cache), 6) + self.assertTrue('foo' in backend.redis.test_cache) + self.assertEqual( + backend.redis.test_cache['foo'], + set(['bar', 'baz', 'bip', 'bop']), + ) + self.assertTrue('event_ids' in backend.redis.test_cache) + self.assertEqual( + backend.redis.test_cache['event_ids'], + set(['beep', 'boop', 'bim', 'bam']), + ) + self.assertTrue('bim' in backend.redis.expiry) + self.assertEqual(backend.redis.expiry['bim'], timedelta(days=5)) + self.assertTrue('bam' in backend.redis.expiry) + self.assertEqual(backend.redis.expiry['bam'], timedelta(days=5)) + + def test_flush_raises_if_backend_does(self): + ''' + flush raises CacheException if backend raises any Exception + given: a backend instance + when: flush is called + and when: backend raises any exception + then: a CacheException is raised + ''' + + # setup + backend = BackendStub({}) + + # execute / verify + data_cache = cache.DataCache(backend, 5) + + # have to add data to be cached before setting raise_error + data_cache.check_or_set_event_id('foo') + + # now set error and execute/verify + backend.redis.raise_error = True + + with self.assertRaises(CacheException) as _: + data_cache.flush() + + +class TestCacheFactory(unittest.TestCase): + def test_new_returns_none_if_disabled(self): + ''' + new returns None if cache disabled + given: a backend factory and a configuration + when: new is called + and when: cache_enabled is False in config + then: None is returned + ''' + + # setup + config = mod_config.Config({ 'cache_enabled': 'false' }) + backend_factory = BackendFactoryStub() + + # execute + cache_factory = cache.CacheFactory(backend_factory) + data_cache = cache_factory.new(config) + + # verify + self.assertIsNone(data_cache) + + def test_new_returns_data_cache_with_backend_if_enabled(self): + ''' + new returns DataCache instance with backend from specified backend factory if cache enabled + given: a backend factory and a configuration + when: new is called + and when: cache_enabled is True in config + then: a DataCache is returned with the a backend from the specified backend factory + ''' + + # setup + config = mod_config.Config({ 'cache_enabled': 'true' }) + backend_factory = BackendFactoryStub() + + # execute + cache_factory = cache.CacheFactory(backend_factory) + data_cache = cache_factory.new(config) + + # verify + self.assertIsNotNone(data_cache) + self.assertTrue(type(data_cache.backend) is BackendStub) + + def test_new_raises_if_backend_factory_does(self): + ''' + new raises CacheException if backend factory does + given: a backend factory and a configuration + when: new is called + and when: backend factory new() raises any exception + then: a CacheException is raised + ''' + + # setup + config = mod_config.Config({ 'cache_enabled': 'true' }) + backend_factory = BackendFactoryStub(raise_error=True) + + # execute / verify + cache_factory = cache.CacheFactory(backend_factory) + + with self.assertRaises(CacheException): + _ = cache_factory.new(config) diff --git a/src/tests/test_integration.py b/src/tests/test_integration.py new file mode 100644 index 0000000..96e9666 --- /dev/null +++ b/src/tests/test_integration.py @@ -0,0 +1,467 @@ +import unittest + +from . import ApiFactoryStub, \ + AuthenticatorFactoryStub, \ + CacheFactoryStub, \ + NewRelicStub, \ + NewRelicFactoryStub, \ + PipelineFactoryStub, \ + QueryFactoryStub, \ + SalesForceFactoryStub +from newrelic_logging import ConfigException, \ + DataFormat, \ + config as mod_config, \ + integration + +class TestIntegration(unittest.TestCase): + def test_build_instance(self): + ''' + given: a Config instance, a NewRelic instance, a data format, a set of + event type field mappings, a set of numeric fields an initial + delay, an instance config, and a instance index + when: no prefix or queries are provided in the instance config + then: return a Salesforce instance configured with appropriate values + ''' + + # setup + config = mod_config.Config({ + 'instances': [ + { + 'name': 'test-inst-1', + 'labels': { + 'foo': 'bar', + 'beep': 'boop', + }, + 'arguments': { + 'token_url': 'https://my.salesforce.test/token', + 'cache_enabled': False, + }, + } + ] + }) + api_factory = ApiFactoryStub() + auth_factory = AuthenticatorFactoryStub() + cache_factory = CacheFactoryStub() + pipeline_factory = PipelineFactoryStub() + salesforce_factory = SalesForceFactoryStub() + query_factory = QueryFactoryStub() + new_relic = NewRelicStub(config) + event_type_fields_mapping = { 'event': ['field1'] } + numeric_fields_list = set(['field1', 'field2']) + + # execute + instance = integration.build_instance( + config, + auth_factory, + cache_factory, + pipeline_factory, + salesforce_factory, + api_factory, + query_factory, + new_relic, + DataFormat.EVENTS, + event_type_fields_mapping, + numeric_fields_list, + 603, + config['instances'][0], + 0, + ) + + # verify + self.assertTrue('client' in instance) + self.assertTrue('name' in instance) + self.assertEqual(instance['name'], 'test-inst-1') + client = instance['client'] + self.assertEqual(client.instance_name, 'test-inst-1') + inst_config = client.config + self.assertIsNotNone(inst_config) + self.assertEqual(inst_config.prefix, '') + self.assertTrue('token_url' in inst_config) + data_cache = client.data_cache + self.assertIsNotNone(data_cache) + self.assertIsNotNone(data_cache.config) + self.assertTrue('cache_enabled' in data_cache.config) + self.assertFalse(data_cache.config['cache_enabled']) + authenticator = client.authenticator + self.assertIsNotNone(authenticator) + self.assertEqual(authenticator.data_cache, data_cache) + self.assertIsNotNone(authenticator.config) + self.assertEqual( + authenticator.config['token_url'], + 'https://my.salesforce.test/token', + ) + p = client.pipeline + self.assertIsNotNone(p) + self.assertEqual(p.data_cache, data_cache) + self.assertEqual(p.new_relic, new_relic) + self.assertEqual(p.data_format, DataFormat.EVENTS) + self.assertIsNotNone(p.labels) + self.assertTrue(type(p.labels) is dict) + labels = p.labels + self.assertTrue('foo' in labels) + self.assertEqual(labels['foo'], 'bar') + self.assertTrue('beep' in labels) + self.assertEqual(labels['beep'], 'boop') + self.assertIsNotNone(p.event_type_fields_mapping) + self.assertTrue(type(p.event_type_fields_mapping) is dict) + event_type_fields_mapping = p.event_type_fields_mapping + self.assertTrue('event' in event_type_fields_mapping) + self.assertTrue(type(event_type_fields_mapping['event']) is list) + self.assertEqual(len(event_type_fields_mapping['event']), 1) + self.assertEqual(event_type_fields_mapping['event'][0], 'field1') + self.assertIsNotNone(p.numeric_fields_list) + self.assertTrue(type(p.numeric_fields_list) is set) + numeric_fields_list = p.numeric_fields_list + self.assertEqual(len(numeric_fields_list), 2) + self.assertTrue('field1' in numeric_fields_list) + self.assertTrue('field2' in numeric_fields_list) + self.assertEqual(client.query_factory, query_factory) + self.assertEqual(client.initial_delay, 603) + self.assertIsNone(client.queries) + + ''' + given: a Config instance, a NewRelic instance, a data format, a set of + event type field mappings, a set of numeric fields an initial + delay, an instance config, and a instance index + when: prefix provided in the instance config + then: the prefix should be set on the instance config object + ''' + + # setup + config = mod_config.Config({ + 'instances': [ + { + 'name': 'test-inst-1', + 'labels': { + 'foo': 'bar', + 'beep': 'boop', + }, + 'arguments': { + 'token_url': 'https://my.salesforce.test/token', + 'cache_enabled': False, + 'auth_env_prefix': 'ABCDEF_' + }, + } + ] + }) + api_factory = ApiFactoryStub() + auth_factory = AuthenticatorFactoryStub() + cache_factory = CacheFactoryStub() + pipeline_factory = PipelineFactoryStub() + salesforce_factory = SalesForceFactoryStub() + query_factory = QueryFactoryStub() + + # execute + instance = integration.build_instance( + config, + auth_factory, + cache_factory, + pipeline_factory, + salesforce_factory, + api_factory, + query_factory, + new_relic, + DataFormat.EVENTS, + event_type_fields_mapping, + numeric_fields_list, + 603, + config['instances'][0], + 0, + ) + + # verify + self.assertTrue('client' in instance) + client = instance['client'] + inst_config = client.config + self.assertIsNotNone(inst_config) + self.assertEqual(inst_config.prefix, 'ABCDEF_') + + ''' + given: a Config instance, a NewRelic instance, a data format, a set of + event type field mappings, a set of numeric fields an initial + delay, an instance config, and a instance index + when: queries provide in the config + then: queries should be set in the Salesforce client instance + ''' + + # setup + config = mod_config.Config({ + 'instances': [ + { + 'name': 'test-inst-1', + 'labels': { + 'foo': 'bar', + 'beep': 'boop', + }, + 'arguments': { + 'token_url': 'https://my.salesforce.test/token', + 'cache_enabled': False, + }, + } + ], + 'queries': [ + { + 'query': 'SELECT foo FROM Account' + } + ] + }) + api_factory = ApiFactoryStub() + auth_factory = AuthenticatorFactoryStub() + cache_factory = CacheFactoryStub() + pipeline_factory = PipelineFactoryStub() + salesforce_factory = SalesForceFactoryStub() + query_factory = QueryFactoryStub() + + # execute + instance = integration.build_instance( + config, + auth_factory, + cache_factory, + pipeline_factory, + salesforce_factory, + api_factory, + query_factory, + new_relic, + DataFormat.EVENTS, + event_type_fields_mapping, + numeric_fields_list, + 603, + config['instances'][0], + 0, + ) + + # verify + self.assertTrue('client' in instance) + client = instance['client'] + self.assertIsNotNone(client.queries) + self.assertEqual(len(client.queries), 1) + self.assertTrue('query' in client.queries[0]) + self.assertEqual(client.queries[0]['query'], 'SELECT foo FROM Account') + + def test_init(self): + ''' + given: a Config instance, set of factories, a data format, an event + type fields mapping, a set of numeric fields and an initial + delay + when: an integration instance is created + then: instances should be created with pipelines that use the correct + data format and a newrelic instance should be created with the + correct config instance + ''' + + # setup + config = mod_config.Config({ + 'instances': [ + { + 'name': 'test-inst-1', + 'labels': { + 'foo': 'bar', + 'beep': 'boop', + }, + 'arguments': { + 'token_url': 'https://my.salesforce.test/token', + 'cache_enabled': False, + }, + }, + { + 'name': 'test-inst-2', + 'labels': { + 'foo': 'bar', + 'beep': 'boop', + }, + 'arguments': { + 'token_url': 'https://my.salesforce.test/token', + 'cache_enabled': False, + }, + } + ], + 'newrelic': { + 'data_format': 'events', + } + }) + + api_factory = ApiFactoryStub() + auth_factory = AuthenticatorFactoryStub() + cache_factory = CacheFactoryStub() + pipeline_factory = PipelineFactoryStub() + salesforce_factory = SalesForceFactoryStub() + query_factory = QueryFactoryStub() + newrelic_factory = NewRelicFactoryStub() + event_type_fields_mapping = { 'event': ['field1'] } + numeric_fields_list = set(['field1', 'field2']) + + # execute + + i = integration.Integration( + config, + auth_factory, + cache_factory, + pipeline_factory, + salesforce_factory, + api_factory, + query_factory, + newrelic_factory, + event_type_fields_mapping, + numeric_fields_list, + 603 + ) + + # verify + + self.assertIsNotNone(i.instances) + self.assertTrue(type(i.instances) is list) + self.assertEqual(len(i.instances), 2) + self.assertTrue('client' in i.instances[0]) + self.assertTrue('name' in i.instances[0]) + self.assertEqual(i.instances[0]['name'], 'test-inst-1') + self.assertTrue('client' in i.instances[1]) + self.assertTrue('name' in i.instances[1]) + self.assertEqual(i.instances[1]['name'], 'test-inst-2') + client = i.instances[0]['client'] + self.assertEqual(client.pipeline.data_format, DataFormat.EVENTS) + client = i.instances[1]['client'] + self.assertEqual(client.pipeline.data_format, DataFormat.EVENTS) + self.assertIsNotNone(i.new_relic) + self.assertIsNotNone(i.new_relic.config) + self.assertEqual(i.new_relic.config, config) + + ''' + given: a Config instance, set of factories, a data format, an event + type fields mapping, a set of numeric fields and an initial + delay + when: an integration instance is created + then: an exception should be raised if an invalid data format is + passed + ''' + + # setup + config = mod_config.Config({ + 'instances': [ + { + 'name': 'test-inst-1', + 'labels': { + 'foo': 'bar', + 'beep': 'boop', + }, + 'arguments': { + 'token_url': 'https://my.salesforce.test/token', + 'cache_enabled': False, + }, + }, + ], + 'newrelic': { + 'data_format': 'invalid', + } + }) + + api_factory = ApiFactoryStub() + auth_factory = AuthenticatorFactoryStub() + cache_factory = CacheFactoryStub() + pipeline_factory = PipelineFactoryStub() + salesforce_factory = SalesForceFactoryStub() + query_factory = QueryFactoryStub() + newrelic_factory = NewRelicFactoryStub() + + # execute/verify + + with self.assertRaises(ConfigException): + i = integration.Integration( + config, + auth_factory, + cache_factory, + pipeline_factory, + salesforce_factory, + api_factory, + query_factory, + newrelic_factory, + event_type_fields_mapping, + numeric_fields_list, + 603 + ) + + ''' + given: a Config instance, set of factories, a data format, an event + type fields mapping, a set of numeric fields and an initial + delay + when: no instance configurations are provided + then: integration instances should be the empty set + ''' + + # setup + config = mod_config.Config({ + 'instances': [], + 'newrelic': { + 'data_format': 'logs', + } + }) + + api_factory = ApiFactoryStub() + auth_factory = AuthenticatorFactoryStub() + cache_factory = CacheFactoryStub() + pipeline_factory = PipelineFactoryStub() + salesforce_factory = SalesForceFactoryStub() + query_factory = QueryFactoryStub() + newrelic_factory = NewRelicFactoryStub() + + # execute + + i = integration.Integration( + config, + auth_factory, + cache_factory, + pipeline_factory, + salesforce_factory, + api_factory, + query_factory, + newrelic_factory, + event_type_fields_mapping, + numeric_fields_list, + 603 + ) + + # verify + + self.assertEqual(len(i.instances), 0) + + ''' + given: a Config instance, set of factories, a data format, an event + type fields mapping, a set of numeric fields and an initial + delay + when: instances property is missing + then: integration instances should be the empty set + ''' + + # setup + config = mod_config.Config({ + 'newrelic': { + 'data_format': 'logs', + } + }) + + api_factory = ApiFactoryStub() + auth_factory = AuthenticatorFactoryStub() + cache_factory = CacheFactoryStub() + pipeline_factory = PipelineFactoryStub() + salesforce_factory = SalesForceFactoryStub() + query_factory = QueryFactoryStub() + newrelic_factory = NewRelicFactoryStub() + + # execute + + i = integration.Integration( + config, + auth_factory, + cache_factory, + pipeline_factory, + salesforce_factory, + api_factory, + query_factory, + newrelic_factory, + event_type_fields_mapping, + numeric_fields_list, + 603 + ) + + # verify + + self.assertEqual(len(i.instances), 0) diff --git a/src/tests/test_pipeline.py b/src/tests/test_pipeline.py new file mode 100644 index 0000000..d79c6f7 --- /dev/null +++ b/src/tests/test_pipeline.py @@ -0,0 +1,1721 @@ +import copy +from datetime import datetime +import json +import pytz +import unittest + + +from newrelic_logging import \ + config, \ + DataFormat, \ + LoginException, \ + pipeline, \ + util, \ + SalesforceApiException +from . import \ + ApiStub, \ + DataCacheStub, \ + NewRelicStub, \ + QueryStub, \ + SessionStub + + +class TestPipeline(unittest.TestCase): + def setUp(self): + with open('./tests/sample_log_lines.csv') as stream: + self.log_rows = stream.readlines() + + with open('./tests/sample_log_lines.json') as stream: + self.log_lines = json.load(stream) + + with open('./tests/sample_event_records.json') as stream: + self.event_records = json.load(stream) + + with open('./tests/sample_log_records.json') as stream: + self.log_records = json.load(stream) + + def test_init_fields_from_log_line(self): + ''' + given: an event type, log line, and event fields mapping + when: there is no matching mapping for the event type in the event + fields mapping + then: copy all fields in log line + ''' + + # setup + log_line = { + 'foo': 'bar', + 'beep': 'boop', + } + + # execute + attrs = pipeline.init_fields_from_log_line('ApexCallout', log_line, {}) + + # verify + self.assertTrue(len(attrs) == 2) + self.assertTrue(attrs['foo'] == 'bar') + self.assertTrue(attrs['beep'] == 'boop') + + ''' + given: an event type, log line, and event fields mapping + when: there is a matching mapping for the event type in the event + fields mapping + then: copy only the fields in the event fields mapping + ''' + + # execute + attrs = pipeline.init_fields_from_log_line( + 'ApexCallout', + log_line, + { 'ApexCallout': ['foo'] } + ) + + # verify + self.assertTrue(len(attrs) == 1) + self.assertTrue(attrs['foo'] == 'bar') + self.assertTrue(not 'beep' in attrs) + + def test_get_log_line_timestamp(self): + ''' + given: a log line + when: there is no TIMESTAMP attribute + then: return the current timestamp + ''' + + # setup + now = datetime.utcnow().replace(microsecond=0) + + # execute + ts = pipeline.get_log_line_timestamp({}) + + # verify + self.assertEqual(now.timestamp(), ts) + + ''' + given: a log line + when: there is a TIMESTAMP attribute + then: parse the string in the format YYYYMMDDHHmmss.FFF and return + the representative timestamp + ''' + + # setup + epoch = now.strftime('%Y%m%d%H%M%S.%f') + + # execute + ts1 = pytz.utc.localize(now).replace(microsecond=0).timestamp() + ts2 = pipeline.get_log_line_timestamp({ 'TIMESTAMP': epoch }) + + # verify + self.assertEqual(ts1, ts2) + + def test_pack_log_line_into_log(self): + ''' + given: a query object, record ID, event type, log line, line number, + and event fields mapping + when: there is a TIMESTAMP and EVENT_TYPE field, no query options, and + no matching event mapping + then: return a log entry with the message "LogFile $ID row $LINENO", + and attributes dict containing all fields from the log line, + an EVENT_TYPE field with the log line event type, a TIMESTAMP + field with the log line timestamp, and a timestamp field with + the timestamp epoch value + ''' + + # setup + query = QueryStub({}) + + # execute + log = pipeline.pack_log_line_into_log( + query, + '00001111AAAABBBB', + 'ApexCallout', + self.log_lines[0], + 0, + {}, + ) + + # verify + self.assertTrue('message' in log) + self.assertTrue('attributes' in log) + self.assertEqual(log['message'], 'LogFile 00001111AAAABBBB row 0') + + attrs = log['attributes'] + self.assertEqual(attrs['EVENT_TYPE'], 'ApexCallout') + self.assertEqual(attrs['LogFileId'], '00001111AAAABBBB') + self.assertTrue('timestamp' in attrs) + ts = pipeline.get_log_line_timestamp({ + 'TIMESTAMP': '20240311160000.000' + }) + self.assertEqual(int(ts), attrs['timestamp']) + self.assertTrue('REQUEST_ID' in attrs) + self.assertTrue('RUN_TIME' in attrs) + self.assertTrue('CPU_TIME' in attrs) + self.assertEqual('YYZ:abcdef123456', attrs['REQUEST_ID']) + self.assertEqual('2112', attrs['RUN_TIME']) + self.assertEqual('10', attrs['CPU_TIME']) + + ''' + given: a query object, record ID, event type, log line, line number, + and event fields mapping + when: there is a TIMESTAMP and EVENT_TYPE field, no matching event + mapping, and the event_type and rename_timestamp query options + then: return the same as case 1 but with the event type specified + in the query options, the epoch value in the field specified + in the query options and no timestamp field + ''' + + # setup + query = QueryStub(config={ + 'event_type': 'CustomSFEvent', + 'rename_timestamp': 'custom_timestamp', + }) + + # execute + log = pipeline.pack_log_line_into_log( + query, + '00001111AAAABBBB', + 'ApexCallout', + self.log_lines[0], + 0, + {}, + ) + + # verify + attrs = log['attributes'] + self.assertTrue('custom_timestamp' in attrs) + self.assertEqual(attrs['EVENT_TYPE'], 'CustomSFEvent') + self.assertEqual(attrs['custom_timestamp'], int(ts)) + self.assertTrue(not 'timestamp' in log) + + def test_export_log_lines(self): + ''' + given: an Api instance, an http session, log file path, and chunk size + when: api.get_log_file() is called + and when: api.get_log_file() raises a LoginException + then: raise a LoginException + ''' + + # setup + api = ApiStub(raise_login_error=True) + session = SessionStub() + + # execute/verify + with self.assertRaises(LoginException): + lines = pipeline.export_log_lines(api, session, '', 100) + # Have to use next to cause the generator to execute the function + # else get_log_file() won't get executed and our stub won't have + # a chance to throw the fake exception. + next(lines) + + ''' + given: an Api instance, an http session, log file path, and chunk size + when: api.get_log_file() is called + and when: api.get_log_file() raises a SalesforceApiException + then: raise a SalesforceApiException + ''' + + # setup + api = ApiStub(raise_error=True) + session = SessionStub() + + # execute/verify + with self.assertRaises(SalesforceApiException): + lines = pipeline.export_log_lines(api, session, '', 100) + # Have to use next to cause the generator to execute the function + # else get_log_file() won't get executed and our stub won't have + # a chance to throw the fake exception. + next(lines) + + ''' + given: an Api instance, an http session, log file path, and chunk size + when: the response produces a 200 status code + then: return a generator iterator that yields one line of data at a time + ''' + + # setup + api = ApiStub(lines=self.log_rows) + session = SessionStub() + + # execute + response = pipeline.export_log_lines(api, session, '', 100) + + lines = [] + for line in response: + lines.append(line) + + # verify + self.assertEqual(len(lines), 3) + self.assertEqual(lines[0], self.log_rows[0]) + self.assertEqual(lines[1], self.log_rows[1]) + self.assertEqual(lines[2], self.log_rows[2]) + + def test_transform_log_lines(self): + ''' + given: an iterable of log rows, query, record id, event type, event + types mapping and no data cache + when: the response produces a 200 status code + then: return a generator iterator that yields one New Relic log + object for each row except the header row + ''' + + # execute + logs = pipeline.transform_log_lines( + self.log_rows, + QueryStub({}), + '00001111AAAABBBB', + 'ApexCallout', + None, + {}, + ) + + l = [] + + for log in logs: + l.append(log) + + # verify + self.assertEqual(len(l), 2) + self.assertTrue('message' in l[0]) + self.assertTrue('attributes' in l[0]) + self.assertEqual(l[0]['message'], 'LogFile 00001111AAAABBBB row 0') + attrs = l[0]['attributes'] + self.assertTrue('EVENT_TYPE' in attrs) + self.assertTrue('timestamp' in attrs) + self.assertEqual(1710172800, attrs['timestamp']) + + self.assertTrue('message' in l[1]) + self.assertTrue('attributes', l[1]) + self.assertEqual(l[1]['message'], 'LogFile 00001111AAAABBBB row 1') + attrs = l[1]['attributes'] + self.assertTrue('EVENT_TYPE' in attrs) + self.assertTrue('timestamp' in attrs) + self.assertEqual(1710176400, attrs['timestamp']) + + ''' + given: an iterable of log rows, query, record id, event type, event + types mapping and a data cache + when: the data cache contains the REQUEST_ID for some of the log lines + then: return a generator iterator that yields one New Relic log + object for each row with a REQUEST_ID + ''' + + # execute + logs = pipeline.transform_log_lines( + self.log_rows, + QueryStub({}), + '00001111AAAABBBB', + 'ApexCallout', + DataCacheStub(cached_logs={ + '00001111AAAABBBB': [ 'YYZ:abcdef123456' ] + }), + {}, + ) + + l = [] + + for log in logs: + l.append(log) + + # verify + self.assertEqual(len(l), 1) + self.assertEqual( + l[0]['attributes']['REQUEST_ID'], + 'YYZ:fedcba654321' + ) + + + def test_pack_event_record_into_log(self): + created_date = self.event_records[0]['CreatedDate'] + timestamp = util.get_timestamp(created_date) + + base_expected_attrs = { + 'Id': '00001111AAAABBBB', + 'Name': 'My Account', + 'BillingCity': None, + 'CreatedDate': created_date, + 'CreatedBy.Name': 'Foo Bar', + 'CreatedBy.Profile.Name': 'Beep Boop', + 'CreatedBy.UserType': 'Bip Bop', + 'EVENT_TYPE': 'Account', + 'timestamp': timestamp, + } + + ''' + given: a query, record id and event record + when: there are no query options, the record id is not None, and the + event record contains a 'type' field in the 'attributes' field + then: return a log with the 'message' attribute set to the event type + specified in the record's 'attributes.type' field + the created + date; where the 'attributes' attribute contains all attributes + according to process_query_result, an 'Id' attribute set to the + passed record ID, an 'EVENT_TYPE' attribute set to the event type + specified in the record's 'attributes.type' field, and a + 'timestamp' attribute set to the epoch value representing the + record's 'CreatedDate' field; and with the 'timestamp' attribute + also set to the epoch value representing the record's + 'CreatedDate' field. + ''' + + # setup + query = QueryStub({}) + + # execute + log = pipeline.pack_event_record_into_log( + query, + '00001111AAAABBBB', + self.event_records[0] + ) + + # verify + self.assertTrue('message' in log) + self.assertTrue('attributes' in log) + self.assertTrue('timestamp' in log) + self.assertEqual(log['message'], f'Account {created_date}') + self.assertEqual(log['attributes'], base_expected_attrs) + self.assertEqual(log['timestamp'], timestamp) + + ''' + given: a query, record id, and an event record + when: there are no query options, the record id is None, and the + event record contains a 'type' field in the 'attributes' field + then: return a log as in use case 1 but with the 'Id' value in the + log 'attributes' attribute set to the 'Id' value from the event + record + ''' + + # setup + event_record = copy.deepcopy(self.event_records[0]) + expected_attrs = copy.deepcopy(base_expected_attrs) + expected_attrs['Id'] = '000012345' + + # execute + log = pipeline.pack_event_record_into_log( + query, + None, + event_record + ) + + # verify + self.assertTrue('message' in log) + self.assertTrue('attributes' in log) + self.assertTrue('timestamp' in log) + self.assertEqual(log['message'], f'Account {created_date}') + self.assertEqual(log['attributes'], expected_attrs) + self.assertEqual(log['timestamp'], timestamp) + + ''' + given: a query, record id, and an event record + when: the 'event_type' query option is specified, the record id is not + None, and the event record contains a 'type' field in the + 'attributes' field + then: return a log as in use case 1 but with the event type in the log + 'message' attribute set to the custom event type specified in the + 'event_type' query option plus the created date, and with the + 'EVENT_TYPE' attribute in the log 'attributes' attribute set to + the custom event type + ''' + + # setup + event_record = copy.deepcopy(self.event_records[0]) + expected_attrs = copy.deepcopy(base_expected_attrs) + expected_attrs['EVENT_TYPE'] = 'CustomEvent' + + # execute + log = pipeline.pack_event_record_into_log( + QueryStub(config={ 'event_type': 'CustomEvent' }), + '00001111AAAABBBB', + event_record + ) + + # verify + self.assertTrue('message' in log) + self.assertTrue('attributes' in log) + self.assertTrue('timestamp' in log) + self.assertEqual(log['message'], f'CustomEvent {created_date}') + self.assertEqual(log['attributes'], expected_attrs) + self.assertEqual(log['timestamp'], timestamp) + + ''' + given: a query, record id, and an event record + when: there are no query options, the record id is not None, and the + event record does not contain an 'attributes' field + then: return a log as in use case 1 but with the event type in the log + 'message' attribute set to the default event type plus the created + date, and with no 'EVENT_TYPE' attribute in the log 'attributes' + attribute + ''' + + # setup + event_record = copy.deepcopy(self.event_records[0]) + event_record.pop('attributes') + expected_attrs = copy.deepcopy(base_expected_attrs) + expected_attrs.pop('EVENT_TYPE') + + # execute + log = pipeline.pack_event_record_into_log( + query, + '00001111AAAABBBB', + event_record + ) + + # verify + self.assertTrue('message' in log) + self.assertTrue('attributes' in log) + self.assertTrue('timestamp' in log) + self.assertEqual(log['message'], f'SFEvent {created_date}') + self.assertEqual(log['attributes'], expected_attrs) + self.assertEqual(log['timestamp'], timestamp) + + ''' + given: a query, record id, and an event record + when: there are no query options, the record id is not None, and the + event record does contain an 'attributes' field but it is not a + dictionary + then: return a log as in the previous use case + ''' + + # setup + event_record = copy.deepcopy(self.event_records[0]) + event_record['attributes'] = 'test' + expected_attrs = copy.deepcopy(base_expected_attrs) + expected_attrs.pop('EVENT_TYPE') + + # execute + log = pipeline.pack_event_record_into_log( + query, + '00001111AAAABBBB', + event_record + ) + + # verify + self.assertTrue('message' in log) + self.assertTrue('attributes' in log) + self.assertTrue('timestamp' in log) + self.assertEqual(log['message'], f'SFEvent {created_date}') + self.assertEqual(log['attributes'], expected_attrs) + self.assertEqual(log['timestamp'], timestamp) + + ''' + given: a query, record id, and an event record + when: there are no query options, the record id is not None, and the + event record does not contain a 'type' field in the 'attributes' + field + then: return a log as in the previous use case + ''' + + # setup + event_record = copy.deepcopy(self.event_records[0]) + event_record['attributes'].pop('type') + expected_attrs = copy.deepcopy(base_expected_attrs) + expected_attrs.pop('EVENT_TYPE') + + # execute + log = pipeline.pack_event_record_into_log( + query, + '00001111AAAABBBB', + event_record + ) + + # verify + self.assertTrue('message' in log) + self.assertTrue('attributes' in log) + self.assertTrue('timestamp' in log) + self.assertEqual(log['message'], f'SFEvent {created_date}') + self.assertEqual(log['attributes'], expected_attrs) + self.assertEqual(log['timestamp'], timestamp) + + ''' + given: a query, record id, and an event record + when: there are no query options, the record id is not None, and the + event record does contain a 'type' field in the 'attributes' + field but it is not a string + then: return a log as in the previous use case + ''' + + # setup + event_record = copy.deepcopy(self.event_records[0]) + event_record['attributes']['type'] = 12345 + expected_attrs = copy.deepcopy(base_expected_attrs) + expected_attrs.pop('EVENT_TYPE') + + # execute + log = pipeline.pack_event_record_into_log( + query, + '00001111AAAABBBB', + event_record + ) + + # verify + self.assertTrue('message' in log) + self.assertTrue('attributes' in log) + self.assertTrue('timestamp' in log) + self.assertEqual(log['message'], f'SFEvent {created_date}') + self.assertEqual(log['attributes'], expected_attrs) + self.assertEqual(log['timestamp'], timestamp) + + ''' + given: a query, record id, and an event record + when: the 'timestamp_attr' query option is specified, the field + specified in the 'timestamp_attr' query option exists in the event + record, the record id is not None, and the event record contains a + 'type' field in the 'attributes' field + then: return a log as in use case 1 but using the timestamp from the + field specified in the 'timestamp_attr' query option. + ''' + + # setup + __now = datetime.now() + + def _now(): + nonlocal __now + return __now + + util._NOW = _now + + created_date_2 = self.event_records[1]['CreatedDate'] + timestamp = util.get_timestamp(created_date_2) + + event_record = copy.deepcopy(self.event_records[0]) + event_record['CustomDate'] = created_date_2 + expected_attrs = copy.deepcopy(base_expected_attrs) + expected_attrs['CustomDate'] = created_date_2 + expected_attrs['timestamp'] = timestamp + + # execute + log = pipeline.pack_event_record_into_log( + QueryStub(config={ 'timestamp_attr': 'CustomDate' }), + '00001111AAAABBBB', + event_record + ) + + # verify + self.assertTrue('message' in log) + self.assertTrue('attributes' in log) + self.assertTrue('timestamp' in log) + self.assertEqual(log['message'], f'Account {created_date_2}') + self.assertEqual(log['attributes'], expected_attrs) + self.assertEqual(log['timestamp'], timestamp) + + ''' + given: a query, record id, and an event record + when: the 'timestamp_attr' query option is specified but the specified + attribute name is not in the event record, the record id is not + None, and the event record contains a 'type' field in the + 'attributes' field + then: return a log as in use case 1 but the log 'message' attribute does + not contain a date, the 'timestamp' attribute of the log + 'attributes' attribute is set to the current time, and with the + log 'timestamp' attribute set to the current time. + ''' + + # setup + event_record = copy.deepcopy(self.event_records[0]) + expected_attrs = copy.deepcopy(base_expected_attrs) + timestamp = util.get_timestamp() + expected_attrs['timestamp'] = timestamp + + # execute + log = pipeline.pack_event_record_into_log( + QueryStub(config={ 'timestamp_attr': 'NotPresent' }), + '00001111AAAABBBB', + event_record + ) + + # verify + self.assertTrue('message' in log) + self.assertTrue('attributes' in log) + self.assertTrue('timestamp' in log) + self.assertEqual(log['message'], f'Account') + self.assertEqual(log['attributes'], expected_attrs) + self.assertEqual(log['timestamp'], timestamp) + + ''' + given: a query, record id, and an event record + when: no query options are specified, the record id is not None, and the + event record does not contain a 'CreatedDate' field + then: return the same as the previous use case + ''' + + # setup + event_record = copy.deepcopy(self.event_records[0]) + event_record.pop('CreatedDate') + expected_attrs = copy.deepcopy(base_expected_attrs) + expected_attrs.pop('CreatedDate') + timestamp = util.get_timestamp() + expected_attrs['timestamp'] = timestamp + + # execute + log = pipeline.pack_event_record_into_log( + query, + '00001111AAAABBBB', + event_record + ) + + # verify + self.assertTrue('message' in log) + self.assertTrue('attributes' in log) + self.assertTrue('timestamp' in log) + self.assertEqual(log['message'], f'Account') + self.assertEqual(log['attributes'], expected_attrs) + self.assertEqual(log['timestamp'], timestamp) + + ''' + given: a query, record id, and an event record + when: the 'rename_timestamp' query option is specified, the record id is + not None, and the event record contains a 'type' field in the + 'attributes' field + then: return the same as use case 1 but with an attribute with the name + specified in the 'rename_timestamp' query option set to the + created date in the log 'attributes' attribute, no 'timestamp' + attribute in the log 'attributes' attribute, and with no log + 'timestamp' attribute + ''' + + # setup + event_record = copy.deepcopy(self.event_records[0]) + expected_attrs = copy.deepcopy(base_expected_attrs) + timestamp = util.get_timestamp(created_date) + expected_attrs['custom_timestamp'] = timestamp + expected_attrs.pop('timestamp') + + # execute + log = pipeline.pack_event_record_into_log( + QueryStub(config={ 'rename_timestamp': 'custom_timestamp' }), + '00001111AAAABBBB', + event_record + ) + + # verify + self.assertTrue('message' in log) + self.assertTrue('attributes' in log) + self.assertTrue(not 'timestamp' in log) + self.assertEqual(log['message'], f'Account {created_date}') + self.assertEqual(log['attributes'], expected_attrs) + + def test_transform_event_records(self): + ''' + given: an event record, query, and no data cache + when: the record contains an 'Id' field + then: return a log with the 'Id' attribute in the log 'attributes' set + to the value of the 'Id' attribute + when: the record does not contain an 'Id' field and the 'id' query + option is not set + then: return a log with no 'Id' attribute in the log 'attributes' field + ''' + + # execute + logs = pipeline.transform_event_records( + self.event_records, + QueryStub({}), + None, + ) + + l = [] + + for log in logs: + l.append(log) + + # verify + self.assertEqual(len(l), 3) + self.assertTrue('Id' in l[0]['attributes']) + self.assertTrue('Id' in l[1]['attributes']) + self.assertEqual(l[0]['attributes']['Id'], '000012345') + self.assertEqual(l[1]['attributes']['Id'], '000054321') + self.assertTrue(not 'Id' in l[2]['attributes']) + + ''' + given: an event record, query, and no data cache + when: the record does not contain an 'Id' field, the 'id' query option + is set and the record contains the field set in the 'id' query + option + then: return a log with a generated id + ''' + + # execute + logs = pipeline.transform_event_records( + self.event_records[2:], + QueryStub(config={ 'id': ['Name'] }), + None, + ) + + l = [] + + for log in logs: + l.append(log) + + # verify + customId = util.generate_record_id([ 'Name' ], self.event_records[2]) + self.assertTrue('Id' in l[0]['attributes']) + self.assertEqual(l[0]['attributes']['Id'], customId) + + ''' + given: an event record, query, and a data cache + when: the data cache contains cached record IDs + then: return only logs with record ids that do not match those in the + cache + ''' + + # execute + logs = pipeline.transform_event_records( + self.event_records, + QueryStub({}), + DataCacheStub( + cached_logs={}, + cached_events=[ '000012345', '000054321' ] + ), + ) + + l = [] + + for log in logs: + l.append(log) + + # verify + self.assertEqual(len(l), 1) + self.assertEqual( + l[0]['attributes']['Name'], + 'My Last Account', + ) + + def test_load_as_logs(self): + def logs(n = 50): + for i in range(0, n): + yield { + 'message': f'log {i}', + 'attributes': { + 'EVENT_TYPE': f'SFEvent{i}', + 'REQUEST_ID': f'abcdef-{i}', + }, + } + + ''' + given: an generator iterator of logs, newrelic instance, set of labels, + and a max rows value + when: their are less than the maximum number of rows, n + then: a single Logs API post should be made with a 'common' property + that contains all the labels and a 'logs' property that contains + n log entries + ''' + + # setup + labels = { 'foo': 'bar' } + newrelic = NewRelicStub() + + # preconditions + self.assertEqual(len(newrelic.logs), 0) + + # execute + pipeline.load_as_logs( + logs(), + newrelic, + labels, + pipeline.DEFAULT_MAX_ROWS, + ) + + # verify + self.assertEqual(len(newrelic.logs), 1) + l = newrelic.logs[0] + self.assertEqual(len(l), 1) + l = l[0] + self.assertTrue('logs' in l) + self.assertTrue('common' in l) + self.assertTrue(type(l['common']) is dict) + self.assertTrue('foo' in l['common']) + self.assertEqual(l['common']['foo'], 'bar') + self.assertEqual(len(l['logs']), 50) + for i, log in enumerate(l['logs']): + self.assertTrue('message' in log) + self.assertEqual(log['message'], f'log {i}') + self.assertTrue('attributes' in log) + self.assertTrue('EVENT_TYPE' in log['attributes']) + self.assertEqual(log['attributes']['EVENT_TYPE'], f'SFEvent{i}') + self.assertTrue('REQUEST_ID' in log['attributes']) + self.assertEqual(log['attributes']['REQUEST_ID'], f'abcdef-{i}') + + ''' + given: an generator iterator of logs, newrelic instance, set of labels, + and a max rows value + when: their are more than the maximum number of rows, n + then: floor(n / max) Logs API posts should be made each containing max + logs in the 'logs' property and a 'common' property that contains + all the labels. IFF n % max > 0, an additional Logs API post is + made containing n % max logs in the 'logs' property and a 'common' + property that contains all the labels. + ''' + + # setup + newrelic = NewRelicStub() + + # preconditions + self.assertEqual(len(newrelic.logs), 0) + + # execute + pipeline.load_as_logs( + logs(150), + newrelic, + labels, + 50, + ) + + # verify + self.assertEqual(len(newrelic.logs), 3) + for i in range(0, 3): + l = newrelic.logs[i] + self.assertEqual(len(l), 1) + l = l[0] + self.assertTrue('logs' in l) + self.assertTrue('common' in l) + self.assertEqual(len(l['logs']), 50) + + # setup + newrelic = NewRelicStub() + + # preconditions + self.assertEqual(len(newrelic.logs), 0) + + # execute + pipeline.load_as_logs( + logs(53), + newrelic, + labels, + 50, + ) + + # verify + self.assertEqual(len(newrelic.logs), 2) + l = newrelic.logs[0] + self.assertEqual(len(l), 1) + l = l[0] + self.assertTrue('logs' in l) + self.assertTrue('common' in l) + self.assertEqual(len(l['logs']), 50) + l = newrelic.logs[1] + self.assertEqual(len(l), 1) + l = l[0] + self.assertTrue('logs' in l) + self.assertTrue('common' in l) + self.assertEqual(len(l['logs']), 3) + + def test_pack_log_into_event(self): + ''' + given: a single log, set of labels, and a set of numeric field names + when: the set of numeric field names is the empty set and the log + contains an 'EVENT_TYPE' property + then: return a single event with a property for each attribute specified + in the 'attributes' field of the log, a property for each label, + and an 'eventType' property set to the value of the 'EVENT_TYPE' + property of the log entry + ''' + + # setup + log = { + 'message': 'Foo and Bar', + 'attributes': self.log_lines[0] + } + + # execute + event = pipeline.pack_log_into_event( + log, + { 'foo': 'bar' }, + set(), + ) + + # verify + self.assertTrue('eventType' in event) + self.assertEqual(event['eventType'], self.log_lines[0]['EVENT_TYPE']) + self.assertTrue('foo' in event) + self.assertEqual(event['foo'], 'bar') + self.assertEqual(len(event), len(self.log_lines[0]) + 2) + for k in self.log_lines[0]: + self.assertTrue(k in event) + self.assertEqual(event[k], self.log_lines[0][k]) + + ''' + given: a single log, set of labels, and a set of numeric field names + when: the set of numeric field names is not empty + then: return the same as use case 1 except each property in the returned + event matching a property in the numeric field names set is + converted to a number and non-numeric values are left as is. + ''' + + # execute + event = pipeline.pack_log_into_event( + log, + { 'foo': 'bar' }, + set(['RUN_TIME', 'CPU_TIME', 'SUCCESS', 'URI']), + ) + + # verify + self.assertEqual(len(event), len(self.log_lines[0]) + 2) + self.assertTrue(type(event['RUN_TIME']) == int) + self.assertTrue(type(event['CPU_TIME']) == int) + self.assertTrue(type(event['SUCCESS']) == int) + self.assertTrue(type(event['URI']) == str) + + ''' + given: a single log, set of labels, and a set of numeric field names + when: the set of numeric field names is the empty set and the log + does not contain an 'EVENT_TYPE' property + then: return the same as use case 1 except the 'eventType' attribute is + set to the default event name. + ''' + + # setup + log_lines = copy.deepcopy(self.log_lines) + + del log_lines[0]['EVENT_TYPE'] + + log['attributes'] = log_lines[0] + + # execute + event = pipeline.pack_log_into_event( + log, + { 'foo': 'bar' }, + set(), + ) + + # verify + self.assertEqual(len(event), len(log_lines[0]) + 2) + self.assertTrue('eventType' in event) + self.assertEqual(event['eventType'], 'UnknownSFEvent') + + def test_load_as_events(self): + def logs(n = 50): + for i in range(0, n): + yield { + 'message': f'log {i}', + 'attributes': { + 'EVENT_TYPE': f'SFEvent{i}', + 'REQUEST_ID': f'abcdef-{i}', + }, + } + + ''' + given: an generator iterator of logs, newrelic instance, set of labels, + max rows value, and a set of numeric field names + when: their are less than the maximum number of rows, n + then: a single Events API post should be made with n events + ''' + + # setup + labels = { 'foo': 'bar' } + newrelic = NewRelicStub() + + # preconditions + self.assertEqual(len(newrelic.events), 0) + + # execute + pipeline.load_as_events( + logs(), + newrelic, + labels, + pipeline.DEFAULT_MAX_ROWS, + set(), + ) + + # verify + self.assertEqual(len(newrelic.events), 1) + l = newrelic.events[0] + self.assertEqual(len(l), 50) + for i, event in enumerate(l): + self.assertTrue('foo' in event) + self.assertEqual(event['foo'], 'bar') + self.assertTrue('EVENT_TYPE' in event) + self.assertEqual(event['EVENT_TYPE'], f'SFEvent{i}') + self.assertTrue('REQUEST_ID' in event) + self.assertEqual(event['REQUEST_ID'], f'abcdef-{i}') + + ''' + given: an generator iterator of logs, newrelic instance, set of labels, + max rows value, and a set of numeric field names + when: their are more than the maximum number of rows, n + then: floor(n / max) Events API posts should be made each containing max + events. IFF n % max > 0, an additional Events API post is made + containing n % max events. + ''' + + # setup + newrelic = NewRelicStub() + + # preconditions + self.assertEqual(len(newrelic.logs), 0) + + # execute + pipeline.load_as_events( + logs(150), + newrelic, + labels, + 50, + set(), + ) + + # verify + self.assertEqual(len(newrelic.events), 3) + for i in range(0, 3): + l = newrelic.events[i] + self.assertEqual(len(l), 50) + + # setup + newrelic = NewRelicStub() + + # preconditions + self.assertEqual(len(newrelic.logs), 0) + + # execute + pipeline.load_as_events( + logs(53), + newrelic, + labels, + 50, + set(), + ) + + # verify + self.assertEqual(len(newrelic.events), 2) + l = newrelic.events[0] + self.assertEqual(len(l), 50) + l = newrelic.events[1] + self.assertEqual(len(l), 3) + + def test_load_data(self): + def logs(n = 50): + for i in range(0, n): + yield { + 'message': f'log {i}', + 'attributes': { + 'EVENT_TYPE': f'SFEvent{i}', + 'REQUEST_ID': f'abcdef-{i}', + }, + } + + ''' + given: a generator iterator of logs, newrelic instance, data format, + set of labels, max rows value, and a set of numeric field names + when: the data format is set to DataFormat.LOGS + then: log data is sent via the New Relic Logs API. + ''' + + # setup + newrelic = NewRelicStub() + + # preconditions + self.assertEqual(len(newrelic.logs), 0) + self.assertEqual(len(newrelic.events), 0) + + # execute + pipeline.load_data( + logs(), + newrelic, + DataFormat.LOGS, + { 'foo': 'bar' }, + 50, + set() + ) + + # verify + self.assertEqual(len(newrelic.logs), 1) + self.assertEqual(len(newrelic.events), 0) + self.assertEqual(len(newrelic.logs[0]), 1) + self.assertTrue('logs' in newrelic.logs[0][0]) + self.assertEqual(len(newrelic.logs[0][0]['logs']), 50) + + ''' + given: a generator iterator of logs, newrelic instance, data format, + set of labels, max rows value, and a set of numeric field names + when: the data format is set to DataFormat.EVENTS + then: log data is sent via the New Relic Events API. + ''' + + # setup + newrelic = NewRelicStub() + + # preconditions + self.assertEqual(len(newrelic.logs), 0) + self.assertEqual(len(newrelic.events), 0) + + # execute + pipeline.load_data( + logs(), + newrelic, + DataFormat.EVENTS, + { 'foo': 'bar' }, + 50, + set() + ) + + # verify + self.assertEqual(len(newrelic.logs), 0) + self.assertEqual(len(newrelic.events), 1) + self.assertEqual(len(newrelic.events[0]), 50) + + def test_pipeline_process_log_record(self): + ''' + given: an instance configuration, data cache, http session, newrelic + instance, data format, set of labels, event type fields mapping, + set of numeric field names, query, instance url, access token and + log record + when: the pipeline is configured with the configuration, session, + newrelic instance, data format, labels, event type fields mapping, + and numeric field names + and when: the data format is set to DataFormat.LOGS + and when: a log record is being processed + and when: the number of log lines to be processed is less than the + maximum number of rows + and when: no data cache is specified + then: a single Logs API post is made containing all labels in the + 'common' property of the logs post and one log for each exported + and transformed log line with the correct attributes from the + corresponding log line using the record ID, event type, and file + name from the given record + ''' + + # setup + api = ApiStub(lines=self.log_rows) + cfg = config.Config({}) + session = SessionStub() + newrelic = NewRelicStub() + query = QueryStub({}) + + p = pipeline.Pipeline( + cfg, + None, + newrelic, + DataFormat.LOGS, + { 'foo': 'bar' }, + {}, + set(), + ) + + record = self.log_records[0] + + # preconditions + self.assertEqual(len(newrelic.logs), 0) + + # execute + p.process_log_record( + api, + session, + query, + record, + ) + + # verify + self.assertEqual(len(newrelic.logs), 1) + l = newrelic.logs[0] + self.assertEqual(len(l), 1) + l = l[0] + self.assertTrue('logs' in l) + self.assertTrue('common' in l) + self.assertTrue(type(l['common']) is dict) + self.assertTrue('foo' in l['common']) + self.assertEqual(l['common']['foo'], 'bar') + self.assertEqual(len(l['logs']), 2) + + logs = l['logs'] + log0 = logs[0] + log1 = logs[1] + + self.assertTrue('message' in log0) + self.assertEqual(log0['message'], 'LogFile 00001111AAAABBBB row 0') + self.assertTrue('attributes' in log0) + self.assertTrue('EVENT_TYPE' in log0['attributes']) + self.assertEqual(log0['attributes']['EVENT_TYPE'], f'ApexCallout') + self.assertTrue('REQUEST_ID' in log0['attributes']) + self.assertEqual(log0['attributes']['REQUEST_ID'], f'YYZ:abcdef123456') + + self.assertTrue('message' in log1) + self.assertEqual(log1['message'], 'LogFile 00001111AAAABBBB row 1') + self.assertTrue('attributes' in log1) + self.assertTrue('EVENT_TYPE' in log1['attributes']) + self.assertEqual(log1['attributes']['EVENT_TYPE'], f'ApexCallout') + self.assertTrue('REQUEST_ID' in log0['attributes']) + self.assertEqual(log1['attributes']['REQUEST_ID'], f'YYZ:fedcba654321') + + ''' + given: the values from use case 1 + when: the pipeline is configured as in use case 1 + and when: export_log_lines() raises a LoginException + then: raise a LoginException + ''' + + # setup + api = ApiStub(raise_login_error=True) + newrelic = NewRelicStub() + + p = pipeline.Pipeline( + cfg, + None, + newrelic, + DataFormat.LOGS, + { 'foo': 'bar' }, + {}, + set(), + ) + + record = self.log_records[0] + + # execute / verify + with self.assertRaises(LoginException) as _: + p.process_log_record( + api, + session, + query, + record, + ) + + ''' + given: the values from use case 1 + when: the pipeline is configured as in use case 1 + and when: export_log_lines() raises a SalesforceApiException + then: raise a SalesforceApiException + ''' + + # setup + api = ApiStub(raise_error=True) + newrelic = NewRelicStub() + + p = pipeline.Pipeline( + cfg, + None, + newrelic, + DataFormat.LOGS, + { 'foo': 'bar' }, + {}, + set(), + ) + + record = self.log_records[0] + + # execute / verify + with self.assertRaises(SalesforceApiException) as _: + p.process_log_record( + api, + session, + query, + record, + ) + + ''' + given: the values from use case 1 + when: the pipeline is configured as in use case 1 + and when: a data cache is specified + and when: the record ID matches a record ID in the data cache + then: no log entries are sent + ''' + + # setup + api = ApiStub() + data_cache = DataCacheStub(skip_record_ids=['00001111AAAABBBB']) + newrelic = NewRelicStub() + + p = pipeline.Pipeline( + cfg, + data_cache, + newrelic, + DataFormat.LOGS, + { 'foo': 'bar' }, + {}, + set(), + ) + + record = self.log_records[0] + + # preconditions + self.assertEqual(len(newrelic.logs), 0) + + # execute + p.process_log_record( + api, + session, + query, + record, + ) + + # verify + self.assertEqual(len(newrelic.logs), 0) + + ''' + given: the values from use case 2 + when: the pipeline is configured as in use case 2 + and when: the data format is set to DataFormat.LOGS + and when: the number of log lines to be processed is less than the + maximum number of rows, + and when: a data cache is specified + and when: the record ID matches a record ID in the data cache + and when: the 'Interval' value of the record is set to 'Daily' + then: A Logs API post is made for the record anyway + ''' + + # setup + api = ApiStub(lines=self.log_rows) + data_cache = DataCacheStub(skip_record_ids=['00001111AAAABBBB']) + newrelic = NewRelicStub() + + p = pipeline.Pipeline( + cfg, + data_cache, + newrelic, + DataFormat.LOGS, + { 'foo': 'bar' }, + {}, + set(), + ) + + new_record = copy.deepcopy(record) + new_record['Interval'] = 'Daily' + + # preconditions + self.assertEqual(len(newrelic.logs), 0) + + # execute + p.process_log_record( + api, + session, + query, + new_record, + ) + + # verify + self.assertEqual(len(newrelic.logs), 1) + l = newrelic.logs[0] + self.assertEqual(len(l), 1) + l = l[0] + self.assertTrue('logs' in l) + self.assertTrue('common' in l) + self.assertTrue(type(l['common']) is dict) + self.assertTrue('foo' in l['common']) + self.assertEqual(l['common']['foo'], 'bar') + self.assertEqual(len(l['logs']), 2) + + ''' + given: the values from use case 3 + when: the pipeline is configured as in use case 3 + and when: the data format is set to DataFormat.LOGS + and when: the number of log lines to be processed is less than the + maximum number of rows, + and when: a data cache is specified + and when: the cache contains a list of log line IDs for the record ID + then: A Logs API post is made for the record containing log entries only + for log lines that have log line IDs that are not in the list of + cached log lines for the record ID + ''' + + # setup + api = ApiStub(lines=self.log_rows) + data_cache = DataCacheStub( + cached_logs={ + '00001111AAAABBBB': ['YYZ:abcdef123456'] + } + ) + newrelic = NewRelicStub() + + p = pipeline.Pipeline( + cfg, + data_cache, + newrelic, + DataFormat.LOGS, + { 'foo': 'bar' }, + {}, + set(), + ) + + record = self.log_records[0] + + # preconditions + self.assertEqual(len(newrelic.logs), 0) + + # execute + p.process_log_record( + api, + session, + query, + record, + ) + + # verify + self.assertEqual(len(newrelic.logs), 1) + l = newrelic.logs[0] + self.assertEqual(len(l), 1) + l = l[0] + self.assertTrue('logs' in l) + self.assertTrue('common' in l) + self.assertTrue(type(l['common']) is dict) + self.assertTrue('foo' in l['common']) + self.assertEqual(l['common']['foo'], 'bar') + self.assertEqual(len(l['logs']), 1) + + def test_pipeline_process_event_records(self): + ''' + given: an instance configuration, data cache, http session, newrelic + instance, data format, set of labels, event type fields mapping, + set of numeric field names, query, and a set of event records + when: the pipeline is configured with the configuration, session, + newrelic instance, data format, labels, event type fields mapping, + and numeric field names + and when: the data format is set to DataFormat.LOGS + and when: event records are being processed + and when: the number of event records to be processed is less than the + maximum number of rows + and when: no data cache is specified + then: a single Events API post is made containing all labels in the + 'common' property of the logs post and one log for each exported + and transformed event record with the correct attributes from the + corresponding event record + ''' + + # setup + cfg = config.Config({}) + newrelic = NewRelicStub() + query = QueryStub(config={ 'id': ['Name'] }) + + p = pipeline.Pipeline( + cfg, + None, + newrelic, + DataFormat.LOGS, + { 'foo': 'bar' }, + {}, + set(), + ) + + # preconditions + self.assertEqual(len(newrelic.logs), 0) + + # execute + p.process_event_records(query, self.event_records) + + # verify + self.assertEqual(len(newrelic.logs), 1) + l = newrelic.logs[0] + self.assertEqual(len(l), 1) + l = l[0] + self.assertTrue('logs' in l) + self.assertTrue('common' in l) + self.assertTrue(type(l['common']) is dict) + self.assertTrue('foo' in l['common']) + self.assertEqual(l['common']['foo'], 'bar') + self.assertEqual(len(l['logs']), 3) + + logs = l['logs'] + log0 = logs[0] + log1 = logs[1] + log2 = logs[2] + + self.assertTrue('message' in log0) + self.assertEqual(log0['message'], 'Account 2024-03-11T00:00:00.000+0000') + self.assertTrue('attributes' in log0) + self.assertTrue('Id' in log0['attributes']) + self.assertEqual(log0['attributes']['Id'], f'000012345') + self.assertTrue('Name' in log0['attributes']) + self.assertEqual(log0['attributes']['Name'], f'My Account') + + self.assertTrue('message' in log1) + self.assertEqual(log1['message'], 'Account 2024-03-10T00:00:00.000+0000') + self.assertTrue('attributes' in log1) + self.assertTrue('Id' in log1['attributes']) + self.assertEqual(log1['attributes']['Id'], f'000054321') + self.assertTrue('Name' in log1['attributes']) + self.assertEqual(log1['attributes']['Name'], f'My Other Account') + + customId = util.generate_record_id([ 'Name' ], self.event_records[2]) + + self.assertTrue('message' in log2) + self.assertEqual(log2['message'], 'Account 2024-03-09T00:00:00.000+0000') + self.assertTrue('attributes' in log2) + self.assertTrue('Id' in log2['attributes']) + self.assertEqual(log2['attributes']['Id'], customId) + self.assertTrue('Name' in log2['attributes']) + self.assertEqual(log2['attributes']['Name'], f'My Last Account') + + def test_pipeline_execute(self): + ''' + given: an instance configuration, data cache, http session, newrelic + instance, data format, set of labels, event type fields mapping, + set of numeric field names, query, and a set of query result + records + when: the pipeline is configured with the configuration, session, + newrelic instance, data format, labels, event type fields mapping, + and numeric field names + and when: the first record in the result set contains a 'LogFile' + attribute + and when: the number of log lines to be processed is less than the + maximum number of rows + and when: a data cache is specified + then: a single Logs API post is made containing all labels in the + 'common' property of the logs post and one log for each exported + and transformed log line, and the cache is flushed + ''' + + # setup + api = ApiStub(lines=self.log_rows) + cfg = config.Config({}) + session = SessionStub() + newrelic = NewRelicStub() + query = QueryStub({}) + data_cache = DataCacheStub() + + p = pipeline.Pipeline( + cfg, + data_cache, + newrelic, + DataFormat.LOGS, + { 'foo': 'bar' }, + {}, + set(), + ) + + # preconditions + self.assertEqual(len(newrelic.logs), 0) + + # execute + p.execute( + api, + session, + query, + self.log_records, + ) + + # verify + self.assertEqual(len(newrelic.logs), 2) + + for _, l in enumerate(newrelic.logs): + self.assertEqual(len(l), 1) + l = l[0] + self.assertTrue('logs' in l) + self.assertTrue('common' in l) + self.assertTrue(type(l['common']) is dict) + self.assertTrue('foo' in l['common']) + self.assertEqual(l['common']['foo'], 'bar') + self.assertEqual(len(l['logs']), 2) + + self.assertTrue(data_cache.flush_called) + + ''' + given: the values from use case 1 + when: the pipeline is configured with the configuration, session, + newrelic instance, data format, labels, event type fields mapping, + and numeric field names + and when: the first record in the result set contains a 'LogFile' + attribute + and when: process_log_record() raises a LoginException + then: raise a LoginException + ''' + + # setup + api = ApiStub(raise_login_error=True) + cfg = config.Config({}) + session = SessionStub() + newrelic = NewRelicStub() + query = QueryStub({}) + data_cache = DataCacheStub() + + p = pipeline.Pipeline( + cfg, + data_cache, + newrelic, + DataFormat.LOGS, + { 'foo': 'bar' }, + {}, + set(), + ) + + # execute / verify + with self.assertRaises(LoginException) as _: + p.execute( + api, + session, + query, + self.log_records, + ) + + ''' + given: the values from use case 1 + when: the pipeline is configured with the configuration, session, + newrelic instance, data format, labels, event type fields mapping, + and numeric field names + and when: the first record in the result set contains a 'LogFile' + attribute + and when: process_log_record() raises a SalesforceApiException + then: raise a SalesforceApiException + ''' + + # setup + api = ApiStub(raise_error=True) + cfg = config.Config({}) + session = SessionStub() + newrelic = NewRelicStub() + query = QueryStub({}) + data_cache = DataCacheStub() + + p = pipeline.Pipeline( + cfg, + data_cache, + newrelic, + DataFormat.LOGS, + { 'foo': 'bar' }, + {}, + set(), + ) + + # execute / verify + with self.assertRaises(SalesforceApiException) as _: + p.execute( + api, + session, + query, + self.log_records, + ) + + ''' + given: the values from use case 1 + when: the pipeline is configured as in use case 1 + and when: the first record in the result set does not contain a + 'LogFile' attribute + and when: a data cache is specified + and when: the number of event records to be processed is less than the + maximum number of rows + then: a single Logs API post is made containing all labels in the + 'common' property of the logs post and one log for each exported + and transformed event record, and the cache is flushed + ''' + api = ApiStub() + cfg = config.Config({}) + newrelic = NewRelicStub() + query = QueryStub({ 'id': ['Name'] }) + data_cache = DataCacheStub() + + p = pipeline.Pipeline( + cfg, + data_cache, + newrelic, + DataFormat.LOGS, + { 'foo': 'bar' }, + {}, + set(), + ) + + self.assertEqual(len(newrelic.logs), 0) + + p.execute( + api, + session, + query, + self.event_records, + ) + + self.assertEqual(len(newrelic.logs), 1) + l = newrelic.logs[0] + self.assertEqual(len(l), 1) + l = l[0] + self.assertTrue('logs' in l) + self.assertTrue('common' in l) + self.assertTrue(type(l['common']) is dict) + self.assertTrue('foo' in l['common']) + self.assertEqual(l['common']['foo'], 'bar') + self.assertEqual(len(l['logs']), 3) + + self.assertTrue(data_cache.flush_called) + + +if __name__ == '__main__': + unittest.main() diff --git a/src/tests/test_query.py b/src/tests/test_query.py new file mode 100644 index 0000000..2cc7645 --- /dev/null +++ b/src/tests/test_query.py @@ -0,0 +1,423 @@ +from datetime import datetime +import unittest + +from newrelic_logging import LoginException, SalesforceApiException +from newrelic_logging import config as mod_config, query, util +from . import \ + ApiStub, \ + SessionStub + +class TestQuery(unittest.TestCase): + def test_get_returns_backing_config_value_when_key_exists(self): + ''' + get() returns the value of the key in the backing config when the key exists + given: a query string, a configuration, and an api version + and given: a key + when: get is called with the key + then: returns value of the key from backing config + ''' + + # setup + api = ApiStub() + config = mod_config.Config({ 'foo': 'bar' }) + + # execute + q = query.Query( + api, + 'SELECT+LogFile+FROM+EventLogFile', + config, + ) + val = q.get('foo') + + # verify + self.assertEqual(val, 'bar') + + def test_get_returns_backing_config_default_when_key_missing(self): + ''' + get() returns the default value passed to the backing config.get when key does not exist in the backing config + given: a query string, a configuration, and an api version + when: get is called with a key and a default value + and when: the key does not exist in the backing config + then: returns default value passed to the backing config.get + ''' + + # setup + api = ApiStub() + config = mod_config.Config({}) + + # execute + q = query.Query( + api, + 'SELECT+LogFile+FROM+EventLogFile', + config, + ) + val = q.get('foo', 'beep') + + # verify + self.assertEqual(val, 'beep') + + def test_execute_raises_login_exception_if_api_query_does(self): + ''' + execute() raises a LoginException if api.query() raises a LoginException + given: an api instance, a query string, and a configuration + when: execute() is called with an http session + then: calls api.query() with the given session, query string, and no api version + and when: api.query() raises a LoginException (as a result of a reauthenticate) + then: raise a LoginException + ''' + + # setup + api = ApiStub(raise_login_error=True) + config = mod_config.Config({}) + session = SessionStub() + + # execute/verify + q = query.Query( + api, + 'SELECT+LogFile+FROM+EventLogFile', + config, + ) + + with self.assertRaises(LoginException) as _: + q.execute(session) + + def test_execute_raises_salesforce_api_exception_if_api_query_does(self): + ''' + execute() raises a SalesforceApiException if api.query() raises a SalesforceApiException + given: an api instance, a query string, and a configuration + when: execute() is called with an http session + then: calls api.query() with the given session, query string, and no api version + and when: api.query() raises a SalesforceApiException + then: raise a SalesforceApiException + ''' + + # setup + api = ApiStub(raise_error=True) + config = mod_config.Config({}) + session = SessionStub() + + # execute/verify + q = query.Query( + api, + 'SELECT+LogFile+FROM+EventLogFile', + config, + ) + + with self.assertRaises(SalesforceApiException) as _: + q.execute(session) + + def test_execute_calls_query_api_with_query_and_returns_result(self): + ''' + execute() calls api.query() with the given session, query string, and no api version and returns the result + given: an api instance, a query string, and a configuration + when: execute() is called with an http session + then: calls api.query() with the given session, query string, and no api version + and: returns query result + ''' + + # setup + api = ApiStub(query_result={ 'foo': 'bar' }) + config = mod_config.Config({}) + session = SessionStub() + + # execute + q = query.Query( + api, + 'SELECT+LogFile+FROM+EventLogFile', + config, + ) + + resp = q.execute(session) + + # verify + self.assertEqual('SELECT+LogFile+FROM+EventLogFile', api.soql) + self.assertIsNone(api.query_api_ver) + self.assertIsNotNone(resp) + self.assertTrue(type(resp) is dict) + self.assertTrue('foo' in resp) + self.assertEqual(resp['foo'], 'bar') + + def test_execute_calls_query_api_with_query_and_api_ver_and_returns_result(self): + ''' + execute() calls api.query() with the given session, query string, and api version and returns the result + given: an api instance, a query string, a configuration, and an api version + when: execute() is called with an http session + then: calls api.query() with the given session, query string, and api version + and: returns query result + ''' + + # setup + api = ApiStub(query_result={ 'foo': 'bar' }) + config = mod_config.Config({}) + session = SessionStub() + + # execute + q = query.Query( + api, + 'SELECT+LogFile+FROM+EventLogFile', + config, + '52.0', + ) + + resp = q.execute(session) + + # verify + self.assertEqual('SELECT+LogFile+FROM+EventLogFile', api.soql) + self.assertEqual(api.query_api_ver, '52.0') + self.assertIsNotNone(resp) + self.assertTrue(type(resp) is dict) + self.assertTrue('foo' in resp) + self.assertEqual(resp['foo'], 'bar') + + +class TestQueryFactory(unittest.TestCase): + def test_build_args_creates_expected_dict(self): + ''' + build_args() returns dictionary with expected properties + given: a query factory + when: build_args() is called with a time lag, timestamp, and generation interval + then: returns dict with expected properties + ''' + + # setup + _now = datetime.utcnow() + + def _utcnow(): + nonlocal _now + return _now + + util._UTCNOW = _utcnow + + time_lag_minutes = 500 + last_to_timestamp = util.get_iso_date_with_offset( + time_lag_minutes=time_lag_minutes * 2, + ) + to_timestamp = util.get_iso_date_with_offset( + time_lag_minutes=time_lag_minutes, + ) + + # execute + f = query.QueryFactory() + args = f.build_args(time_lag_minutes, last_to_timestamp, 'Daily') + + # verify + self.assertIsNotNone(args) + self.assertTrue(type(args) is dict) + self.assertTrue('to_timestamp' in args) + self.assertEqual(args['to_timestamp'], to_timestamp) + self.assertTrue('from_timestamp' in args) + self.assertEqual(args['from_timestamp'], last_to_timestamp) + self.assertTrue('log_interval_type' in args) + self.assertEqual(args['log_interval_type'], 'Daily') + + def test_get_env_returns_empty_dict_if_no_env(self): + ''' + get_env() returns an empty dict if env is not in the passed query dict + given: a query factory + when: get_env() is called with a query dict + and when: there is no env property in the query dict + then: returns an empty dict + ''' + + # setup + q = { 'query': 'SELECT LogFile FROM EventLogFile' } + + # execute + f = query.QueryFactory() + env = f.get_env(q) + + # verify + self.assertIsNotNone(env) + self.assertTrue(type(env) is dict) + self.assertEqual(len(env), 0) + + def test_get_env_returns_empty_dict_if_env_not_dict(self): + ''' + get_env() returns an empty dict if the passed query dict has an env property but it is not a dict + given: a query factory + when: get_env() is called with a query dict + and when: there is an env property in the query dict + and when: the env property is not a dict + then: returns an empty dict + ''' + + # setup + q = { 'query': 'SELECT LogFile FROM EventLogFile', 'env': 'foo' } + + # execute + f = query.QueryFactory() + env = f.get_env(q) + + # verify + self.assertIsNotNone(env) + self.assertTrue(type(env) is dict) + self.assertEqual(len(env), 0) + + def test_get_env_returns_env_dict_from_query_dict(self): + ''' + get_env() returns the env dict from the query dict when one is present + given: a query factory + when: get_env() is called with a query dict + and when: there is an env property in the query dict + and when: the env property is a dict + then: returns the env dict + ''' + + # setup + q = { + 'query': 'SELECT LogFile FROM EventLogFile', + 'env': { 'foo': 'bar' }, + } + + # execute + f = query.QueryFactory() + env = f.get_env(q) + + # verify + self.assertIsNotNone(env) + self.assertTrue(type(env) is dict) + self.assertEqual(len(env), 1) + self.assertTrue('foo' in env) + self.assertEqual(env['foo'], 'bar') + + def test_new_returns_query_obj_with_encoded_query_with_args_replaced(self): + ''' + new() returns a query instance with the given query with arguments replaced and URL encoded + given: a query factory + when: new() is called with an Api instance, query dict, lag time, timestamp, and generation interval + then: returns a query instance with the input query with arguments replaced and URL encoded + ''' + + # setup + _now = datetime.utcnow() + + def _utcnow(): + nonlocal _now + return _now + + util._UTCNOW = _utcnow + + api = ApiStub() + to_timestamp = util.get_iso_date_with_offset(time_lag_minutes=500) + last_to_timestamp = util.get_iso_date_with_offset(time_lag_minutes=1000) + now = _now.isoformat(timespec='milliseconds') + "Z" + env = { 'foo': 'now()' } + + # execute + f = query.QueryFactory() + q = f.new( + api, + { + 'query': 'SELECT LogFile FROM EventLogFile WHERE CreatedDate>={from_timestamp} AND CreatedDate<{to_timestamp} AND LogIntervalType={log_interval_type} AND Foo={foo}', + 'env': env, + }, + 500, + last_to_timestamp, + 'Daily', + ) + + # verify + self.assertEqual( + q.query, + f'SELECT+LogFile+FROM+EventLogFile+WHERE+CreatedDate>={last_to_timestamp}+AND+CreatedDate<{to_timestamp}+AND+LogIntervalType=Daily+AND+Foo={now}' + ) + + def test_new_returns_query_obj_with_expected_config(self): + ''' + new() returns a query instance with the input query dict minus the query property + given: a query factory + when: new() is called with an Api instance, query dict, lag time, timestamp, and generation interval + then: returns a query instance with a config equal to the input query dict minus the query property + ''' + + # setup + last_to_timestamp = util.get_iso_date_with_offset(time_lag_minutes=1000) + + # execute + api = ApiStub() + f = query.QueryFactory() + q = f.new( + api, + { + 'query': 'SELECT LogFile FROM EventLogFile', + 'foo': 'bar', + 'beep': 'boop', + 'bip': 0, + 'bop': 5, + }, + 500, + last_to_timestamp, + 'Daily', + ) + + # verify + config = q.get_config() + + self.assertIsNotNone(config) + self.assertTrue(type(config) is mod_config.Config) + self.assertFalse('query' in config) + self.assertTrue('foo' in config) + self.assertEqual(config['foo'], 'bar') + self.assertTrue('beep' in config) + self.assertEqual(config['beep'], 'boop') + self.assertTrue('bip' in config) + self.assertEqual(config['bip'], 0) + self.assertTrue('bop' in config) + self.assertEqual(config['bop'], 5) + + def test_new_returns_query_obj_with_given_api_ver(self): + ''' + new() returns a query instance with the api version specified in the query dict + given: a query factory + when: new() is called with an Api instance, query dict, lag time, timestamp, and generation interval + and when: an api version is specified in the query dict + then: returns a query instance with the api version specified in the query dict + ''' + + # setup + api = ApiStub(api_ver='54.0') + last_to_timestamp = util.get_iso_date_with_offset(time_lag_minutes=1000) + + # execute + f = query.QueryFactory() + q = f.new( + api, + { + 'query': 'SELECT LogFile FROM EventLogFile', + 'api_ver': '58.0' + }, + 500, + last_to_timestamp, + 'Daily', + ) + + # verify + self.assertEqual(q.api_ver, '58.0') + + def test_new_returns_query_obj_with_default_api_ver(self): + ''' + new() returns a query instance without an api version + given: a query factory + when: new() is called with an Api instance, query dict, lag time, timestamp, and generation interval + and when: no api version is specified in the query dict + then: returns a query instance without an api version + ''' + + # setup + api = ApiStub(api_ver='54.0') + last_to_timestamp = util.get_iso_date_with_offset(time_lag_minutes=1000) + + # execute + f = query.QueryFactory() + q = f.new( + api, + { + 'query': 'SELECT LogFile FROM EventLogFile', + }, + 500, + last_to_timestamp, + 'Daily', + ) + + # verify + self.assertIsNone(q.api_ver) diff --git a/src/tests/test_salesforce.py b/src/tests/test_salesforce.py new file mode 100644 index 0000000..2e61137 --- /dev/null +++ b/src/tests/test_salesforce.py @@ -0,0 +1,717 @@ +from datetime import datetime, timedelta +import unittest + + +from . import \ + ApiFactoryStub, \ + AuthenticatorStub, \ + DataCacheStub, \ + PipelineStub, \ + QueryStub, \ + QueryFactoryStub, \ + SessionStub +from newrelic_logging import \ + config, \ + salesforce, \ + util, \ + LoginException, \ + SalesforceApiException + + +class TestSalesforce(unittest.TestCase): + def test_init(self): + _now = datetime.utcnow() + + def _utcnow(): + nonlocal _now + return _now + + util._UTCNOW = _utcnow + + ''' + given: an instance name and configuration, a data cache, authenticator, + pipeline, query factory, initial delay value, and a set of + queries + when: values are present in the configuration for all relevant config + properties + and when: no data cache is specified + and when: no queries are specified + then: a new Salesforce instance is created with the correct values + ''' + + # setup + time_lag_minutes = 603 + initial_delay = 5 + cfg = config.Config({ + 'api_ver': '55.0', + 'time_lag_minutes': time_lag_minutes, + 'date_field': 'CreateDate', + 'generation_interval': 'Hourly', + }) + auth = AuthenticatorStub() + pipeline = PipelineStub() + data_cache = DataCacheStub() + api_factory = ApiFactoryStub() + query_factory = QueryFactoryStub() + last_to_timestamp = util.get_iso_date_with_offset( + time_lag_minutes, + initial_delay, + ) + + # execute + client = salesforce.SalesForce( + 'test_instance', + cfg, + None, + auth, + pipeline, + api_factory, + query_factory, + initial_delay, + ) + + # verify + self.assertEqual(client.instance_name, 'test_instance') + self.assertEqual(client.data_cache, None) + self.assertEqual(client.time_lag_minutes, time_lag_minutes) + self.assertEqual(client.date_field, 'CreateDate') + self.assertEqual(client.generation_interval, 'Hourly') + self.assertEqual(client.last_to_timestamp, last_to_timestamp) + self.assertIsNotNone(client.api) + self.assertEqual(client.api.authenticator, auth) + self.assertEqual(client.api.api_ver, '55.0') + self.assertTrue(len(client.queries) == 1) + self.assertTrue('query' in client.queries[0]) + self.assertEqual( + client.queries[0]['query'], + salesforce.SALESFORCE_CREATED_DATE_QUERY, + ) + + ''' + given: an instance name and configuration, a data cache, authenticator, + pipeline, query factory, initial delay value, and a set of + queries + and when: no data cache is specified + and when: no queries are specified + and when: no lag time is specified in the config + then: a new Salesforce instance is created with the default lag time + ''' + + # setup + cfg = config.Config({ + 'api_ver': '55.0', + 'date_field': 'CreateDate', + 'generation_interval': 'Hourly', + }) + last_to_timestamp = util.get_iso_date_with_offset( + config.DEFAULT_TIME_LAG_MINUTES, + initial_delay, + ) + + # execute + client = salesforce.SalesForce( + 'test_instance', + cfg, + None, + auth, + pipeline, + api_factory, + query_factory, + initial_delay, + ) + + # verify + self.assertEqual( + client.time_lag_minutes, + config.DEFAULT_TIME_LAG_MINUTES, + ) + self.assertEqual(client.last_to_timestamp, last_to_timestamp) + + ''' + given: an instance name and configuration, a data cache, authenticator, + pipeline, query factory, initial delay value, and a set of + queries + and when: a data cache is specified + and when: no queries are specified + and when: no lag time is specified in the config + then: a new Salesforce instance is created with the lag time set to 0 + ''' + + # setup + last_to_timestamp = util.get_iso_date_with_offset( + 0, + initial_delay, + ) + + # execute + client = salesforce.SalesForce( + 'test_instance', + cfg, + data_cache, + auth, + pipeline, + api_factory, + query_factory, + initial_delay, + ) + + # verify + self.assertEqual( + client.time_lag_minutes, + 0, + ) + self.assertEqual(client.last_to_timestamp, last_to_timestamp) + + ''' + given: an instance name and configuration, a data cache, authenticator, + pipeline, query factory, initial delay value, and a set of + queries + and when: no data cache is specified + and when: no queries are specified + and when: no date field is specified in the config + then: a new Salesforce instance is created with the default date field + ''' + + # setup + cfg = config.Config({ + 'api_ver': '55.0', + 'time_lag_minutes': time_lag_minutes, + 'generation_interval': 'Hourly', + }) + + # execute + client = salesforce.SalesForce( + 'test_instance', + cfg, + None, + auth, + pipeline, + api_factory, + query_factory, + initial_delay, + ) + + # verify + self.assertEqual(client.date_field, config.DATE_FIELD_LOG_DATE) + + ''' + given: an instance name and configuration, a data cache, authenticator, + pipeline, query factory, initial delay value, and a set of + queries + and when: a data cache is specified + and when: no queries are specified + and when: no date field is specified in the config + then: a new Salesforce instance is created with the default date field + ''' + + # execute + client = salesforce.SalesForce( + 'test_instance', + cfg, + data_cache, + auth, + pipeline, + api_factory, + query_factory, + initial_delay, + ) + + # verify + self.assertEqual(client.date_field, config.DATE_FIELD_CREATE_DATE) + + ''' + given: an instance name and configuration, a data cache, authenticator, + pipeline, query factory, initial delay value, and a set of + queries + and when: no data cache is specified + and when: no queries are specified + and when: no generation interval is specified + then: a new Salesforce instance is created with the default generation + interval + ''' + + # setup + cfg = config.Config({ + 'api_ver': '55.0', + 'time_lag_minutes': time_lag_minutes, + 'date_field': 'CreateDate' + }) + + # execute + client = salesforce.SalesForce( + 'test_instance', + cfg, + None, + auth, + pipeline, + api_factory, + query_factory, + initial_delay, + ) + + # verify + self.assertEqual( + client.generation_interval, + config.DEFAULT_GENERATION_INTERVAL, + ) + + ''' + given: an instance name and configuration, a data cache, authenticator, + pipeline, query factory, initial delay value, and a set of + queries + and when: no data cache is specified + and when: no queries are specified + and when: the date field option is set to 'LogDate' + then: a new Salesforce instance is created with the default log date + query + ''' + + # setup + cfg = config.Config({ + 'api_ver': '55.0', + 'time_lag_minutes': time_lag_minutes, + 'date_field': 'LogDate', + 'generation_interval': 'Hourly', + }) + + # execute + client = salesforce.SalesForce( + 'test_instance', + cfg, + None, + auth, + pipeline, + api_factory, + query_factory, + initial_delay, + ) + + # verify + self.assertTrue(len(client.queries) == 1) + self.assertTrue('query' in client.queries[0]) + self.assertEqual( + client.queries[0]['query'], + salesforce.SALESFORCE_LOG_DATE_QUERY, + ) + + ''' + given: an instance name and configuration, a data cache, authenticator, + pipeline, query factory, initial delay value, and a set of + queries + and when: no data cache is specified + and when: queries are specified + then: a new Salesforce instance is created with the specified queries + ''' + + # setup + queries = [ { 'query': 'foo' }, { 'query': 'bar' }] + + # execute + client = salesforce.SalesForce( + 'test_instance', + cfg, + None, + auth, + pipeline, + api_factory, + query_factory, + initial_delay, + queries + ) + + # verify + self.assertTrue(len(client.queries) == 2) + self.assertTrue('query' in client.queries[0]) + self.assertEqual(client.queries[0]['query'], 'foo') + self.assertTrue('query' in client.queries[1]) + self.assertEqual(client.queries[1]['query'], 'bar') + + def test_authenticate_is_called(self): + ''' + given: an instance name and configuration, a data cache, authenticator, + pipeline, query factory, initial delay value, set of queries, + and an http session + when: called + then: the underlying api is called + ''' + + # setup + time_lag_minutes = 603 + initial_delay = 5 + cfg = config.Config({ + 'api_ver': '55.0', + 'time_lag_minutes': time_lag_minutes, + 'date_field': 'CreateDate', + 'generation_interval': 'Hourly', + }) + auth = AuthenticatorStub() + pipeline = PipelineStub() + api_factory = ApiFactoryStub() + query_factory = QueryFactoryStub() + session = SessionStub() + + client = salesforce.SalesForce( + 'test_instance', + cfg, + None, + auth, + pipeline, + api_factory, + query_factory, + initial_delay, + ) + + # execute + client.authenticate(session) + + # verify + self.assertIsNotNone(client.api) + self.assertEqual(client.api.authenticator, auth) + self.assertTrue(auth.authenticate_called) + + def test_authenticate_raises_login_exception_if_authenticate_does(self): + ''' + given: an instance name and configuration, a data cache, authenticator, + pipeline, query factory, initial delay value, set of queries, + and an http session + when: called + and when: authenticator.authenticate() raises a LoginException + then: raise a LoginException + ''' + + # setup + time_lag_minutes = 603 + initial_delay = 5 + cfg = config.Config({ + 'api_ver': '55.0', + 'time_lag_minutes': time_lag_minutes, + 'date_field': 'CreateDate', + 'generation_interval': 'Hourly', + }) + auth = AuthenticatorStub(raise_login_error=True) + pipeline = PipelineStub() + api_factory = ApiFactoryStub() + query_factory = QueryFactoryStub() + session = SessionStub() + + client = salesforce.SalesForce( + 'test_instance', + cfg, + None, + auth, + pipeline, + api_factory, + query_factory, + initial_delay, + ) + + # execute / verify + with self.assertRaises(LoginException) as _: + client.authenticate(session) + + # verify + self.assertIsNotNone(client.api) + self.assertEqual(client.api.authenticator, auth) + self.assertTrue(auth.authenticate_called) + + def test_slide_time_range(self): + _now = datetime.utcnow() + + def _utcnow(): + nonlocal _now + return _now + + util._UTCNOW = _utcnow + + ''' + given: an instance name and configuration, data cache, authenticator, + pipeline, query factory, initial delay value, and a set of + queries + when: called + then: the 'last_to_timestamp' is updated + ''' + + # setup + time_lag_minutes = 603 + initial_delay = 5 + cfg = config.Config({ + 'api_ver': '55.0', + 'time_lag_minutes': time_lag_minutes, + 'date_field': 'CreateDate', + 'generation_interval': 'Hourly', + }) + auth = AuthenticatorStub() + pipeline = PipelineStub() + api_factory = ApiFactoryStub() + query_factory = QueryFactoryStub() + + client = salesforce.SalesForce( + 'test_instance', + cfg, + None, + auth, + pipeline, + api_factory, + query_factory, + initial_delay, + ) + + last_to_before = client.last_to_timestamp + + # pretend it's 10 minutes from now to ensure this is different from + # the timestamp calculated during object creation + _now = datetime.utcnow() + timedelta(minutes=10) + + # execute + client.slide_time_range() + + # verify + last_to_after = client.last_to_timestamp + + self.assertNotEqual(last_to_after, last_to_before) + + def test_fetch_logs(self): + ''' + given: an instance name and configuration, data cache, authenticator, + pipeline, query factory, initial delay value, a set of queries, + and an http session + when: the set of queries is the empty set + then: a default query should be executed + ''' + + # setup + time_lag_minutes = 603 + initial_delay = 5 + cfg = config.Config({ + 'api_ver': '55.0', + 'time_lag_minutes': time_lag_minutes, + 'date_field': 'CreateDate', + 'generation_interval': 'Hourly', + }) + auth = AuthenticatorStub() + pipeline = PipelineStub() + api_factory = ApiFactoryStub() + query_factory = QueryFactoryStub() + session = SessionStub() + + client = salesforce.SalesForce( + 'test_instance', + cfg, + None, + auth, + pipeline, + api_factory, + query_factory, + initial_delay, + ) + + # execute + client.fetch_logs(session) + + # verify + + self.assertEqual(len(query_factory.queries), 1) + query = query_factory.queries[0] + self.assertTrue(query.executed) + self.assertTrue(pipeline.executed) + self.assertEqual(len(pipeline.queries), 1) + self.assertEqual(query, pipeline.queries[0]) + + ''' + given: an instance name and configuration, data cache, authenticator, + pipeline, query factory, initial delay value, a set of queries, + and an http session + when: the set of queries is not the empty set + then: each query should be executed + ''' + + auth = AuthenticatorStub() + pipeline = PipelineStub() + query_factory = QueryFactoryStub() + api_factory = ApiFactoryStub() + session = SessionStub() + queries = [ + { + 'query': 'foo', + }, + { + 'query': 'bar', + }, + { + 'query': 'beep', + }, + { + 'query': 'boop', + }, + ] + + client = salesforce.SalesForce( + 'test_instance', + cfg, + None, + auth, + pipeline, + api_factory, + query_factory, + initial_delay, + queries, + ) + + # execute + client.fetch_logs(session) + + # verify + + self.assertEqual(len(query_factory.queries), 4) + + query1 = query_factory.queries[0] + self.assertEqual(query1.query, 'foo') + self.assertTrue(query1.executed) + query2 = query_factory.queries[1] + self.assertEqual(query2.query, 'bar') + self.assertTrue(query2.executed) + query3 = query_factory.queries[2] + self.assertEqual(query3.query, 'beep') + self.assertTrue(query3.executed) + query4 = query_factory.queries[3] + self.assertEqual(query4.query, 'boop') + self.assertTrue(query4.executed) + + self.assertTrue(pipeline.executed) + self.assertEqual(len(pipeline.queries), 4) + self.assertEqual(query1, pipeline.queries[0]) + self.assertEqual(query2, pipeline.queries[1]) + self.assertEqual(query3, pipeline.queries[2]) + self.assertEqual(query4, pipeline.queries[3]) + + ''' + given: an instance name and configuration, data cache, authenticator, + pipeline, query factory, initial delay value, a set of queries, + and an http session + when: no response is returned from a query + then: query should be executed and pipeline should not be executed + ''' + + auth = AuthenticatorStub() + pipeline = PipelineStub() + query = QueryStub(result=None) + query_factory = QueryFactoryStub(query) + api_factory = ApiFactoryStub() + session = SessionStub() + + client = salesforce.SalesForce( + 'test_instance', + cfg, + None, + auth, + pipeline, + api_factory, + query_factory, + initial_delay, + ) + + # execute + client.fetch_logs(session) + + # verify + + self.assertTrue(query.executed) + self.assertFalse(pipeline.executed) + + ''' + given: an instance name and configuration, data cache, authenticator, + pipeline, query factory, initial delay value, a set of queries, + and an http session + when: there is no 'records' attribute in response returned from query + then: query should be executed and pipeline should not be executed + ''' + + auth = AuthenticatorStub() + pipeline = PipelineStub() + query = QueryStub(result={ 'foo': 'bar' }) + query_factory = QueryFactoryStub(query) + api_factory = ApiFactoryStub() + session = SessionStub() + + client = salesforce.SalesForce( + 'test_instance', + cfg, + None, + auth, + pipeline, + api_factory, + query_factory, + initial_delay, + ) + + # execute + client.fetch_logs(session) + + # verify + + self.assertTrue(query.executed) + self.assertFalse(pipeline.executed) + + ''' + given: an instance name and configuration, data cache, authenticator, + pipeline, api factory, query factory, initial delay value, and an + http session + when: pipeline.execute() raises a LoginException + then: raise a LoginException + ''' + + auth = AuthenticatorStub() + pipeline = PipelineStub(raise_login_error=True) + query = QueryStub() + query_factory = QueryFactoryStub(query) + api_factory = ApiFactoryStub() + session = SessionStub() + + client = salesforce.SalesForce( + 'test_instance', + cfg, + None, + auth, + pipeline, + api_factory, + query_factory, + initial_delay, + ) + + # execute / verify + with self.assertRaises(LoginException) as _: + client.fetch_logs(session) + + ''' + given: an instance name and configuration, data cache, authenticator, + pipeline, api factory, query factory, initial delay value, and an + http session + when: pipeline.execute() raises a SalesforceApiException + then: raise a SalesforceApiException + ''' + + auth = AuthenticatorStub() + pipeline = PipelineStub(raise_error=True) + query = QueryStub() + query_factory = QueryFactoryStub(query) + api_factory = ApiFactoryStub() + session = SessionStub() + + client = salesforce.SalesForce( + 'test_instance', + cfg, + None, + auth, + pipeline, + api_factory, + query_factory, + initial_delay, + ) + + # execute / verify + with self.assertRaises(SalesforceApiException) as _: + client.fetch_logs(session) + + +if __name__ == '__main__': + unittest.main() diff --git a/src/tests/test_util.py b/src/tests/test_util.py new file mode 100644 index 0000000..4ad247f --- /dev/null +++ b/src/tests/test_util.py @@ -0,0 +1,509 @@ +from datetime import datetime, timedelta +import hashlib +import json +import unittest + +from newrelic_logging import util + +class TestUtilities(unittest.TestCase): + def test_is_logfile_response(self): + ''' + given: a set of query result records + when: the set is the empty set + then: return true + ''' + + # execute/verify + self.assertTrue(util.is_logfile_response([])) + + ''' + given: a set of query result records + when: the first record in the set contains a 'LogFile' property + then: return true + ''' + + # execute/verify + self.assertTrue(util.is_logfile_response([ + { 'LogFile': 'example' } + ])) + + ''' + given: a set of query result records + when: the first record in the set does not contain a 'LogFile' property + then: return false + ''' + + # execute/verify + self.assertFalse(util.is_logfile_response([{}])) + + def test_generate_record_id(self): + ''' + given: a set of id keys and a query result record + when: the set of id keys is the empty set + then: return the empty string + ''' + + # execute + record_id = util.generate_record_id([], { 'Name': 'foo' }) + + # verify + self.assertEqual(record_id, '') + + ''' + given: a set of id keys and a query result record + when: the set of id keys is not empty + and when: there is some key for which the query result record does not + have a property + then: an exception is raised + ''' + + # execute/verify + with self.assertRaises(Exception): + util.generate_record_id([ 'EventType' ], { 'Name': 'foo' }) + + ''' + given: a set of id keys and a query result record + when: the set of id keys is not empty + and when: the query result record has a property for the key but the + value for that key is the empty string + then: return the empty string + ''' + + # execute + record_id = util.generate_record_id([ 'Name' ], { 'Name': '' }) + + # verify + self.assertEqual(record_id, '') + + ''' + given: a set of id keys and a query result record + when: the set of id keys is not empty + and when: the query result record has a property for the key with a + value that is not the emptry string + then: return a value obtained by concatenating the values for all id + keys and creating a sha3 256 message digest over that value + ''' + + # execute + record_id = util.generate_record_id([ 'Name' ], { 'Name': 'foo' }) + + # verify + m = hashlib.sha3_256() + m.update('foo'.encode('utf-8')) + expected = m.hexdigest() + + self.assertEqual(expected, record_id) + + def test_maybe_convert_str_to_num(self): + ''' + given: a string + when: the string contains a valid integer + then: the string value is converted to an integer + ''' + + # execute + val = util.maybe_convert_str_to_num('2') + + # verify + self.assertTrue(type(val) is int) + self.assertEqual(val, 2) + + ''' + given: a string + when: the string contains a valid floating point number + then: the string value is converted to a float + ''' + + # execute + val = util.maybe_convert_str_to_num('3.14') + + # verify + self.assertTrue(type(val) is float) + self.assertEqual(val, 3.14) + + ''' + given: a string + when: the string contains neither a valid integer nor a valid floating + point number + then: the string value is returned + ''' + + # execute + val = util.maybe_convert_str_to_num('not a number') + + # verify + self.assertTrue(type(val) is str) + self.assertEqual(val, 'not a number') + + def test_get_iso_date_with_offset(self): + _now = datetime.utcnow() + + def _utcnow(): + nonlocal _now + return _now + + util._UTCNOW = _utcnow + + ''' + given: a time lag and initial delay + when: neither are specified (both default to 0) + then: return the current time in iso format + ''' + + # setup + val = _now.isoformat(timespec='milliseconds') + 'Z' + + # execute + isonow = util.get_iso_date_with_offset() + + # verify + self.assertEqual(val, isonow) + + ''' + given: a time lag and initial delay + when: time lag is specified + then: return the current time minus the time lag in iso format + ''' + + # setup + time_lag_minutes = 412 + val = (_now - timedelta(minutes=time_lag_minutes)) \ + .isoformat(timespec='milliseconds') + 'Z' + + # execute + isonow = util.get_iso_date_with_offset( + time_lag_minutes=time_lag_minutes + ) + + # verify + self.assertEqual(val, isonow) + + ''' + given: a time lag and initial delay + when: initial delay is specified + then: return the current time minus the initial delay in iso format + ''' + + # setup + initial_delay = 678 + val = (_now - timedelta(minutes=initial_delay)) \ + .isoformat(timespec='milliseconds') + 'Z' + + # execute + isonow = util.get_iso_date_with_offset(initial_delay=initial_delay) + + # verify + self.assertEqual(val, isonow) + + ''' + given: a time lag and initial delay + when: both are specified + then: return the current time minus the sum of the time lag and initial + delay in iso format + ''' + + # setup + initial_delay = 678 + val = (_now - timedelta(minutes=(time_lag_minutes + initial_delay))) \ + .isoformat(timespec='milliseconds') + 'Z' + + # execute + isonow = util.get_iso_date_with_offset( + time_lag_minutes, + initial_delay, + ) + + # verify + self.assertEqual(val, isonow) + + def test_is_primitive_true_for_primitive_types(self): + ''' + is_primitive() returns true for types considered to be "primitive" (str, int, float, bool, None) + given: a set of "primitive" values + when: is_primitive() is called + then: returns True + ''' + + # setup + vals = ('string', 100, 62.1, False, None) + + # execute / verify + for v in vals: + b = util.is_primitive(v) + self.assertTrue(b) + + def test_is_primitive_false_for_non_primitive_types(self): + ''' + is_primitive() returns false for types not considered to be "primitive" + given: a set of "non-primitive" values + when: is_primitive() is called + then: returns False + ''' + + # setup + vals = (['list'], (1, 2), { 'foo': 'bar' }, Exception()) + + # execute / verify + for v in vals: + b = util.is_primitive(v) + self.assertFalse(b) + + def test_process_query_result_copies_primitive_fields(self): + ''' + process_query_result() copies all primitive fields to the new dict + given: JSON result from an SOQL query + when: the SOQL query result contains only primitive (non-nested) fields + then: returns dict containing all primitive fields + ''' + + # setup + query_result = json.loads('''{ + "Action": "PermSetFlsChanged", + "CreatedByContext": null, + "CreatedById": "0058W00000A7LvTQAV", + "CreatedByIssuer": 1, + "CreatedDate": "2023-11-30T17:33:08.000+0000", + "DelegateUser": 2.0, + "Display": "Changed permission set 00e1U000000XFwxQAG: field-level security for Task: Related To was changed from Read/Write to No Access", + "Id": "0Ym7c00001RGP6MCAX", + "ResponsibleNamespacePrefix": false, + "Section": "Manage Users" + }''') + + expected_result = { + 'Action': 'PermSetFlsChanged', + 'CreatedByContext': None, + 'CreatedById': '0058W00000A7LvTQAV', + 'CreatedByIssuer': 1, + 'CreatedDate': '2023-11-30T17:33:08.000+0000', + 'DelegateUser': 2.0, + 'Display': 'Changed permission set 00e1U000000XFwxQAG: field-level security for Task: Related To was changed from Read/Write to No Access', + 'Id': '0Ym7c00001RGP6MCAX', + 'ResponsibleNamespacePrefix': False, + 'Section': 'Manage Users' + } + + # execute + result = util.process_query_result(query_result) + + # verify + self.assertEqual(expected_result, result) + + def test_process_query_result_flattens_and_copies_nested_fields(self): + ''' + process_query_result() flattens all nested fields and copies them to the new dict with keys that use the syntax "field1.nestedfield1.nestedfield1" and so on + given: JSON result from an SOQL query + when: the SOQL query result contains primitive and nested fields + then: returns dict containing all primitive fields + and: contains primitives from all nested fields + and: the keys for all nested fields use the syntax 'field1.nestedfield1', 'field2.nestedfield2.nestedfield1' and so on + ''' + + # setup + query_result = json.loads('''{ + "Action": "PermSetFlsChanged", + "CreatedByContext": null, + "CreatedById": "0058W00000A7LvTQAV", + "CreatedBy": { + "Name": "Chetan Gupta", + "Profile": { + "Name": "System Administrator" + }, + "UserType": "Standard" + }, + "CreatedByIssuer": null, + "CreatedDate": "2023-11-30T17:33:08.000+0000", + "DelegateUser": null, + "Display": "Changed permission set 00e1U000000XFwxQAG: field-level security for Task: Related To was changed from Read/Write to No Access", + "Id": "0Ym7c00001RGP6MCAX", + "ResponsibleNamespacePrefix": null, + "Section": "Manage Users" + }''') + + expected_result = { + 'Action': 'PermSetFlsChanged', + 'CreatedBy.Name': 'Chetan Gupta', + 'CreatedBy.Profile.Name': 'System Administrator', + 'CreatedBy.UserType': 'Standard', + 'CreatedByContext': None, + 'CreatedById': '0058W00000A7LvTQAV', + 'CreatedByIssuer': None, + 'CreatedDate': '2023-11-30T17:33:08.000+0000', + 'DelegateUser': None, + 'Display': 'Changed permission set 00e1U000000XFwxQAG: field-level security for Task: Related To was changed from Read/Write to No Access', + 'Id': '0Ym7c00001RGP6MCAX', + 'ResponsibleNamespacePrefix': None, + 'Section': 'Manage Users' + } + + # execute + result = util.process_query_result(query_result) + + # verify + self.assertEqual(expected_result, result) + + def test_process_query_result_ignores_attributes_fields(self): + ''' + process_query_result() ignores all fields and nested fields named 'attributes' + given: JSON result from an SOQL query + when: the SOQL query result contains primitive and nested fields + and when: the SOQL query result contains fields and nested fields named 'attributes' + then: returns dict containing all primitive fields + and: contains primitives from all nested fields + and: the keys for all nested fields use the syntax 'field1.nestedfield1', 'field2.nestedfield2.nestedfield1' and so on + and: fields and nested fields named 'attributes' are ignored + ''' + + # setup + query_result = json.loads('''{ + "attributes": { + "type": "SetupAuditTrail", + "url": "/services/data/v55.0/sobjects/SetupAuditTrail/0Ym7c00001RGP6MCAX" + }, + "Action": "PermSetFlsChanged", + "CreatedByContext": null, + "CreatedById": "0058W00000A7LvTQAV", + "CreatedBy": { + "attributes": { + "type": "User", + "url": "/services/data/v55.0/sobjects/User/0058W00000A7LvTQAV" + }, + "Name": "Chetan Gupta", + "Profile": { + "attributes": { + "type": "Profile", + "url": "/services/data/v55.0/sobjects/Profile/00e1U000001wRS1QAM" + }, + "Name": "System Administrator" + }, + "UserType": "Standard" + }, + "CreatedByIssuer": null, + "CreatedDate": "2023-11-30T17:33:08.000+0000", + "DelegateUser": null, + "Display": "Changed permission set 00e1U000000XFwxQAG: field-level security for Task: Related To was changed from Read/Write to No Access", + "Id": "0Ym7c00001RGP6MCAX", + "ResponsibleNamespacePrefix": null, + "Section": "Manage Users" + }''') + + expected_result = { + 'Action': 'PermSetFlsChanged', + 'CreatedBy.Name': 'Chetan Gupta', + 'CreatedBy.Profile.Name': 'System Administrator', + 'CreatedBy.UserType': 'Standard', + 'CreatedByContext': None, + 'CreatedById': '0058W00000A7LvTQAV', + 'CreatedByIssuer': None, + 'CreatedDate': '2023-11-30T17:33:08.000+0000', + 'DelegateUser': None, + 'Display': 'Changed permission set 00e1U000000XFwxQAG: field-level security for Task: Related To was changed from Read/Write to No Access', + 'Id': '0Ym7c00001RGP6MCAX', + 'ResponsibleNamespacePrefix': None, + 'Section': 'Manage Users' + } + + # execute + result = util.process_query_result(query_result) + + # verify + self.assertEqual(expected_result, result) + + def test_process_query_result_ignores_non_dict_structured_fields(self): + ''' + process_query_result() ignores all fields and nested fields that are neither "primitive" nor type dict + given: JSON result from an SOQL query + when: the SOQL query result contains primitive and nested fields + and when: the SOQL query result contains fields and nested fields that are neither "primitive" nor type dict + then: returns dict containing all primitive fields + and: contains primitives from all nested fields + and: the keys for all nested fields use the syntax 'field1.nestedfield1', 'field2.nestedfield2.nestedfield1' and so on + and: fields and nested fields that are neither "primitive" nor type dict are ignored + ''' + + # setup + query_result = json.loads('''{ + "Action": "PermSetFlsChanged", + "CreatedByContext": null, + "CreatedById": "0058W00000A7LvTQAV", + "CreatedByIssuer": null, + "CreatedDate": "2023-11-30T17:33:08.000+0000", + "DelegateUser": null, + "Display": "Changed permission set 00e1U000000XFwxQAG: field-level security for Task: Related To was changed from Read/Write to No Access", + "Id": "0Ym7c00001RGP6MCAX", + "ResponsibleNamespacePrefix": null, + "Section": "Manage Users", + "RandomNested": { + "RandomNestedArray": ["beep", "boop"] + }, + "RandomArray": ["foo", "bar"] + }''') + + expected_result = { + 'Action': 'PermSetFlsChanged', + 'CreatedByContext': None, + 'CreatedById': '0058W00000A7LvTQAV', + 'CreatedByIssuer': None, + 'CreatedDate': '2023-11-30T17:33:08.000+0000', + 'DelegateUser': None, + 'Display': 'Changed permission set 00e1U000000XFwxQAG: field-level security for Task: Related To was changed from Read/Write to No Access', + 'Id': '0Ym7c00001RGP6MCAX', + 'ResponsibleNamespacePrefix': None, + 'Section': 'Manage Users' + } + + # execute + result = util.process_query_result(query_result) + + # verify + self.assertEqual(expected_result, result) + + def test_get_timestamp_returns_current_posix_ms_as_int(self): + ''' + get_timestamp() returns current posix time in ms as an integer + given: a date string + when: the date string is None + then: returns the current posix time in ms as an integer + ''' + + # setup + + __now = datetime.now() + + def _now(): + nonlocal __now + return __now + + util._NOW = _now + + expected = int(__now.timestamp() * 1000) + + # execute + timestamp = util.get_timestamp() + + # verify + self.assertEqual(expected, timestamp) + + def test_get_timestamp_returns_posix_ms_as_int_for_date_string(self): + ''' + get_timestamp() returns the posix time in ms as an integer for the time specified in the date string + given: a date string + when: the date string is not None + and when: the date string is of the form %Y-%m-%dT%H:%M:%S.%f%z + then: returns the posix time in ms as an integer for the date string + ''' + + # setup + date_string = '2024-03-11T00:00:00.000+0000' + time = datetime.strptime(date_string, '%Y-%m-%dT%H:%M:%S.%f%z') + expected = int(time.timestamp() * 1000) + + # execute + timestamp = util.get_timestamp(date_string) + + # verify + self.assertEqual(expected, timestamp) + + +if __name__ == '__main__': + unittest.main()