Skip to content

Commit

Permalink
Use GroupShuffleSplit to ensure subgraphs of decision clusters are en…
Browse files Browse the repository at this point in the history
…tirely in test or training but not spread
  • Loading branch information
jbothma committed Aug 23, 2024
1 parent 8defbb4 commit 18cdb13
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 25 deletions.
11 changes: 5 additions & 6 deletions nomenklatura/matching/pairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,24 @@ class JudgedPair(object):
(or not) by a user.
"""

__slots__ = ("left", "right", "judgement")
__slots__ = ("left", "right", "judgement", "group")

def __init__(
self, left: EntityProxy, right: EntityProxy, judgement: Judgement
self, left: EntityProxy, right: EntityProxy, judgement: Judgement, group: int
) -> None:
self.left = left
self.right = right
self.judgement = judgement
self.group = group

def to_dict(self) -> Dict[str, Any]:
return {
"left": self.left.to_dict(),
"right": self.right.to_dict(),
"judgement": self.judgement.value,
"group": self.group,
}

def __hash__(self):
return hash((self.left.id, self.right.id, self.judgement.value))


def read_pairs(pairs_file: PathLike) -> Generator[JudgedPair, None, None]:
"""Read judgement pairs (training data) from a JSON file."""
Expand All @@ -44,7 +43,7 @@ def read_pairs(pairs_file: PathLike) -> Generator[JudgedPair, None, None]:
judgement = Judgement(data["judgement"])
if judgement not in (Judgement.POSITIVE, Judgement.NEGATIVE):
continue
yield JudgedPair(left_entity, right_entity, judgement)
yield JudgedPair(left_entity, right_entity, judgement, data["group"])


def read_pair_sets(pairs_file: PathLike) -> List[Set[JudgedPair]]:
Expand Down
45 changes: 26 additions & 19 deletions nomenklatura/matching/regression_v3/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from numpy.typing import NDArray
from sklearn.pipeline import make_pipeline # type: ignore
from sklearn.preprocessing import StandardScaler # type: ignore
from sklearn.model_selection import train_test_split # type: ignore
from sklearn.model_selection import GroupShuffleSplit, train_test_split # type: ignore
from sklearn.linear_model import LogisticRegression # type: ignore
from sklearn import metrics # type: ignore
from concurrent.futures import ThreadPoolExecutor

from nomenklatura.judgement import Judgement
from nomenklatura.matching.pairs import read_pair_sets, JudgedPair
from nomenklatura.matching.pairs import read_pairs, JudgedPair
from nomenklatura.matching.regression_v3.model import RegressionV3
from nomenklatura.util import PathLike

Expand Down Expand Up @@ -46,23 +46,30 @@ def pairs_to_arrays(


def train_matcher(pairs_file: PathLike) -> None:
pair_sets = read_pair_sets(pairs_file)

positive = sum([len([p for p in s if p.judgement == Judgement.POSITIVE]) for s in pair_sets])
negative = sum([len([p for p in s if p.judgement == Judgement.NEGATIVE]) for s in pair_sets])

log.info("Total pairs loaded: %d (%d pos/%d neg)", positive+negative, positive, negative)
log.info("Total independent sets loaded: %d", len(pair_sets))

train_sets, test_sets = train_test_split(pair_sets, test_size=0.33)
log.info("Training sets: %d, Test sets: %d - test is %d%%", len(train_sets), len(test_sets), 100*len(test_sets)/(len(pair_sets)))
train_pairs = [p for s in train_sets for p in s]
test_pairs = [p for s in test_sets for p in s]
log.info("Training pairs: %d, Test pairs: %d, test is %d%%", len(train_pairs), len(test_pairs), 100*len(test_pairs)/(len(train_pairs)+len(test_pairs)))

X_train, y_train = pairs_to_arrays(train_pairs)
X_test, y_test = pairs_to_arrays(test_pairs)

pairs = []
for pair in read_pairs(pairs_file):
# HACK: support more eventually:
# if not pair.left.schema.is_a("LegalEntity"):
# continue
if pair.judgement == Judgement.UNSURE:
pair.judgement = Judgement.NEGATIVE
# randomize_entity(pair.left)
# randomize_entity(pair.right)
pairs.append(pair)
# random.shuffle(pairs)
# pairs = pairs[:30000]
positive = len([p for p in pairs if p.judgement == Judgement.POSITIVE])
negative = len([p for p in pairs if p.judgement == Judgement.NEGATIVE])
log.info("Total pairs loaded: %d (%d pos/%d neg)", len(pairs), positive, negative)
X, y = pairs_to_arrays(pairs)
groups = [p.group for p in pairs]
gss = GroupShuffleSplit(test_size=.33)
train_indices, test_indices = next(gss.split(X, y, groups=groups))
X_train = [X[i] for i in train_indices]
X_test = [X[i] for i in test_indices]
y_train = [y[i] for i in train_indices]
y_test = [y[i] for i in test_indices]
# X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)
# logreg = LogisticRegression(class_weight={0: 95, 1: 1})
# logreg = LogisticRegression(penalty="l1", solver="liblinear")
logreg = LogisticRegression(penalty="l2")
Expand Down

0 comments on commit 18cdb13

Please sign in to comment.