diff --git a/xgboost_ray/data_sources/data_source.py b/xgboost_ray/data_sources/data_source.py index 774bf6c..c9bcfc7 100644 --- a/xgboost_ray/data_sources/data_source.py +++ b/xgboost_ray/data_sources/data_source.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union import pandas as pd from ray.actor import ActorHandle @@ -118,12 +118,12 @@ def convert_to_series(data: Any) -> pd.Series: @classmethod def get_column( cls, data: pd.DataFrame, column: Any - ) -> Tuple[pd.Series, Optional[str]]: + ) -> Tuple[pd.Series, Optional[Union[str, List]]]: """Helper method wrapping around convert to series. This method should usually not be overwritten. """ - if isinstance(column, str): + if isinstance(column, str) or isinstance(column, List): return data[column], column elif column is not None: return cls.convert_to_series(column), None diff --git a/xgboost_ray/matrix.py b/xgboost_ray/matrix.py index f3f7523..fe12b5f 100644 --- a/xgboost_ray/matrix.py +++ b/xgboost_ray/matrix.py @@ -307,7 +307,10 @@ def _split_dataframe( label, exclude = data_source.get_column(local_data, self.label) if exclude: - exclude_cols.add(exclude) + if isinstance(exclude, List): + exclude_cols.update(exclude) + else: + exclude_cols.add(exclude) weight, exclude = data_source.get_column(local_data, self.weight) if exclude: @@ -406,7 +409,11 @@ def get_data_source(self) -> Type[DataSource]: ): # noqa: E721: # Label is an object of a different type than the main data. # We have to make sure they are compatible - if not data_source.is_data_type(self.label): + # if it's a parquet data source and label is a list, + # then we consider it a multi-label data + if not data_source.is_data_type(self.label) and not ( + isinstance(self.label, List) and data_source.__name__ == "Parquet" + ): raise ValueError( "The passed `data` and `label` types are not compatible." "\nFIX THIS by passing the same types to the " @@ -521,7 +528,11 @@ def get_data_source(self) -> Type[DataSource]: f"RayDMatrix." ) - if self.label is not None and not isinstance(self.label, str): + if ( + self.label is not None + and not isinstance(self.label, str) + and not isinstance(self.label, List) + ): raise ValueError( f"Invalid `label` value for distributed datasets: " f"{self.label}. Only strings are supported. " diff --git a/xgboost_ray/tests/test_matrix.py b/xgboost_ray/tests/test_matrix.py index 6c76449..ac1e3cb 100644 --- a/xgboost_ray/tests/test_matrix.py +++ b/xgboost_ray/tests/test_matrix.py @@ -33,6 +33,15 @@ def setUp(self): * repeat ) self.y = np.array([0, 1, 2, 3] * repeat) + self.multi_y = np.array( + [ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 1], + [0, 0, 1, 0], + ] + * repeat + ) @classmethod def setUpClass(cls): @@ -62,7 +71,7 @@ def testColumnOrdering(self): assert data.columns.tolist() == cols[:-1] - def _testMatrixCreation(self, in_x, in_y, **kwargs): + def _testMatrixCreation(self, in_x, in_y, multi_label=False, **kwargs): if "sharding" not in kwargs: kwargs["sharding"] = RayShardingMode.BATCH mat = RayDMatrix(in_x, in_y, **kwargs) @@ -81,7 +90,10 @@ def _load_data(params): x, y = _load_data(params) self.assertTrue(np.allclose(self.x, x)) - self.assertTrue(np.allclose(self.y, y)) + if multi_label: + self.assertTrue(np.allclose(self.multi_y, y)) + else: + self.assertTrue(np.allclose(self.y, y)) # Multi actor check mat = RayDMatrix(in_x, in_y, **kwargs) @@ -95,7 +107,10 @@ def _load_data(params): x2, y2 = _load_data(params) self.assertTrue(np.allclose(self.x, concat_dataframes([x1, x2]))) - self.assertTrue(np.allclose(self.y, concat_dataframes([y1, y2]))) + if multi_label: + self.assertTrue(np.allclose(self.multi_y, concat_dataframes([y1, y2]))) + else: + self.assertTrue(np.allclose(self.y, concat_dataframes([y1, y2]))) def testFromNumpy(self): in_x = self.x @@ -276,6 +291,22 @@ def testFromMultiCSVString(self): [data_file_1, data_file_2], "label", distributed=True ) + def testFromParquetStringMultiLabel(self): + with tempfile.TemporaryDirectory() as dir: + data_file = os.path.join(dir, "data.parquet") + + data_df = pd.DataFrame(self.x, columns=["a", "b", "c", "d"]) + labels = [f"label_{label}" for label in range(4)] + data_df[labels] = self.multi_y + data_df.to_parquet(data_file) + + self._testMatrixCreation( + data_file, labels, multi_label=True, distributed=False + ) + self._testMatrixCreation( + data_file, labels, multi_label=True, distributed=True + ) + def testFromParquetString(self): with tempfile.TemporaryDirectory() as dir: data_file = os.path.join(dir, "data.parquet") @@ -287,6 +318,28 @@ def testFromParquetString(self): self._testMatrixCreation(data_file, "label", distributed=False) self._testMatrixCreation(data_file, "label", distributed=True) + def testFromMultiParquetStringMultiLabel(self): + with tempfile.TemporaryDirectory() as dir: + data_file_1 = os.path.join(dir, "data_1.parquet") + data_file_2 = os.path.join(dir, "data_2.parquet") + + data_df = pd.DataFrame(self.x, columns=["a", "b", "c", "d"]) + labels = [f"label_{label}" for label in range(4)] + data_df[labels] = self.multi_y + + df_1 = data_df[0 : len(data_df) // 2] + df_2 = data_df[len(data_df) // 2 :] + + df_1.to_parquet(data_file_1) + df_2.to_parquet(data_file_2) + + self._testMatrixCreation( + [data_file_1, data_file_2], labels, multi_label=True, distributed=False + ) + self._testMatrixCreation( + [data_file_1, data_file_2], labels, multi_label=True, distributed=True + ) + def testFromMultiParquetString(self): with tempfile.TemporaryDirectory() as dir: data_file_1 = os.path.join(dir, "data_1.parquet")