Skip to content

Commit

Permalink
Update _simpler_sf.py
Browse files Browse the repository at this point in the history
Added full-path column names for queries with more than two levels of relationship. Refactored variables in _recursive_unnest for explicitness. Removed progress bar for single queries. Added docstring to _smart_query().
  • Loading branch information
benvigano authored Jan 3, 2024
1 parent d3995e2 commit c3ca5ae
Showing 1 changed file with 40 additions and 23 deletions.
63 changes: 40 additions & 23 deletions simpler_sf/_simpler_sf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@
import re


def _recursive_unnest(d, k1, r):
def _recursive_unnest(data, parent_path='', results={}):
'''Recursively un-nest records'''
for k2 in d:
if isinstance(d[k2], Mapping) and "attributes" in d[k2]:
r = _recursive_unnest(d[k2], k2, r)
for current_level_key in data:
path = '.'.join(filter(None, [parent_path, current_level_key]))
if isinstance(data[current_level_key], Mapping) and "attributes" in data[current_level_key]:
results = _recursive_unnest(data[current_level_key], path, results)
else:
if k2 != "attributes":
if k1 == "":
r[k2] = d[k2]
else:
r['.'.join([k1, k2])] = d[k2]
return r
if current_level_key != "attributes":
results[path] = data[current_level_key]
else:
pass
return results


def _unnest_query_output(records) -> dict:
Expand All @@ -31,20 +31,22 @@ def _unnest_query_output(records) -> dict:
results = {}

for record in records:
unnested_record = _recursive_unnest(record, "", {})

unnested_record = _recursive_unnest(record)
# Get the row index, to know how many blank fillers
# to insert in case a new field is found
if len(results.keys()) != 0:
row_ix = len(results[list(results.keys())[0]])
else:
row_ix = 0

# Notice: when a record has a nan value, that field is not included in the record dictionary
'''
Notice: when a record has a nan value, that field is not included in the record dictionary
# If the record has a field that wasn't present in the previous records,
# insert as many blank fillers as the number of previous records.
# Note: in the first iteration this will just write in 'results' an empty list for each field.
If the record has a field that wasn't present in the previous records,
insert as many blank fillers as the number of previous records.
Note: in the first iteration this will just write in 'results' an empty list for each field.
'''
for key in unnested_record:
if key not in list(results.keys()):
results[key] = [blank_filler] * row_ix
Expand Down Expand Up @@ -72,7 +74,7 @@ def _determine_object(query):
'''Parse a query to determine the Salesforce object'''

if " from " not in query.lower():
raise Exception(f"'from' not found in query '{query.lower()}'")
raise Exception(f"'FROM' statement not found in query '{query.lower()}'")

return re.split(" from ", query, flags=re.IGNORECASE)[1].split(" ")[0]

Expand Down Expand Up @@ -142,12 +144,24 @@ def _smart_query(
filter_values: [object] = None,
not_in: bool = False
):
'''
Parameters
query: str,
show_progress: [bool] (default: True),
filter_field: str (default: None) : The field in 'WHERE -field- IN -values-',
filter_values: [object] (default: None) : The values in 'WHERE -field- IN -values-',
not_in: bool (default: False): If true, filter becomes 'WHERE -field- NOT IN -values-
Returns
pd.DataFrame
'''

# Make the query inline to ease parsing
query = " ".join(line.strip() for line in query.splitlines())

# Determine the Salesforce object by parsing the query
object = getattr(self.bulk, _determine_object(query))
object_name = _determine_object(query)
object = getattr(self.bulk, object_name)

# Determine the fields by parsing the query
fields = _determine_fields(query)
Expand All @@ -158,7 +172,7 @@ def _smart_query(
sub_queries = [query]

dfs = []
for sub_query in tqdm(sub_queries):
for sub_query in tqdm(sub_queries) if len(sub_queries) > 1 else sub_queries:
results = _unnest_query_output(object.query(sub_query))
partial_df = pd.DataFrame(results)
dfs.append(partial_df)
Expand All @@ -171,10 +185,13 @@ def _smart_query(
else:
pass

# Remove any unrequested columns that were returned
# This can happen for example in the query "SELECT Account.Id FROM Contact", infact
# if contacts that are not linked to an account are present, the additional column "Account" is returned.
unrequested_columns = [c for c in output_df.columns if c not in fields]
'''
Remove any unrequested columns that were returned, accounting for case insensitivitiy and optional object name prefix
This can happen for example in the query "SELECT Account.Id FROM Contact", infact
if contacts that are not linked to an account are present, the additional column "Account" is returned.
'''
fields_lower = [f.lower() for f in fields]
unrequested_columns = [c for c in output_df.columns if c.lower() not in fields_lower and object_name.lower() + "." + c.lower() not in fields_lower]
output_df.drop(columns=unrequested_columns, inplace=True)

return output_df
Expand Down

0 comments on commit c3ca5ae

Please sign in to comment.