Skip to content

Commit

Permalink
fix: IndexError with columns full of NaNs
Browse files Browse the repository at this point in the history
  • Loading branch information
lecardozo committed Oct 18, 2023
1 parent 9f0ba33 commit 3c1450b
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
5 changes: 4 additions & 1 deletion nvtabular/ops/categorify.py
Original file line number Diff line number Diff line change
Expand Up @@ -1704,7 +1704,10 @@ def _encode(
codes = type(df)({"order": dispatch.arange(len(df), like_df=df)}, index=df.index)

for cl, cr in zip(selection_l.names, selection_r.names):
if isinstance(df[cl].dropna().iloc[0], (np.ndarray, list)):
column_without_nans = df[cl].dropna()
if len(column_without_nans) and isinstance(
column_without_nans.iloc[0], (np.ndarray, list)
):
ser = df[cl].copy()
codes[cl] = dispatch.flatten_list_column_values(ser).astype(value[cr].dtype)
else:
Expand Down
14 changes: 14 additions & 0 deletions tests/unit/ops/test_categorify.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,3 +734,17 @@ def test_categorify_inference():
output_tensors = inference_op.transform(cats.input_columns, input_tensors)
for key in input_tensors:
assert output_tensors[key].dtype == np.dtype("int64")


def test_categorify_transform_only_nans_column():
train_df = make_df({"cat_column": ["a", "a", "b", "c", np.nan]})
cat_features = ["cat_column"] >> nvt.ops.Categorify(max_size=4)
train_dataset = nvt.Dataset(train_df)

workflow = nvt.Workflow(cat_features)
workflow.fit(train_dataset)

inference_df = make_df({"cat_column": [np.nan] * 10})
inference_dataset = nvt.Dataset(inference_df)

workflow.transform(inference_dataset).compute()

0 comments on commit 3c1450b

Please sign in to comment.