Skip to content

Commit

Permalink
FIX: ensure compatibility between sklearn and spark tfidf vectors for…
Browse files Browse the repository at this point in the history
… skl>=1.5

Change for TfidfTransformer of sklearn v1.5 in order to ensure compatibility between the
pandas and spark version of emm. In sklearn v1.5+ TfidfTransformer no longer has the _idf_diag attribute,
needed for setting the compatibility.
  • Loading branch information
mbaak committed Sep 5, 2024
1 parent 338b410 commit 80a5683
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
20 changes: 16 additions & 4 deletions emm/indexing/pandas_normalized_tfidf.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,25 @@ def fit(self, X: pd.Series | pd.DataFrame) -> TfidfVectorizer:
super().fit(X)

timer.label("normalize")
idf_diag = self._tfidf._idf_diag
n_features = self.idf_.shape[0]

# 1. this max_idf_square value is used in normalization step for simulating out-of-vocabulary tokens
idf_diag = scipy.sparse.diags(
self.idf_, offsets=0, shape=(n_features, n_features), format="csr", dtype=self.dtype
)
idf_diag = idf_diag - scipy.sparse.diags(np.ones(idf_diag.shape[0]), shape=idf_diag.shape, dtype=self.dtype)
self._tfidf._idf_diag = idf_diag
assert self._tfidf._idf_diag.dtype == self.dtype
# this value is used in normalization step for simulating out-of-vocabulary tokens
self.max_idf_square = idf_diag.max() ** 2

# 2. ensure compatibility between sklearn and spark tfidf vectors
if hasattr(self._tfidf, "_idf_diag"):
# sklearn < 1.5
self._tfidf._idf_diag = idf_diag
assert self._tfidf._idf_diag.dtype == self.dtype
else:
# sklearn >= 1.5
self.idf_ = self.idf_ - np.ones(n_features, dtype=self.dtype)
assert self.idf_.dtype == self.dtype

timer.log_params({"n": len(X), "n_features": idf_diag.shape[0]})

return self
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies = [
# Fix for error ValueError: numpy.ndarray size changed, may indicate binary incompatibility. Expected 96 from C header, got 88 from PyObject.
"numpy>=1.20.1",
"scipy",
"scikit-learn<1.5.0",
"scikit-learn>=1.0.0",
"pandas>=1.1.0,!=1.5.0",
"jinja2", # for pandas https://pandas.pydata.org/docs/getting_started/install.html#visualization
"rapidfuzz<3.0.0",
Expand Down

0 comments on commit 80a5683

Please sign in to comment.