Skip to content

Commit

Permalink
allow user to specify their own model
Browse files Browse the repository at this point in the history
  • Loading branch information
CangyuanLi committed Nov 6, 2023
1 parent ad20f2d commit 1c117af
Showing 1 changed file with 23 additions and 10 deletions.
33 changes: 23 additions & 10 deletions src/pyethnicity/_ml_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,10 @@ def _pad_sequences(


def predict_race_fl(
first_name: Name, last_name: Name, chunksize: int = CHUNKSIZE
first_name: Name,
last_name: Name,
chunksize: int = CHUNKSIZE,
_model: onnxruntime.InferenceSession = None,
) -> pd.DataFrame:
"""Predict race from first and last name.
Expand Down Expand Up @@ -189,13 +192,14 @@ def predict_race_fl(
]
X = _pad_sequences(X, maxlen=30).astype(np.float32)

model = MODEL_LOADER.load("first_last")
if _model is None:
_model = MODEL_LOADER.load("first_last")

input_name = model.get_inputs()[0].name
input_name = _model.get_inputs()[0].name

y_pred = []
for input_ in tqdm.tqdm(cutils.chunk_seq(X, chunksize)):
y_pred.extend(model.run(None, input_feed={input_name: input_})[0])
y_pred.extend(_model.run(None, input_feed={input_name: input_})[0])

preds: dict[str, list] = {r: [] for r in RACES}
for row in y_pred:
Expand All @@ -217,6 +221,7 @@ def predict_race_flg(
geography: Geography,
geo_type: GeoType,
chunksize: int = CHUNKSIZE,
_model: onnxruntime.InferenceSession = None,
) -> pd.DataFrame:
r"""Predict race from first name, last name, and geography. The output from
pyethnicity.predict_race_fl is combined with geography using Naive Bayes:
Expand Down Expand Up @@ -266,7 +271,7 @@ def predict_race_flg(
>>> geography=[11106, 27106], geo_type="zcta"
>>> )
"""
fl_preds = predict_race_fl(first_name, last_name, chunksize)
fl_preds = predict_race_fl(first_name, last_name, chunksize, _model)

return _bng(pl.from_pandas(fl_preds), geography, geo_type)

Expand All @@ -277,6 +282,7 @@ def predict_race(
geography: Geography,
geo_type: GeoType,
chunksize: int = CHUNKSIZE,
_model: onnxruntime.InferenceSession = None,
) -> pd.DataFrame:
"""Predict race from first name, last name, and geography. The output from
pyethnicity.predict_race_flg is ensembled with pyethnicty.bisg and pyethnicty.bifsg.
Expand Down Expand Up @@ -321,7 +327,9 @@ def predict_race(
>>> geography=[11106, 27106], geo_type="zcta"
>>> )
"""
flz = predict_race_flg(first_name, last_name, geography, geo_type, chunksize)
flz = predict_race_flg(
first_name, last_name, geography, geo_type, chunksize, _model
)
bifsg_ = bifsg(first_name, last_name, geography, geo_type)
bisg_ = bisg(last_name, geography, geo_type)

Expand Down Expand Up @@ -353,7 +361,11 @@ def predict_race(
return df


def predict_sex_f(first_name: Name, chunksize: int = CHUNKSIZE) -> pd.DataFrame:
def predict_sex_f(
first_name: Name,
chunksize: int = CHUNKSIZE,
_model: onnxruntime.InferenceSession = None,
) -> pd.DataFrame:
"""Predict sex from first name.
Parameters
Expand All @@ -379,14 +391,15 @@ def predict_sex_f(first_name: Name, chunksize: int = CHUNKSIZE) -> pd.DataFrame:
X = [_encode_name(fn, mapper=VALID_NAME_CHARS_DICT) for fn in first_name_cleaned]
X = _pad_sequences(X, maxlen=15).astype(np.float32)

model = MODEL_LOADER.load("first_sex")
if _model is None:
_model = MODEL_LOADER.load("first_sex")

input_name = model.get_inputs()[0].name
input_name = _model.get_inputs()[0].name

with tqdm.tqdm(total=len(X)) as pbar:
y_pred = []
for input_ in cutils.chunk_seq(X, chunksize):
y_pred.extend(model.run(None, input_feed={input_name: input_})[0])
y_pred.extend(_model.run(None, input_feed={input_name: input_})[0])
pbar.update(len(input_))

pct_male = [row[0] for row in y_pred]
Expand Down

0 comments on commit 1c117af

Please sign in to comment.