diff --git a/tests/unit/ops/test_categorify.py b/tests/unit/ops/test_categorify.py index 2db9ef0f7b..5092bb7fa7 100644 --- a/tests/unit/ops/test_categorify.py +++ b/tests/unit/ops/test_categorify.py @@ -738,7 +738,7 @@ def test_categorify_inference(): def test_categorify_transform_only_nans_column(): train_df = make_df({"cat_column": ["a", "a", "b", "c", np.nan]}) - cat_features = ["cat_column"] >> nvt.ops.Categorify(max_size=4) + cat_features = ["cat_column"] >> nvt.ops.Categorify() train_dataset = nvt.Dataset(train_df) workflow = nvt.Workflow(cat_features) @@ -747,4 +747,5 @@ def test_categorify_transform_only_nans_column(): inference_df = make_df({"cat_column": [np.nan] * 10}) inference_dataset = nvt.Dataset(inference_df) - workflow.transform(inference_dataset).compute() + output = workflow.transform(inference_dataset).compute() + assert len(output) == len(inference_df)