diff --git a/src/predictions/profiles_mlcorelib/connectors/CommonWarehouseConnector.py b/src/predictions/profiles_mlcorelib/connectors/CommonWarehouseConnector.py index 3be868b52..551bff6ca 100644 --- a/src/predictions/profiles_mlcorelib/connectors/CommonWarehouseConnector.py +++ b/src/predictions/profiles_mlcorelib/connectors/CommonWarehouseConnector.py @@ -893,7 +893,6 @@ def safe_parse_json(entry): def generate_type_hint( self, df: pd.DataFrame, - column_types: Dict[str, List[str]], ): return None diff --git a/src/predictions/profiles_mlcorelib/connectors/Connector.py b/src/predictions/profiles_mlcorelib/connectors/Connector.py index e778f4f8b..30b589a09 100644 --- a/src/predictions/profiles_mlcorelib/connectors/Connector.py +++ b/src/predictions/profiles_mlcorelib/connectors/Connector.py @@ -468,7 +468,7 @@ def get_material_registry_table( pass @abstractmethod - def generate_type_hint(self, df: Any, column_types: Dict[str, List[str]]): + def generate_type_hint(self, df: Any): pass @abstractmethod diff --git a/src/predictions/profiles_mlcorelib/connectors/SnowflakeConnector.py b/src/predictions/profiles_mlcorelib/connectors/SnowflakeConnector.py index fe8436f5d..702f24493 100644 --- a/src/predictions/profiles_mlcorelib/connectors/SnowflakeConnector.py +++ b/src/predictions/profiles_mlcorelib/connectors/SnowflakeConnector.py @@ -1078,17 +1078,11 @@ def get_material_registry_table( def generate_type_hint( self, df: snowflake.snowpark.Table, - column_types: Dict[str, List[str]], ): types = [] - - for col in df.columns: - dtype_str = column_types[col] - for category, mapping in self.data_type_mapping.items(): - if dtype_str in mapping: - types.append(mapping[dtype_str]) - break - + schema_fields = df.schema.fields + for field in schema_fields: + types.append(field.datatype) return types def call_prediction_udf( @@ -1129,8 +1123,12 @@ def call_prediction_udf( preds = preds.withColumn("columnindex", F.row_number().over(w)) extracted_df = extracted_df.withColumn("columnindex", F.row_number().over(w)) preds = preds.join( - extracted_df, preds.columnindex == extracted_df.columnindex, "inner" - ).drop("columnindex") + extracted_df, + preds.columnindex == extracted_df.columnindex, + "inner", + lsuffix="_left", + rsuffix="_right", + ).drop("columnindex_left", "columnindex_right") # Remove the dummy label column in case of Regression if "label" not in pred_output_df_columns: @@ -1138,7 +1136,7 @@ def call_prediction_udf( preds_with_percentile = preds.withColumn( percentile_column_name, - F.percent_rank().over(Window.orderBy(F.col(score_column_name))), + (F.percent_rank().over(Window.orderBy(F.col(score_column_name)))) * 100, ) return preds_with_percentile diff --git a/src/predictions/profiles_mlcorelib/ml_core/preprocess_and_predict.py b/src/predictions/profiles_mlcorelib/ml_core/preprocess_and_predict.py index a26cb653f..ec6dfa327 100644 --- a/src/predictions/profiles_mlcorelib/ml_core/preprocess_and_predict.py +++ b/src/predictions/profiles_mlcorelib/ml_core/preprocess_and_predict.py @@ -123,9 +123,7 @@ def preprocess_and_predict( input_df = connector.select_relevant_columns( predict_data, required_features_upper_case ) - types = connector.generate_type_hint( - input_df, results["column_names"]["feature_table_column_types"] - ) + types = connector.generate_type_hint(input_df) predict_data = connector.add_index_timestamp_colum_for_predict_data( predict_data, trainer.index_timestamp, end_ts diff --git a/tests/unit/SnowflakeConnector.py b/tests/unit/SnowflakeConnector.py index e5586a586..6a0a7f44f 100644 --- a/tests/unit/SnowflakeConnector.py +++ b/tests/unit/SnowflakeConnector.py @@ -63,44 +63,6 @@ def test_label_table_does_not_change_label_value_for_regression(self): self.assertListEqual(actual_label_col_vals, expected_label_col_vals) -class TestGenerateTypeHint(unittest.TestCase): - def setUp(self) -> None: - self.session = Session.builder.config("local_testing", True).create() - self.connector = MockSnowflakeConnector() - - # Returns a list of type hints for given pandas DataFrame's fields - def test_returns_type_hints(self): - df = pd.DataFrame.from_dict( - { - "COL1": ["a", "b"], - "COL2": [1, 2], - "COL3": [1.1, 2.2], - "COL4": ["a1", "b1"], - } - ) - table = self.session.create_dataframe(df) - column_types = { - "COL1": "StringType", - "COL2": "IntegerType", - "COL3": "FloatType", - "COL4": "StringType", - } - - type_hints = self.connector.generate_type_hint(table, column_types) - self.assertEqual( - type_hints, [T.StringType(), T.IntegerType(), T.FloatType(), T.StringType()] - ) - - # Handles DataFrame with single row and column - def test_handles_single_row_and_column(self): - df = pd.DataFrame({"COL1": [1]}) - table = self.session.create_dataframe(df) - column_types = {"COL1": "IntegerType"} - - type_hints = self.connector.generate_type_hint(table, column_types) - self.assertEqual(type_hints, [T.IntegerType()]) - - class TestSelectRelevantColumns(unittest.TestCase): def setUp(self) -> None: self.connector = MockSnowflakeConnector()