From 3f20da8af282c71f7e9a64eafd72048f7702973d Mon Sep 17 00:00:00 2001 From: Christos Pylianidis Date: Sun, 1 Dec 2024 08:58:18 +0100 Subject: [PATCH] FIX: Refactor prepare_name_pairs_pd to pass arguments to create_positive_negative_samples (#32) * passed correct_col to prepare_name pairs and replaced hardcoded column names * passed uid_col to prepare_name pairs and changed hardcodes mentions of it * added docstring for correct_col in prepare name pairs * passed more columns to prepare name pairs and replaced their corresponding hardcoded values * passed the new columns also to spark version of training name pairs * added branch to test.yml * passed positive_set_col to create_negative_name_pairs * removed branch from test.yml --- emm/data/prepare_name_pairs.py | 29 ++++++++++++++++++-------- emm/pipeline/pandas_entity_matching.py | 5 +++++ emm/pipeline/spark_entity_matching.py | 5 +++++ 3 files changed, 30 insertions(+), 9 deletions(-) diff --git a/emm/data/prepare_name_pairs.py b/emm/data/prepare_name_pairs.py index 0c0b29f..f953295 100644 --- a/emm/data/prepare_name_pairs.py +++ b/emm/data/prepare_name_pairs.py @@ -39,7 +39,11 @@ def prepare_name_pairs_pd( entity_id_col="entity_id", gt_entity_id_col="gt_entity_id", positive_set_col="positive_set", + correct_col="correct", uid_col="uid", + gt_uid_col="gt_uid", + preprocessed_col="preprocessed", + gt_preprocessed_col="gt_preprocessed", random_seed=42, ): """Prepare dataset of name-pair candidates for training of supervised model. @@ -70,7 +74,12 @@ def prepare_name_pairs_pd( For matching name-pairs entity_id == gt_entity_id. positive_set_col: column that specifies which candidates remain positive and which become negative, default is "positive_set". + correct_col: column that indicates a correct match, default is "correct". + For entity_id == gt_entity_id the column value is "correct". uid_col: uid column for names to match, default is "uid". + gt_uid_col: uid column of ground-truth names, default is "gt_uid". + preprocessed_col: name of the preprocessed names column, default is "preprocessed". + gt_preprocessed_col: name of the preprocessed ground-truth names column, default is "gt_preprocessed". random_seed: random seed for selection of negative names, default is 42. """ """We can have the following dataset.columns, or much more like 'count', 'counterparty_account_count_distinct', 'type1_sum': @@ -84,7 +93,7 @@ def prepare_name_pairs_pd( assert entity_id_col in candidates_pd.columns assert gt_entity_id_col in candidates_pd.columns - candidates_pd["correct"] = candidates_pd[entity_id_col] == candidates_pd[gt_entity_id_col] + candidates_pd[correct_col] = candidates_pd[entity_id_col] == candidates_pd[gt_entity_id_col] # negative sample creation? # if so, add positive_set_col column for negative sample creation @@ -110,14 +119,14 @@ def prepare_name_pairs_pd( # - happens with one correct/positive case, we just pick the correct one if drop_duplicate_candidates: candidates_pd = candidates_pd.sort_values( - ["uid", "gt_preprocessed", "correct"], ascending=False - ).drop_duplicates(subset=["uid", "gt_preprocessed"], keep="first") + [uid_col, gt_preprocessed_col, correct_col], ascending=False + ).drop_duplicates(subset=[uid_col, gt_preprocessed_col], keep="first") # Similar, for a training set remove all equal names that are not considered a match. # This can happen a lot in actual data, e.g. with franchises that are independent but have the same name. # It's a true effect in data, but this screws up our intuitive notion that identical names should be related. if drop_samename_nomatch: - samename_nomatch = (candidates_pd["preprocessed"] == candidates_pd["gt_preprocessed"]) & ~candidates_pd[ - "correct" + samename_nomatch = (candidates_pd[preprocessed_col] == candidates_pd[gt_preprocessed_col]) & ~candidates_pd[ + correct_col ] candidates_pd = candidates_pd[~samename_nomatch] @@ -133,7 +142,9 @@ def prepare_name_pairs_pd( # is referred to in: resources/data/howto_create_unittest_sample_namepairs.txt # create negative sample and rerank negative candidates # this drops, in part, the negative correct candidates - candidates_pd = create_positive_negative_samples(candidates_pd) + candidates_pd = create_positive_negative_samples( + candidates_pd, uid_col=uid_col, correct_col=correct_col, positive_set_col=positive_set_col + ) # It could be that we dropped all candidates, so we need to re-introduce the no-candidate rows names_to_match_after = candidates_pd[names_to_match_cols].drop_duplicates() @@ -142,12 +153,12 @@ def prepare_name_pairs_pd( ) names_to_match_missing = names_to_match_missing[names_to_match_missing["_merge"] == "left_only"] names_to_match_missing = names_to_match_missing.drop(columns=["_merge"]) - names_to_match_missing["correct"] = False + names_to_match_missing[correct_col] = False # Since this column is used to calculate benchmark metrics names_to_match_missing["score_0_rank"] = 1 candidates_pd = pd.concat([candidates_pd, names_to_match_missing], ignore_index=True) - candidates_pd["gt_preprocessed"] = candidates_pd["gt_preprocessed"].fillna("") - candidates_pd["no_candidate"] = candidates_pd["gt_uid"].isnull() + candidates_pd[gt_preprocessed_col] = candidates_pd[gt_preprocessed_col].fillna("") + candidates_pd["no_candidate"] = candidates_pd[gt_uid_col].isnull() return candidates_pd diff --git a/emm/pipeline/pandas_entity_matching.py b/emm/pipeline/pandas_entity_matching.py index 190cbdf..1d4c591 100644 --- a/emm/pipeline/pandas_entity_matching.py +++ b/emm/pipeline/pandas_entity_matching.py @@ -384,6 +384,11 @@ def create_training_name_pairs( else drop_duplicate_candidates, create_negative_sample_fraction=create_negative_sample_fraction, positive_set_col=self.parameters.get("positive_set_col", "positive_set"), + correct_col=self.parameters.get("correct_col", "correct"), + uid_col=self.parameters.get("uid_col", "uid"), + gt_uid_col=self.parameters.get("gt_uid_col", "gt_uid"), + preprocessed_col=self.parameters.get("preprocessed_col", "preprocessed"), + gt_preprocessed_col=self.parameters.get("gt_preprocessed_col", "gt_preprocessed"), random_seed=random_seed, **kwargs, ) diff --git a/emm/pipeline/spark_entity_matching.py b/emm/pipeline/spark_entity_matching.py index 0df6565..fce8a3d 100644 --- a/emm/pipeline/spark_entity_matching.py +++ b/emm/pipeline/spark_entity_matching.py @@ -412,6 +412,11 @@ def create_training_name_pairs( else drop_duplicate_candidates, create_negative_sample_fraction=create_negative_sample_fraction, positive_set_col=self.parameters.get("positive_set_col", "positive_set"), + correct_col=self.parameters.get("correct_col", "correct"), + uid_col=self.parameters.get("uid_col", "uid"), + gt_uid_col=self.parameters.get("gt_uid_col", "gt_uid"), + preprocessed_col=self.parameters.get("preprocessed_col", "preprocessed"), + gt_preprocessed_col=self.parameters.get("gt_preprocessed_col", "gt_preprocessed"), random_seed=random_seed, **kwargs, )