Skip to content
This repository has been archived by the owner on Jan 2, 2024. It is now read-only.

Commit

Permalink
Merge pull request #820 from Avaiga/feature/#815-integrate-filter-to-…
Browse files Browse the repository at this point in the history
…db-query

Feature/#815- Filter database datanodes should not read all data
  • Loading branch information
trgiangdo authored Nov 8, 2023
2 parents e63aba0 + 0a2a587 commit 76ed789
Show file tree
Hide file tree
Showing 9 changed files with 288 additions and 167 deletions.
86 changes: 72 additions & 14 deletions src/taipy/core/data/_abstract_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import urllib.parse
from abc import abstractmethod
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Set
from typing import Dict, List, Optional, Set, Tuple, Union

import modin.pandas as modin_pd
import numpy as np
Expand All @@ -24,6 +24,7 @@
from taipy.config.common.scope import Scope

from .._version._version_manager_factory import _VersionManagerFactory
from ..data.operator import JoinOperator, Operator
from ..exceptions.exceptions import MissingRequiredProperty, UnknownDatabaseEngine
from ._abstract_tabular import _AbstractTabularDataNode
from .data_node import DataNode
Expand Down Expand Up @@ -198,6 +199,15 @@ def _conn_string(self) -> str:

raise UnknownDatabaseEngine(f"Unknown engine: {engine}")

def filter(self, operators: Optional[Union[List, Tuple]] = None, join_operator=JoinOperator.AND):
if self.properties[self.__EXPOSED_TYPE_PROPERTY] == self.__EXPOSED_TYPE_PANDAS:
return self._read_as_pandas_dataframe(operators=operators, join_operator=join_operator)
if self.properties[self.__EXPOSED_TYPE_PROPERTY] == self.__EXPOSED_TYPE_MODIN:
return self._read_as_modin_dataframe(operators=operators, join_operator=join_operator)
if self.properties[self.__EXPOSED_TYPE_PROPERTY] == self.__EXPOSED_TYPE_NUMPY:
return self._read_as_numpy(operators=operators, join_operator=join_operator)
return self._read_as(operators=operators, join_operator=join_operator)

def _read(self):
if self.properties[self.__EXPOSED_TYPE_PROPERTY] == self.__EXPOSED_TYPE_PANDAS:
return self._read_as_pandas_dataframe()
Expand All @@ -207,32 +217,76 @@ def _read(self):
return self._read_as_numpy()
return self._read_as()

def _read_as(self):
def _read_as(self, operators: Optional[Union[List, Tuple]] = None, join_operator=JoinOperator.AND):
custom_class = self.properties[self.__EXPOSED_TYPE_PROPERTY]
with self._get_engine().connect() as connection:
query_result = connection.execute(text(self._get_read_query()))
query_result = connection.execute(text(self._get_read_query(operators, join_operator)))
return list(map(lambda row: custom_class(**row), query_result))

def _read_as_numpy(self) -> np.ndarray:
return self._read_as_pandas_dataframe().to_numpy()
def _read_as_numpy(
self, operators: Optional[Union[List, Tuple]] = None, join_operator=JoinOperator.AND
) -> np.ndarray:
return self._read_as_pandas_dataframe(operators=operators, join_operator=join_operator).to_numpy()

def _read_as_pandas_dataframe(self, columns: Optional[List[str]] = None):
def _read_as_pandas_dataframe(
self,
columns: Optional[List[str]] = None,
operators: Optional[Union[List, Tuple]] = None,
join_operator=JoinOperator.AND,
):
with self._get_engine().connect() as conn:
if columns:
return pd.DataFrame(conn.execute(text(self._get_read_query())))[columns]
return pd.DataFrame(conn.execute(text(self._get_read_query())))
return pd.DataFrame(conn.execute(text(self._get_read_query(operators, join_operator))))[columns]
return pd.DataFrame(conn.execute(text(self._get_read_query(operators, join_operator))))

def _read_as_modin_dataframe(self, columns: Optional[List[str]] = None):
def _read_as_modin_dataframe(
self,
columns: Optional[List[str]] = None,
operators: Optional[Union[List, Tuple]] = None,
join_operator=JoinOperator.AND,
):
if columns:
return modin_pd.read_sql_query(self._get_read_query(), con=self._get_engine())[columns]
return modin_pd.read_sql_query(self._get_read_query(), con=self._get_engine())
return modin_pd.read_sql_query(self._get_read_query(operators, join_operator), con=self._get_engine())[
columns
]
return modin_pd.read_sql_query(self._get_read_query(operators, join_operator), con=self._get_engine())

@abstractmethod
def _get_read_query(self):
raise NotImplementedError
def _get_read_query(self, operators: Optional[Union[List, Tuple]] = None, join_operator=JoinOperator.AND):
query = self._get_base_read_query()

if not operators:
return query

if not isinstance(operators, List):
operators = [operators]

conditions = []
for key, value, operator in operators:
if operator == Operator.EQUAL:
conditions.append(f"{key} = '{value}'")
elif operator == Operator.NOT_EQUAL:
conditions.append(f"{key} <> '{value}'")
elif operator == Operator.GREATER_THAN:
conditions.append(f"{key} > '{value}'")
elif operator == Operator.GREATER_OR_EQUAL:
conditions.append(f"{key} >= '{value}'")
elif operator == Operator.LESS_THAN:
conditions.append(f"{key} < '{value}'")
elif operator == Operator.LESS_OR_EQUAL:
conditions.append(f"{key} <= '{value}'")

if join_operator == JoinOperator.AND:
query += f" WHERE {' AND '.join(conditions)}"
elif join_operator == JoinOperator.OR:
query += f" WHERE {' OR '.join(conditions)}"
else:
raise NotImplementedError(f"Join operator {join_operator} not implemented.")

return query

@abstractmethod
def _do_write(self, data, engine, connection) -> None:
def _get_base_read_query(self) -> str:
raise NotImplementedError

def _write(self, data) -> None:
Expand All @@ -248,6 +302,10 @@ def _write(self, data) -> None:
else:
transaction.commit()

@abstractmethod
def _do_write(self, data, engine, connection) -> None:
raise NotImplementedError

def __setattr__(self, key: str, value) -> None:
if key in self.__ENGINE_PROPERTIES:
self._engine = None
Expand Down
3 changes: 0 additions & 3 deletions src/taipy/core/data/data_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import modin.pandas as modin_pd
import networkx as nx
import numpy as np
import pandas as pd

from taipy.config.common._validate_id import _validate_id
from taipy.config.common.scope import Scope
Expand Down
41 changes: 36 additions & 5 deletions src/taipy/core/data/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@

from datetime import datetime, timedelta
from inspect import isclass
from typing import Any, Dict, List, Optional, Set
from typing import Any, Dict, List, Optional, Set, Tuple, Union

from taipy.config.common.scope import Scope

from .._version._version_manager_factory import _VersionManagerFactory
from ..common._mongo_connector import _connect_mongodb
from ..data.operator import JoinOperator, Operator
from ..exceptions.exceptions import InvalidCustomDocument, MissingRequiredProperty
from .data_node import DataNode
from .data_node_id import DataNodeId, Edit
Expand Down Expand Up @@ -175,19 +176,49 @@ def _check_custom_document(self, custom_document):
def storage_type(cls) -> str:
return cls.__STORAGE_TYPE

def filter(self, operators: Optional[Union[List, Tuple]] = None, join_operator=JoinOperator.AND):
cursor = self._read_by_query(operators, join_operator)
return list(map(lambda row: self._decoder(row), cursor))

def _read(self):
cursor = self._read_by_query()

return list(map(lambda row: self._decoder(row), cursor))

def _read_by_query(self):
def _read_by_query(self, operators: Optional[Union[List, Tuple]] = None, join_operator=JoinOperator.AND):
"""Query from a Mongo collection, exclude the _id field"""
if not operators:
return self.collection.find()

if not isinstance(operators, List):
operators = [operators]

conditions = []
for key, value, operator in operators:
if operator == Operator.EQUAL:
conditions.append({key: value})
elif operator == Operator.NOT_EQUAL:
conditions.append({key: {"$ne": value}})
elif operator == Operator.GREATER_THAN:
conditions.append({key: {"$gt": value}})
elif operator == Operator.GREATER_OR_EQUAL:
conditions.append({key: {"$gte": value}})
elif operator == Operator.LESS_THAN:
conditions.append({key: {"$lt": value}})
elif operator == Operator.LESS_OR_EQUAL:
conditions.append({key: {"$lte": value}})

query = {}
if join_operator == JoinOperator.AND:
query = {"$and": conditions}
elif join_operator == JoinOperator.OR:
query = {"$or": conditions}
else:
raise NotImplementedError(f"Join operator {join_operator} is not supported.")

return self.collection.find()
return self.collection.find(query)

def _write(self, data) -> None:
"""Check data against a collection of types to handle insertion on the database."""

if not isinstance(data, list):
data = [data]

Expand Down
2 changes: 1 addition & 1 deletion src/taipy/core/data/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(
def storage_type(cls) -> str:
return cls.__STORAGE_TYPE

def _get_read_query(self):
def _get_base_read_query(self) -> str:
return self.properties.get(self.__READ_QUERY_KEY)

def _do_write(self, data, engine, connection) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/taipy/core/data/sql_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
def storage_type(cls) -> str:
return cls.__STORAGE_TYPE

def _get_read_query(self):
def _get_base_read_query(self) -> str:
return f"SELECT * FROM {self.properties[self.__TABLE_KEY]}"

def _do_write(self, data, engine, connection) -> None:
Expand Down
12 changes: 6 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ def tmp_sqlite_db_file_path(tmpdir_factory):
file_extension = ".db"
db = create_engine("sqlite:///" + os.path.join(fn.strpath, f"{db_name}{file_extension}"))
conn = db.connect()
conn.execute(text("CREATE TABLE foo (foo int, bar int);"))
conn.execute(text("INSERT INTO foo (foo, bar) VALUES (1, 2);"))
conn.execute(text("INSERT INTO foo (foo, bar) VALUES (3, 4);"))
conn.execute(text("CREATE TABLE example (foo int, bar int);"))
conn.execute(text("INSERT INTO example (foo, bar) VALUES (1, 2);"))
conn.execute(text("INSERT INTO example (foo, bar) VALUES (3, 4);"))
conn.commit()
conn.close()
db.dispose()
Expand All @@ -162,9 +162,9 @@ def tmp_sqlite_sqlite3_file_path(tmpdir_factory):

db = create_engine("sqlite:///" + os.path.join(fn.strpath, f"{db_name}{file_extension}"))
conn = db.connect()
conn.execute(text("CREATE TABLE foo (foo int, bar int);"))
conn.execute(text("INSERT INTO foo (foo, bar) VALUES (1, 2);"))
conn.execute(text("INSERT INTO foo (foo, bar) VALUES (3, 4);"))
conn.execute(text("CREATE TABLE example (foo int, bar int);"))
conn.execute(text("INSERT INTO example (foo, bar) VALUES (1, 2);"))
conn.execute(text("INSERT INTO example (foo, bar) VALUES (3, 4);"))
conn.commit()
conn.close()
db.dispose()
Expand Down
14 changes: 14 additions & 0 deletions tests/core/data/test_mongo_data_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from dataclasses import dataclass
from datetime import datetime
from unittest.mock import patch

import mongomock
import pymongo
Expand Down Expand Up @@ -339,3 +340,16 @@ def test_filter(self, properties):
{"bar": 2},
{},
]

@mongomock.patch(servers=(("localhost", 27017),))
@pytest.mark.parametrize("properties", __properties)
def test_filter_does_not_read_all_entities(self, properties):
mongo_dn = MongoCollectionDataNode("foo", Scope.SCENARIO, properties=properties)

# MongoCollectionDataNode.filter() should not call the MongoCollectionDataNode._read() method
with patch.object(MongoCollectionDataNode, "_read") as read_mock:
mongo_dn.filter(("foo", 1, Operator.EQUAL))
mongo_dn.filter(("bar", 2, Operator.NOT_EQUAL))
mongo_dn.filter([("bar", 1, Operator.EQUAL), ("bar", 2, Operator.EQUAL)], JoinOperator.OR)

assert read_mock["_read"].call_count == 0
Loading

0 comments on commit 76ed789

Please sign in to comment.