Skip to content
This repository has been archived by the owner on Jan 2, 2024. It is now read-only.

Feature/#336 - Filter numpy array exposed type #768

Merged
merged 11 commits into from
Nov 3, 2023
3 changes: 2 additions & 1 deletion src/taipy/core/data/_abstract_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Dict, List, Optional, Set

import modin.pandas as modin_pd
import numpy as np
import pandas as pd
from sqlalchemy import create_engine, text

Expand Down Expand Up @@ -214,7 +215,7 @@ def _read_as(self):
query_result = connection.execute(text(self._get_read_query()))
return list(map(lambda row: custom_class(**row), query_result))

def _read_as_numpy(self):
def _read_as_numpy(self) -> np.ndarray:
return self._read_as_pandas_dataframe().to_numpy()

def _read_as_pandas_dataframe(self, columns: Optional[List[str]] = None):
Expand Down
339 changes: 208 additions & 131 deletions src/taipy/core/data/_filter.py

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion src/taipy/core/data/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Any, Dict, List, Optional, Set

import modin.pandas as modin_pd
import numpy as np
import pandas as pd

from taipy.config.common.scope import Scope
Expand Down Expand Up @@ -198,7 +199,7 @@ def _read_as(self):
res.append(custom_class(*line))
return res

def _read_as_numpy(self):
def _read_as_numpy(self) -> np.ndarray:
return self._read_as_pandas_dataframe().to_numpy()

def _read_as_pandas_dataframe(
Expand Down
103 changes: 10 additions & 93 deletions src/taipy/core/data/data_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import uuid
from abc import abstractmethod
from datetime import datetime, timedelta
from functools import reduce
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import modin.pandas as modin_pd
Expand All @@ -35,7 +34,7 @@
from ..job.job_id import JobId
from ._filter import _FilterDataNode
from .data_node_id import DataNodeId, Edit
from .operator import JoinOperator, Operator
from .operator import JoinOperator


class DataNode(_Entity, _Labeled):
Expand Down Expand Up @@ -423,103 +422,24 @@ def filter(self, operators: Union[List, Tuple], join_operator=JoinOperator.AND):
The data is filtered by the provided list of 3-tuples (key, value, `Operator^`).

If multiple filter operators are provided, filtered data will be joined based on the
join operator (_AND_ or _OR_).
join operator (*AND* or *OR*).

Parameters:
operators (Union[List[Tuple], Tuple]): A 3-element tuple or a list of 3-element tuples,
each is in the form of (key, value, `Operator^`).
join_operator (JoinOperator^): The operator used to join the multiple filter
3-tuples.
Returns:
The filtered data.
Raises:
NotImplementedError: If the data type is not supported.
"""
data = self._read()
if len(operators) == 0:
return data
if not ((type(operators[0]) == list) or (type(operators[0]) == tuple)):
if isinstance(data, (pd.DataFrame, modin_pd.DataFrame)):
return DataNode.__filter_dataframe_per_key_value(data, operators[0], operators[1], operators[2])
if isinstance(data, List):
return DataNode.__filter_list_per_key_value(data, operators[0], operators[1], operators[2])
else:
if isinstance(data, (pd.DataFrame, modin_pd.DataFrame)):
return DataNode.__filter_dataframe(data, operators, join_operator=join_operator)
if isinstance(data, List):
return DataNode.__filter_list(data, operators, join_operator=join_operator)
raise NotImplementedError

@staticmethod
def __filter_dataframe(
df_data: Union[pd.DataFrame, modin_pd.DataFrame], operators: Union[List, Tuple], join_operator=JoinOperator.AND
):
filtered_df_data = []
if join_operator == JoinOperator.AND:
how = "inner"
elif join_operator == JoinOperator.OR:
how = "outer"
else:
raise NotImplementedError
for key, value, operator in operators:
filtered_df_data.append(DataNode.__filter_dataframe_per_key_value(df_data, key, value, operator))
return DataNode.__dataframe_merge(filtered_df_data, how) if filtered_df_data else pd.DataFrame()

@staticmethod
def __filter_dataframe_per_key_value(
df_data: Union[pd.DataFrame, modin_pd.DataFrame], key: str, value, operator: Operator
):
df_by_col = df_data[key]
if operator == Operator.EQUAL:
df_by_col = df_by_col == value
if operator == Operator.NOT_EQUAL:
df_by_col = df_by_col != value
if operator == Operator.LESS_THAN:
df_by_col = df_by_col < value
if operator == Operator.LESS_OR_EQUAL:
df_by_col = df_by_col <= value
if operator == Operator.GREATER_THAN:
df_by_col = df_by_col > value
if operator == Operator.GREATER_OR_EQUAL:
df_by_col = df_by_col >= value
return df_data[df_by_col]
return _FilterDataNode._filter(data, operators, join_operator)

@staticmethod
def __dataframe_merge(df_list: List, how="inner"):
return reduce(lambda df1, df2: pd.merge(df1, df2, how=how), df_list)

@staticmethod
def __filter_list(list_data: List, operators: Union[List, Tuple], join_operator=JoinOperator.AND):
filtered_list_data = []
for key, value, operator in operators:
filtered_list_data.append(DataNode.__filter_list_per_key_value(list_data, key, value, operator))
if len(filtered_list_data) == 0:
return filtered_list_data
if join_operator == JoinOperator.AND:
return DataNode.__list_intersect(filtered_list_data)
elif join_operator == JoinOperator.OR:
return list(set(np.concatenate(filtered_list_data)))
else:
raise NotImplementedError

@staticmethod
def __filter_list_per_key_value(list_data: List, key: str, value, operator: Operator):
filtered_list = []
for row in list_data:
row_value = getattr(row, key)
if operator == Operator.EQUAL and row_value == value:
filtered_list.append(row)
if operator == Operator.NOT_EQUAL and row_value != value:
filtered_list.append(row)
if operator == Operator.LESS_THAN and row_value < value:
filtered_list.append(row)
if operator == Operator.LESS_OR_EQUAL and row_value <= value:
filtered_list.append(row)
if operator == Operator.GREATER_THAN and row_value > value:
filtered_list.append(row)
if operator == Operator.GREATER_OR_EQUAL and row_value >= value:
filtered_list.append(row)
return filtered_list

@staticmethod
def __list_intersect(list_data):
return list(set(list_data.pop()).intersection(*map(set, list_data)))
def __getitem__(self, item):
data = self._read()
return _FilterDataNode._filter_by_key(data, item)

@abstractmethod
def _read(self):
Expand All @@ -529,9 +449,6 @@ def _read(self):
def _write(self, data):
raise NotImplementedError

def __getitem__(self, items):
return _FilterDataNode(self.id, self._read())[items]

@property # type: ignore
@_self_reload(_MANAGER_NAME)
def is_ready_for_reading(self) -> bool:
Expand Down
3 changes: 2 additions & 1 deletion src/taipy/core/data/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Any, Dict, List, Optional, Set

import modin.pandas as modin_pd
import numpy as np
import pandas as pd

from taipy.config.common.scope import Scope
Expand Down Expand Up @@ -220,7 +221,7 @@ def _read_as(self, read_kwargs: Dict):
list_of_dicts = self._read_as_pandas_dataframe(read_kwargs).to_dict(orient="records")
return [custom_class(**dct) for dct in list_of_dicts]

def _read_as_numpy(self, read_kwargs: Dict):
def _read_as_numpy(self, read_kwargs: Dict) -> np.ndarray:
return self._read_as_pandas_dataframe(read_kwargs).to_numpy()

def _read_as_pandas_dataframe(self, read_kwargs: Dict) -> pd.DataFrame:
Expand Down
8 changes: 6 additions & 2 deletions src/taipy/core/data/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Set

from sqlalchemy import text

from taipy.config.common.scope import Scope

from .._version._version_manager_factory import _VersionManagerFactory
Expand Down Expand Up @@ -133,6 +135,8 @@ def _do_write(self, data, engine, connection) -> None:
queries = [queries]
for query in queries:
if isinstance(query, str):
connection.execute(query)
connection.execute(text(query))
else:
connection.execute(*query)
statement = query[0]
parameters = query[1]
connection.execute(text(statement), parameters)
3 changes: 1 addition & 2 deletions src/taipy/core/data/sql_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,8 @@ def _do_write(self, data, engine, connection) -> None:

def _create_table(self, engine) -> Table:
return Table(
self.table,
self.properties[self.__TABLE_KEY],
MetaData(),
autoload=True,
autoload_with=engine,
)

Expand Down
15 changes: 7 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.
import json

import os
import pathlib
import pickle
import shutil
from datetime import datetime
Expand Down Expand Up @@ -145,9 +144,9 @@ def tmp_sqlite_db_file_path(tmpdir_factory):
file_extension = ".db"
db = create_engine("sqlite:///" + os.path.join(fn.strpath, f"{db_name}{file_extension}"))
conn = db.connect()
conn.execute(text("CREATE TABLE example (a int, b int, c int);"))
conn.execute(text("INSERT INTO example (a, b, c) VALUES (1, 2, 3);"))
conn.execute(text("INSERT INTO example (a, b, c) VALUES (4, 5, 6);"))
conn.execute(text("CREATE TABLE foo (foo int, bar int);"))
conn.execute(text("INSERT INTO foo (foo, bar) VALUES (1, 2);"))
conn.execute(text("INSERT INTO foo (foo, bar) VALUES (3, 4);"))
conn.commit()
conn.close()
db.dispose()
Expand All @@ -163,9 +162,9 @@ def tmp_sqlite_sqlite3_file_path(tmpdir_factory):

db = create_engine("sqlite:///" + os.path.join(fn.strpath, f"{db_name}{file_extension}"))
conn = db.connect()
conn.execute(text("CREATE TABLE example (a int, b int, c int);"))
conn.execute(text("INSERT INTO example (a, b, c) VALUES (1, 2, 3);"))
conn.execute(text("INSERT INTO example (a, b, c) VALUES (4, 5, 6);"))
conn.execute(text("CREATE TABLE foo (foo int, bar int);"))
conn.execute(text("INSERT INTO foo (foo, bar) VALUES (1, 2);"))
conn.execute(text("INSERT INTO foo (foo, bar) VALUES (3, 4);"))
conn.commit()
conn.close()
db.dispose()
Expand Down
67 changes: 67 additions & 0 deletions tests/core/data/test_csv_data_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from src.taipy.core.data._data_manager import _DataManager
from src.taipy.core.data.csv import CSVDataNode
from src.taipy.core.data.data_node_id import DataNodeId
from src.taipy.core.data.operator import JoinOperator, Operator
from src.taipy.core.exceptions.exceptions import InvalidExposedType, NoData
from taipy.config.common.scope import Scope
from taipy.config.config import Config
Expand Down Expand Up @@ -302,6 +303,72 @@ def test_pandas_exposed_type(self):
dn = CSVDataNode("foo", Scope.SCENARIO, properties={"path": path, "exposed_type": "pandas"})
assert isinstance(dn.read(), pd.DataFrame)

def test_filter_pandas_exposed_type(self, csv_file):
dn = CSVDataNode("foo", Scope.SCENARIO, properties={"path": csv_file, "exposed_type": "pandas"})
dn.write(
[
{"foo": 1, "bar": 1},
{"foo": 1, "bar": 2},
{"foo": 1},
{"foo": 2, "bar": 2},
{"bar": 2},
]
)

assert len(dn.filter(("foo", 1, Operator.EQUAL))) == 3
assert len(dn.filter(("foo", 1, Operator.NOT_EQUAL))) == 2
assert len(dn.filter(("bar", 2, Operator.EQUAL))) == 3
assert len(dn.filter([("bar", 1, Operator.EQUAL), ("bar", 2, Operator.EQUAL)], JoinOperator.OR)) == 4

assert dn["foo"].equals(pd.Series([1, 1, 1, 2, None]))
assert dn["bar"].equals(pd.Series([1, 2, None, 2, 2]))
assert dn[:2].equals(pd.DataFrame([{"foo": 1.0, "bar": 1.0}, {"foo": 1.0, "bar": 2.0}]))

def test_filter_modin_exposed_type(self, csv_file):
dn = CSVDataNode("foo", Scope.SCENARIO, properties={"path": csv_file, "exposed_type": "modin"})
dn.write(
[
{"foo": 1, "bar": 1},
{"foo": 1, "bar": 2},
{"foo": 1},
{"foo": 2, "bar": 2},
{"bar": 2},
]
)

assert len(dn.filter(("foo", 1, Operator.EQUAL))) == 3
assert len(dn.filter(("foo", 1, Operator.NOT_EQUAL))) == 2
assert len(dn.filter(("bar", 2, Operator.EQUAL))) == 3
assert len(dn.filter([("bar", 1, Operator.EQUAL), ("bar", 2, Operator.EQUAL)], JoinOperator.OR)) == 4

assert dn["foo"].equals(modin_pd.Series([1, 1, 1, 2, None]))
assert dn["bar"].equals(modin_pd.Series([1, 2, None, 2, 2]))
assert dn[:2].equals(modin_pd.DataFrame([{"foo": 1.0, "bar": 1.0}, {"foo": 1.0, "bar": 2.0}]))

def test_filter_numpy_exposed_type(self, csv_file):
dn = CSVDataNode("foo", Scope.SCENARIO, properties={"path": csv_file, "exposed_type": "numpy"})
dn.write(
[
[1, 1],
[1, 2],
[1, 3],
[2, 1],
[2, 2],
[2, 3],
]
)

assert len(dn.filter((0, 1, Operator.EQUAL))) == 3
assert len(dn.filter((0, 1, Operator.NOT_EQUAL))) == 3
assert len(dn.filter((1, 2, Operator.EQUAL))) == 2
assert len(dn.filter([(0, 1, Operator.EQUAL), (1, 2, Operator.EQUAL)], JoinOperator.OR)) == 4

assert np.array_equal(dn[0], np.array([1, 1]))
assert np.array_equal(dn[1], np.array([1, 2]))
assert np.array_equal(dn[:3], np.array([[1, 1], [1, 2], [1, 3]]))
assert np.array_equal(dn[:, 0], np.array([1, 1, 1, 2, 2, 2]))
assert np.array_equal(dn[1:4, :1], np.array([[1], [1], [2]]))

def test_raise_error_invalid_exposed_type(self):
path = os.path.join(pathlib.Path(__file__).parent.resolve(), "data_sample/example.csv")
with pytest.raises(InvalidExposedType):
Expand Down
Loading
Loading