Skip to content

Commit

Permalink
FIX: Refactor prepare_name_pairs_pd to pass arguments to create_posit…
Browse files Browse the repository at this point in the history
…ive_negative_samples (#32)

* passed correct_col to prepare_name pairs and replaced hardcoded column names

* passed uid_col to prepare_name pairs and changed hardcodes mentions of it

* added docstring for correct_col in prepare name pairs

* passed more columns to prepare name pairs and replaced their corresponding hardcoded values

* passed the new columns also to spark version of training name pairs

* added branch to test.yml

* passed positive_set_col to create_negative_name_pairs

* removed branch from test.yml
  • Loading branch information
chrispyl authored Dec 1, 2024
1 parent e7f5658 commit 3f20da8
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 9 deletions.
29 changes: 20 additions & 9 deletions emm/data/prepare_name_pairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ def prepare_name_pairs_pd(
entity_id_col="entity_id",
gt_entity_id_col="gt_entity_id",
positive_set_col="positive_set",
correct_col="correct",
uid_col="uid",
gt_uid_col="gt_uid",
preprocessed_col="preprocessed",
gt_preprocessed_col="gt_preprocessed",
random_seed=42,
):
"""Prepare dataset of name-pair candidates for training of supervised model.
Expand Down Expand Up @@ -70,7 +74,12 @@ def prepare_name_pairs_pd(
For matching name-pairs entity_id == gt_entity_id.
positive_set_col: column that specifies which candidates remain positive and which become negative,
default is "positive_set".
correct_col: column that indicates a correct match, default is "correct".
For entity_id == gt_entity_id the column value is "correct".
uid_col: uid column for names to match, default is "uid".
gt_uid_col: uid column of ground-truth names, default is "gt_uid".
preprocessed_col: name of the preprocessed names column, default is "preprocessed".
gt_preprocessed_col: name of the preprocessed ground-truth names column, default is "gt_preprocessed".
random_seed: random seed for selection of negative names, default is 42.
"""
"""We can have the following dataset.columns, or much more like 'count', 'counterparty_account_count_distinct', 'type1_sum':
Expand All @@ -84,7 +93,7 @@ def prepare_name_pairs_pd(
assert entity_id_col in candidates_pd.columns
assert gt_entity_id_col in candidates_pd.columns

candidates_pd["correct"] = candidates_pd[entity_id_col] == candidates_pd[gt_entity_id_col]
candidates_pd[correct_col] = candidates_pd[entity_id_col] == candidates_pd[gt_entity_id_col]

# negative sample creation?
# if so, add positive_set_col column for negative sample creation
Expand All @@ -110,14 +119,14 @@ def prepare_name_pairs_pd(
# - happens with one correct/positive case, we just pick the correct one
if drop_duplicate_candidates:
candidates_pd = candidates_pd.sort_values(
["uid", "gt_preprocessed", "correct"], ascending=False
).drop_duplicates(subset=["uid", "gt_preprocessed"], keep="first")
[uid_col, gt_preprocessed_col, correct_col], ascending=False
).drop_duplicates(subset=[uid_col, gt_preprocessed_col], keep="first")
# Similar, for a training set remove all equal names that are not considered a match.
# This can happen a lot in actual data, e.g. with franchises that are independent but have the same name.
# It's a true effect in data, but this screws up our intuitive notion that identical names should be related.
if drop_samename_nomatch:
samename_nomatch = (candidates_pd["preprocessed"] == candidates_pd["gt_preprocessed"]) & ~candidates_pd[
"correct"
samename_nomatch = (candidates_pd[preprocessed_col] == candidates_pd[gt_preprocessed_col]) & ~candidates_pd[
correct_col
]
candidates_pd = candidates_pd[~samename_nomatch]

Expand All @@ -133,7 +142,9 @@ def prepare_name_pairs_pd(
# is referred to in: resources/data/howto_create_unittest_sample_namepairs.txt
# create negative sample and rerank negative candidates
# this drops, in part, the negative correct candidates
candidates_pd = create_positive_negative_samples(candidates_pd)
candidates_pd = create_positive_negative_samples(
candidates_pd, uid_col=uid_col, correct_col=correct_col, positive_set_col=positive_set_col
)

# It could be that we dropped all candidates, so we need to re-introduce the no-candidate rows
names_to_match_after = candidates_pd[names_to_match_cols].drop_duplicates()
Expand All @@ -142,12 +153,12 @@ def prepare_name_pairs_pd(
)
names_to_match_missing = names_to_match_missing[names_to_match_missing["_merge"] == "left_only"]
names_to_match_missing = names_to_match_missing.drop(columns=["_merge"])
names_to_match_missing["correct"] = False
names_to_match_missing[correct_col] = False
# Since this column is used to calculate benchmark metrics
names_to_match_missing["score_0_rank"] = 1

candidates_pd = pd.concat([candidates_pd, names_to_match_missing], ignore_index=True)
candidates_pd["gt_preprocessed"] = candidates_pd["gt_preprocessed"].fillna("")
candidates_pd["no_candidate"] = candidates_pd["gt_uid"].isnull()
candidates_pd[gt_preprocessed_col] = candidates_pd[gt_preprocessed_col].fillna("")
candidates_pd["no_candidate"] = candidates_pd[gt_uid_col].isnull()

return candidates_pd
5 changes: 5 additions & 0 deletions emm/pipeline/pandas_entity_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,11 @@ def create_training_name_pairs(
else drop_duplicate_candidates,
create_negative_sample_fraction=create_negative_sample_fraction,
positive_set_col=self.parameters.get("positive_set_col", "positive_set"),
correct_col=self.parameters.get("correct_col", "correct"),
uid_col=self.parameters.get("uid_col", "uid"),
gt_uid_col=self.parameters.get("gt_uid_col", "gt_uid"),
preprocessed_col=self.parameters.get("preprocessed_col", "preprocessed"),
gt_preprocessed_col=self.parameters.get("gt_preprocessed_col", "gt_preprocessed"),
random_seed=random_seed,
**kwargs,
)
Expand Down
5 changes: 5 additions & 0 deletions emm/pipeline/spark_entity_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,11 @@ def create_training_name_pairs(
else drop_duplicate_candidates,
create_negative_sample_fraction=create_negative_sample_fraction,
positive_set_col=self.parameters.get("positive_set_col", "positive_set"),
correct_col=self.parameters.get("correct_col", "correct"),
uid_col=self.parameters.get("uid_col", "uid"),
gt_uid_col=self.parameters.get("gt_uid_col", "gt_uid"),
preprocessed_col=self.parameters.get("preprocessed_col", "preprocessed"),
gt_preprocessed_col=self.parameters.get("gt_preprocessed_col", "gt_preprocessed"),
random_seed=random_seed,
**kwargs,
)
Expand Down

0 comments on commit 3f20da8

Please sign in to comment.