Skip to content

Commit

Permalink
move class and testing function out into individual files
Browse files Browse the repository at this point in the history
  • Loading branch information
frankhereford committed Nov 27, 2023
1 parent 0323453 commit d6fcef0
Showing 1 changed file with 28 additions and 98 deletions.
126 changes: 28 additions & 98 deletions atd-etl/cris_import/cris_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import lib.sql as util
import lib.graphql as graphql
from lib.helpers_import import insert_crash_change_template as insert_change_template
from lib.sshkeytempdir import SshKeyTempDir, write_key_to_file
from lib.testing import mess_with_incoming_records_to_ensure_updates

DEPLOYMENT_ENVIRONMENT = os.environ.get(
"ENVIRONMENT", "development"
Expand Down Expand Up @@ -82,13 +84,13 @@ def main():
SFTP_ENDPOINT_SSH_PRIVATE_KEY = secrets["sftp_endpoint_private_key"]

local_mode = False
if bool(glob.glob('/app/development_extracts/*.zip')):
if bool(glob.glob("/app/development_extracts/*.zip")):
local_mode = True

zip_location = None
if not local_mode: # Production
if not local_mode: # Production
zip_location = download_archives()
else: # Development. Put a zip in the development_extracts directory to use it.
else: # Development. Put a zip in the development_extracts directory to use it.
zip_location = specify_extract_location()

if not zip_location:
Expand All @@ -103,10 +105,9 @@ def main():
pgloader_command_files = pgloader_csvs_into_database(schema_name)
trimmed_token = remove_trailing_carriage_returns(pgloader_command_files)
typed_token = align_db_typing(trimmed_token)
#typed_token = mess_with_incoming_records_to_ensure_updates(typed_token) # for testing PR 1316
align_records_token = align_records(typed_token)
clean_up_import_schema(align_records_token)
if not local_mode: # We're using a locally provided zip file, so skip these steps
if not local_mode: # We're using a locally provided zip file, so skip these steps
remove_archives_from_sftp_endpoint(zip_location)
upload_csv_files_to_s3(archive)

Expand Down Expand Up @@ -197,12 +198,12 @@ def get_secrets():
"opitem": "SFTP Endpoint Key",
"opfield": ".private key",
"opvault": VAULT_ID,
},
},
"bastion_ssh_private_key": {
"opitem": "RDS Bastion Key",
"opfield": ".private key",
"opvault": VAULT_ID,
},
},
}

# instantiate a 1Password client
Expand All @@ -211,64 +212,8 @@ def get_secrets():
return onepasswordconnectsdk.load_dict(client, REQUIRED_SECRETS)


def mess_with_incoming_records_to_ensure_updates(map_state):
print(map_state)
schema = map_state["import_schema"]
with SshKeyTempDir() as key_directory:
write_key_to_file(key_directory + "/id_ed25519", DB_BASTION_HOST_SSH_PRIVATE_KEY + "\n")
ssh_tunnel = SSHTunnelForwarder(
(DB_BASTION_HOST),
ssh_username=DB_BASTION_HOST_SSH_USERNAME,
ssh_private_key=f"{key_directory}/id_ed25519",
remote_bind_address=(DB_RDS_HOST, 5432),
)
ssh_tunnel.start()

pg = psycopg2.connect(
host="localhost",
port=ssh_tunnel.local_bind_port,
user=DB_USER,
password=DB_PASS,
dbname=DB_NAME,
sslmode=DB_SSL_REQUIREMENT,
sslrootcert="/root/rds-combined-ca-bundle.pem",
)

sql = f"""UPDATE {schema}.crash
SET rpt_street_name = rpt_street_name || ' ' || lpad(to_hex((floor(random() * 16777215)::int)), 6, '0');
"""
print(sql)
cursor = pg.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
cursor.execute(sql)

sql = f"""UPDATE {schema}.unit
SET vin = vin || ' ' || lpad(to_hex((floor(random() * 16777215)::int)), 6, '0');
"""
print(sql)
cursor = pg.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
cursor.execute(sql)

sql = f"""UPDATE {schema}.person
SET prsn_last_name = prsn_last_name || ' ' || lpad(to_hex((floor(random() * 16777215)::int)), 6, '0');
"""
print(sql)
cursor = pg.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
cursor.execute(sql)

sql = f"""UPDATE {schema}.primaryperson
SET prsn_last_name = prsn_last_name || ' ' || lpad(to_hex((floor(random() * 16777215)::int)), 6, '0');
"""
print(sql)
cursor = pg.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
cursor.execute(sql)

pg.commit()

return map_state


def specify_extract_location():
zip_files = glob.glob('/app/development_extracts/*.zip')
zip_files = glob.glob("/app/development_extracts/*.zip")
if not zip_files:
return False

Expand All @@ -280,7 +225,6 @@ def specify_extract_location():
return zip_tmpdir



def download_archives():
"""
Connect to the SFTP endpoint which receives archives from CRIS and
Expand All @@ -290,7 +234,9 @@ def download_archives():
"""

with SshKeyTempDir() as key_directory:
write_key_to_file(key_directory + "/id_ed25519", SFTP_ENDPOINT_SSH_PRIVATE_KEY + "\n")
write_key_to_file(
key_directory + "/id_ed25519", SFTP_ENDPOINT_SSH_PRIVATE_KEY + "\n"
)

zip_tmpdir = tempfile.mkdtemp()
rsync = None
Expand Down Expand Up @@ -409,7 +355,9 @@ def remove_archives_from_sftp_endpoint(zip_location):
Returns: None
"""
with SshKeyTempDir() as key_directory:
write_key_to_file(key_directory + "/id_ed25519", SFTP_ENDPOINT_SSH_PRIVATE_KEY + "\n")
write_key_to_file(
key_directory + "/id_ed25519", SFTP_ENDPOINT_SSH_PRIVATE_KEY + "\n"
)

print(zip_location)
for archive in os.listdir(zip_location):
Expand Down Expand Up @@ -445,7 +393,10 @@ def pgloader_csvs_into_database(map_state):
print(f"Command file: {command_file}")

with SshKeyTempDir() as key_directory:
write_key_to_file(key_directory + "/id_ed25519", DB_BASTION_HOST_SSH_PRIVATE_KEY + "\n")
write_key_to_file(
key_directory + "/id_ed25519",
DB_BASTION_HOST_SSH_PRIVATE_KEY + "\n",
)
# we're going to get away with opening up this tunnel here for all pgloader commands
# because they get executed before this goes out of scope
ssh_tunnel = SSHTunnelForwarder(
Expand Down Expand Up @@ -488,9 +439,10 @@ def pgloader_csvs_into_database(map_state):


def remove_trailing_carriage_returns(map_state):

with SshKeyTempDir() as key_directory:
write_key_to_file(key_directory + "/id_ed25519", DB_BASTION_HOST_SSH_PRIVATE_KEY + "\n")
write_key_to_file(
key_directory + "/id_ed25519", DB_BASTION_HOST_SSH_PRIVATE_KEY + "\n"
)
ssh_tunnel = SSHTunnelForwarder(
(DB_BASTION_HOST),
ssh_username=DB_BASTION_HOST_SSH_USERNAME,
Expand All @@ -517,7 +469,6 @@ def remove_trailing_carriage_returns(map_state):


def align_db_typing(map_state):

"""
This function compares the target table in the VZDB with the corollary table in the import schema. For each column pair,
the type of the VZDB table's column is applied to the import table. This acts as a strong typing check for all input data,
Expand Down Expand Up @@ -596,7 +547,6 @@ def align_db_typing(map_state):


def align_records(map_state):

"""
This function begins by preparing a number of list and string variables containing SQL fragments.
These fragments are used to create queries which inspect the data differences between a pair of records.
Expand Down Expand Up @@ -796,9 +746,10 @@ def create_import_schema_name(mapped_state):


def create_target_import_schema(map_state):

with SshKeyTempDir() as key_directory:
write_key_to_file(key_directory + "/id_ed25519", DB_BASTION_HOST_SSH_PRIVATE_KEY + "\n")
write_key_to_file(
key_directory + "/id_ed25519", DB_BASTION_HOST_SSH_PRIVATE_KEY + "\n"
)
ssh_tunnel = SSHTunnelForwarder(
(DB_BASTION_HOST),
ssh_username=DB_BASTION_HOST_SSH_USERNAME,
Expand Down Expand Up @@ -846,7 +797,9 @@ def create_target_import_schema(map_state):

def clean_up_import_schema(map_state):
with SshKeyTempDir() as key_directory:
write_key_to_file(key_directory + "/id_ed25519", DB_BASTION_HOST_SSH_PRIVATE_KEY + "\n")
write_key_to_file(
key_directory + "/id_ed25519", DB_BASTION_HOST_SSH_PRIVATE_KEY + "\n"
)
ssh_tunnel = SSHTunnelForwarder(
(DB_BASTION_HOST),
ssh_username=DB_BASTION_HOST_SSH_USERNAME,
Expand Down Expand Up @@ -875,29 +828,6 @@ def clean_up_import_schema(map_state):

return map_state

# these temp directories are used to store ssh keys, because they will
# automatically clean themselves up when they go out of scope.
class SshKeyTempDir:
def __init__(self):
self.path = None

def __enter__(self):
self.path = tempfile.mkdtemp(dir='/tmp')
return self.path

def __exit__(self, exc_type, exc_val, exc_tb):
shutil.rmtree(self.path)

def write_key_to_file(path, content):
# Open the file with write permissions and create it if it doesn't exist
fd = os.open(path, os.O_WRONLY | os.O_CREAT, 0o600)

# Write the content to the file
os.write(fd, content.encode())

# Close the file
os.close(fd)


if __name__ == "__main__":
main()

0 comments on commit d6fcef0

Please sign in to comment.