Skip to content

Commit

Permalink
fetching type_hints from table directly (#413)
Browse files Browse the repository at this point in the history
* fetching type_hints from table directly

* downgrading pycaret

* correcting output_table for snowflake
  • Loading branch information
joker2411 authored Aug 13, 2024
1 parent c769691 commit fbcb38a
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -893,7 +893,6 @@ def safe_parse_json(entry):
def generate_type_hint(
self,
df: pd.DataFrame,
column_types: Dict[str, List[str]],
):
return None

Expand Down
2 changes: 1 addition & 1 deletion src/predictions/profiles_mlcorelib/connectors/Connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def get_material_registry_table(
pass

@abstractmethod
def generate_type_hint(self, df: Any, column_types: Dict[str, List[str]]):
def generate_type_hint(self, df: Any):
pass

@abstractmethod
Expand Down
22 changes: 10 additions & 12 deletions src/predictions/profiles_mlcorelib/connectors/SnowflakeConnector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,17 +1078,11 @@ def get_material_registry_table(
def generate_type_hint(
self,
df: snowflake.snowpark.Table,
column_types: Dict[str, List[str]],
):
types = []

for col in df.columns:
dtype_str = column_types[col]
for category, mapping in self.data_type_mapping.items():
if dtype_str in mapping:
types.append(mapping[dtype_str])
break

schema_fields = df.schema.fields
for field in schema_fields:
types.append(field.datatype)
return types

def call_prediction_udf(
Expand Down Expand Up @@ -1129,16 +1123,20 @@ def call_prediction_udf(
preds = preds.withColumn("columnindex", F.row_number().over(w))
extracted_df = extracted_df.withColumn("columnindex", F.row_number().over(w))
preds = preds.join(
extracted_df, preds.columnindex == extracted_df.columnindex, "inner"
).drop("columnindex")
extracted_df,
preds.columnindex == extracted_df.columnindex,
"inner",
lsuffix="_left",
rsuffix="_right",
).drop("columnindex_left", "columnindex_right")

# Remove the dummy label column in case of Regression
if "label" not in pred_output_df_columns:
preds.drop(output_label_column)

preds_with_percentile = preds.withColumn(
percentile_column_name,
F.percent_rank().over(Window.orderBy(F.col(score_column_name))),
(F.percent_rank().over(Window.orderBy(F.col(score_column_name)))) * 100,
)

return preds_with_percentile
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,7 @@ def preprocess_and_predict(
input_df = connector.select_relevant_columns(
predict_data, required_features_upper_case
)
types = connector.generate_type_hint(
input_df, results["column_names"]["feature_table_column_types"]
)
types = connector.generate_type_hint(input_df)

predict_data = connector.add_index_timestamp_colum_for_predict_data(
predict_data, trainer.index_timestamp, end_ts
Expand Down
38 changes: 0 additions & 38 deletions tests/unit/SnowflakeConnector.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,44 +63,6 @@ def test_label_table_does_not_change_label_value_for_regression(self):
self.assertListEqual(actual_label_col_vals, expected_label_col_vals)


class TestGenerateTypeHint(unittest.TestCase):
def setUp(self) -> None:
self.session = Session.builder.config("local_testing", True).create()
self.connector = MockSnowflakeConnector()

# Returns a list of type hints for given pandas DataFrame's fields
def test_returns_type_hints(self):
df = pd.DataFrame.from_dict(
{
"COL1": ["a", "b"],
"COL2": [1, 2],
"COL3": [1.1, 2.2],
"COL4": ["a1", "b1"],
}
)
table = self.session.create_dataframe(df)
column_types = {
"COL1": "StringType",
"COL2": "IntegerType",
"COL3": "FloatType",
"COL4": "StringType",
}

type_hints = self.connector.generate_type_hint(table, column_types)
self.assertEqual(
type_hints, [T.StringType(), T.IntegerType(), T.FloatType(), T.StringType()]
)

# Handles DataFrame with single row and column
def test_handles_single_row_and_column(self):
df = pd.DataFrame({"COL1": [1]})
table = self.session.create_dataframe(df)
column_types = {"COL1": "IntegerType"}

type_hints = self.connector.generate_type_hint(table, column_types)
self.assertEqual(type_hints, [T.IntegerType()])


class TestSelectRelevantColumns(unittest.TestCase):
def setUp(self) -> None:
self.connector = MockSnowflakeConnector()
Expand Down

0 comments on commit fbcb38a

Please sign in to comment.