Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update create_training_name_pairs to pass additional arguments to prepare_name_pairs #29

Merged
merged 10 commits into from
Nov 15, 2024
3 changes: 3 additions & 0 deletions emm/pipeline/pandas_entity_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def create_training_name_pairs(
n_train_ids: int = -1,
random_seed: int = 42,
drop_duplicate_candidates: bool | None = None,
**kwargs,
) -> pd.DataFrame:
"""Create name-pairs for training from positive names that match to the ground truth.

Expand All @@ -333,6 +334,7 @@ def create_training_name_pairs(
drop_duplicate_candidates: if True drop any duplicate training candidates and keep just one,
if available keep the correct match. Recommended for string-similarity models, eg. with
without_rank_features=True. default is False.
kwargs: extra key-word arguments meant to be passed to prepare_name_pairs_pd.

Returns:
pandas dataframe with name-pair candidates to be used for training.
Expand Down Expand Up @@ -383,6 +385,7 @@ def create_training_name_pairs(
create_negative_sample_fraction=create_negative_sample_fraction,
positive_set_col=self.parameters.get("positive_set_col", "positive_set"),
random_seed=random_seed,
**kwargs,
)

def fit_classifier(
Expand Down
7 changes: 6 additions & 1 deletion emm/pipeline/spark_entity_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from __future__ import annotations

import re
from typing import TYPE_CHECKING, Any, Callable, Literal, Mapping
from typing import TYPE_CHECKING, Any, Literal

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -48,6 +48,8 @@
from emm.supervised_model.spark_supervised_model import SparkSupervisedLayerEstimator

if TYPE_CHECKING:
from collections.abc import Callable, Mapping

from pyspark.ml import Pipeline, PipelineModel


Expand Down Expand Up @@ -343,6 +345,7 @@ def create_training_name_pairs(
n_train_ids=-1,
random_seed=42,
drop_duplicate_candidates: bool | None = None,
**kwargs,
) -> pd.DataFrame:
"""Create name-pairs for training from positive names that match to the ground truth.

Expand All @@ -364,6 +367,7 @@ def create_training_name_pairs(
drop_duplicate_candidates: if True drop any duplicate training candidates and keep just one,
if available keep the correct match. Recommended for string-similarity models, eg. with
without_rank_features=True. default is False.
kwargs: extra key-word arguments meant to be passed to prepare_name_pairs_pd.

Returns:
pandas dataframe with name-pair candidates to be used for training.
Expand Down Expand Up @@ -409,6 +413,7 @@ def create_training_name_pairs(
create_negative_sample_fraction=create_negative_sample_fraction,
positive_set_col=self.parameters.get("positive_set_col", "positive_set"),
random_seed=random_seed,
**kwargs,
)

def fit_classifier(
Expand Down
Loading