Skip to content
This repository has been archived by the owner on Jan 2, 2024. It is now read-only.

Commit

Permalink
added to list to models
Browse files Browse the repository at this point in the history
  • Loading branch information
Toan Quach authored and Toan Quach committed Nov 2, 2023
1 parent 38233ff commit 2b13eaa
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 41 deletions.
30 changes: 23 additions & 7 deletions src/taipy/core/_repository/_sql_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@
connection = None


from taipy.config.config import Config

from .._repository._abstract_repository import _AbstractRepository
from .._repository.db._sql_session import _SQLSession
from ..exceptions import MissingRequiredProperty, ModelNotFound

connection = None


def dict_factory(cursor, row):
d = {}
for idx, col in enumerate(cursor.description):
Expand All @@ -42,7 +51,6 @@ def init_db():
except KeyError:
raise MissingRequiredProperty("Missing property db_location")

# More sql databases can be easily added in the future
sqlite3.threadsafety = 3

global connection
Expand Down Expand Up @@ -138,6 +146,9 @@ def _delete(self, entity_id: str):
if cursor.rowcount == 0:
raise ModelNotFound(str(self.model_type.__name__), entity_id)

if cursor.rowcount == 0:
raise ModelNotFound(str(self.model_type.__name__), entity_id)

def _delete_all(self):
self.db.execute(str(self.model_type.__table__.delete().compile(dialect=sqlite.dialect())))
self.db.commit()
Expand Down Expand Up @@ -224,22 +235,27 @@ def _get_by_configs_and_owner_ids(self, configs_and_owner_ids, filters: Optional
return res

def __get_entities_by_config_and_owner(
self, config_id: str, owner_id: Optional[str] = "", filters: Optional[List[Dict]] = None
self, config_id: str, owner_id: Optional[str] = None, filters: Optional[List[Dict]] = None
) -> ModelType:
if not filters:
filters = []
versions = [item.get("version") for item in filters if item.get("version")]

query = self.model_type.__table__.select().filter_by(config_id=config_id)
parameters = [config_id]

if owner_id:
query = query.filter_by(owner_id=owner_id)
else:
query = query.filter_by(owner_id=None)
parameters.append(owner_id)
query = query.filter_by(owner_id=owner_id)

if versions:
query = query.filter(self.model_type.version.in_(versions)) # type: ignore
query = str(query.filter(self.model_type.version.in_(versions)).compile(dialect=sqlite.dialect())) # type: ignore
return self.db.execute(query)

query = str(query.compile(dialect=sqlite.dialect()))
return self.db.execute(query).fetchone()
if entry := self.db.execute(query, parameters).fetchone():
return self.model_type.from_dict(entry)
return None

#############################
# ## Private methods ## #
Expand Down
12 changes: 12 additions & 0 deletions src/taipy/core/_version/_version_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import json
from dataclasses import dataclass
from typing import Any, Dict

Expand Down Expand Up @@ -42,3 +43,14 @@ def from_dict(data: Dict[str, Any]):
config=data["config"],
creation_date=data["creation_date"],
)

@staticmethod
def to_list(model):
return [
model.id,
model.config,
model.creation_date,
model.is_production,
model.is_development,
model.is_latest,
]
20 changes: 15 additions & 5 deletions src/taipy/core/_version/_version_sql_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self):
super().__init__(model_type=_VersionModel, converter=_VersionConverter)

def _set_latest_version(self, version_number):
if old_latest := self.db.query(self.model_type).filter_by(is_latest=True).first():
if old_latest := self.db.execute(str(self.model_type.__table__.select().filter_by(is_latest=True))).fetchone():
old_latest.is_latest = False

version = self.__get_by_id(version_number)
Expand All @@ -39,7 +39,9 @@ def _get_latest_version(self):
return ""

def _set_development_version(self, version_number):
if old_development := self.db.query(self.model_type).filter_by(is_development=True).first():
if old_development := self.db.execute(
str(self.model_type.__table__.select().filter_by(is_development=True))
).fetchone():
old_development.is_development = False

version = self.__get_by_id(version_number)
Expand All @@ -50,7 +52,9 @@ def _set_development_version(self, version_number):
self.db.commit()

def _get_development_version(self):
if development := self.db.query(self.model_type).filter_by(is_development=True).first():
if development := self.db.execute(
str(self.model_type.__table__.select().filter_by(is_development=True))
).fetchone():
return development.id
raise ModelNotFound(self.model_type, "")

Expand All @@ -63,7 +67,11 @@ def _set_production_version(self, version_number):
self.db.commit()

def _get_production_versions(self):
if productions := self.db.query(self.model_type).filter_by(is_production=True).all():
if productions := self.db.execute(
str(self.model_type.__table__.select().filter_by(is_production=True).compile(dialect=sqlite.dialect())),
).fetchall():

# if productions := self.db.query(self.model_type).filter_by(is_production=True).all():
return [p.id for p in productions]
return []

Expand All @@ -77,4 +85,6 @@ def _delete_production_version(self, version_number):
self.db.commit()

def __get_by_id(self, version_id):
return self.db.query(self.model_type).filter_by(id=version_id).first()
query = str(self.model_type.__table__.select().filter_by(id=version_id).compile(dialect=sqlite.dialect()))
entry = self.db.execute(query, [version_id]).fetchone()
return self.model_type.from_dict(entry) if entry else None
9 changes: 8 additions & 1 deletion src/taipy/core/data/_data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ class _DataNodeModel(_BaseModel):
@staticmethod
def from_dict(data: Dict[str, Any]):
dn_properties = data["data_node_properties"]
if isinstance(dn_properties, str):
dn_properties = json.loads(dn_properties.replace("'", '"'))

edits = data["edits"]
if isinstance(edits, str):
edits = json.loads(edits.replace("'", '"'))

return _DataNodeModel(
id=data["id"],
config_id=data["config_id"],
Expand All @@ -76,7 +83,7 @@ def from_dict(data: Dict[str, Any]):
owner_id=data.get("owner_id"),
parent_ids=data.get("parent_ids", []),
last_edit_date=data.get("last_edit_date"),
edits=data["edits"],
edits=edits,
version=data["version"],
validity_days=data["validity_days"],
validity_seconds=data["validity_seconds"],
Expand Down
12 changes: 10 additions & 2 deletions src/taipy/core/job/_job_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ class _JobModel(_BaseModel):

@staticmethod
def from_dict(data: Dict[str, Any]):
subscribers = data["subscribers"]
if isinstance(subscribers, str):
subscribers = json.loads(subscribers.replace("'", '"'))

stacktrace = data["stacktrace"]
if isinstance(stacktrace, str):
stacktrace = json.loads(stacktrace.replace("'", '"'))

return _JobModel(
id=data["id"],
task_id=data["task_id"],
Expand All @@ -58,8 +66,8 @@ def from_dict(data: Dict[str, Any]):
submit_id=data["submit_id"],
submit_entity_id=data["submit_entity_id"],
creation_date=data["creation_date"],
subscribers=data["subscribers"],
stacktrace=data["stacktrace"],
subscribers=subscribers,
stacktrace=stacktrace,
version=data["version"],
)

Expand Down
54 changes: 48 additions & 6 deletions src/taipy/core/scenario/_scenario_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import json
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

Expand Down Expand Up @@ -56,17 +57,58 @@ class _ScenarioModel(_BaseModel):

@staticmethod
def from_dict(data: Dict[str, Any]):
tasks = data.get("tasks", None)
if isinstance(tasks, str):
tasks = json.loads(tasks.replace("'", '"'))

additional_data_nodes = data.get("additional_data_nodes", None)
if isinstance(additional_data_nodes, str):
additional_data_nodes = json.loads(additional_data_nodes.replace("'", '"'))

properties = data["properties"]
if isinstance(properties, str):
properties = json.loads(properties.replace("'", '"'))

subscribers = data["subscribers"]
if isinstance(subscribers, str):
subscribers = json.loads(subscribers.replace("'", '"'))

tags = data["tags"]
if isinstance(tags, str):
tags = json.loads(tags.replace("'", '"'))

sequences = data.get("sequences", None)
if isinstance(sequences, str):
sequences = json.loads(sequences.replace("'", '"'))

return _ScenarioModel(
id=data["id"],
config_id=data["config_id"],
tasks=data.get("tasks", None),
additional_data_nodes=data.get("additional_data_nodes", None),
properties=data["properties"],
tasks=tasks,
additional_data_nodes=additional_data_nodes,
properties=properties,
creation_date=data["creation_date"],
primary_scenario=data["primary_scenario"],
subscribers=data["subscribers"],
tags=data["tags"],
subscribers=subscribers,
tags=tags,
version=data["version"],
sequences=data.get("sequences", None),
sequences=sequences,
cycle=CycleId(data["cycle"]) if "cycle" in data else None,
)

@staticmethod
def to_list(model):
return [
model.id,
model.config_id,
json.dumps(model.tasks),
json.dumps(model.additional_data_nodes),
json.dumps(model.properties),
model.creation_date,
model.primary_scenario,
json.dumps(model.subscribers),
json.dumps(model.tags),
model.version,
json.dumps(model.sequences),
model.cycle,
]
24 changes: 20 additions & 4 deletions src/taipy/core/task/_task_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,34 @@ class _TaskModel(_BaseModel):

@staticmethod
def from_dict(data: Dict[str, Any]):
parent_ids = data.get("parent_ids", [])
if isinstance(parent_ids, str):
parent_ids = json.loads(parent_ids.replace("'", '"'))

input_ids = data["input_ids"]
if isinstance(input_ids, str):
input_ids = json.loads(input_ids.replace("'", '"'))

output_ids = data["output_ids"]
if isinstance(output_ids, str):
output_ids = json.loads(output_ids.replace("'", '"'))

properties = data["properties"] if "properties" in data.keys() else {}
if isinstance(properties, str):
properties = json.loads(properties.replace("'", '"'))

return _TaskModel(
id=data["id"],
owner_id=data.get("owner_id"),
parent_ids=data.get("parent_ids", []),
parent_ids=parent_ids,
config_id=data["config_id"],
input_ids=data["input_ids"],
input_ids=input_ids,
function_name=data["function_name"],
function_module=data["function_module"],
output_ids=data["output_ids"],
output_ids=output_ids,
version=data["version"],
skippable=data["skippable"],
properties=data["properties"] if "properties" in data.keys() else {},
properties=properties,
)

@staticmethod
Expand Down
41 changes: 25 additions & 16 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
import pandas as pd
import pytest
from sqlalchemy import create_engine, text
from sqlalchemy.dialects import sqlite
from sqlalchemy.schema import CreateTable, DropTable

from src.taipy.core._orchestrator._orchestrator_factory import _OrchestratorFactory
from src.taipy.core._repository.db._sql_session import _build_engine
from src.taipy.core._repository._sql_repository import connection
from src.taipy.core._version._version import _Version
from src.taipy.core._version._version_manager_factory import _VersionManagerFactory
from src.taipy.core._version._version_model import _VersionModel
Expand Down Expand Up @@ -442,20 +444,27 @@ def init_sql_repo(tmp_sqlite):
Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite})

# Clean SQLite database
engine = _build_engine()

_CycleModel.__table__.drop(bind=engine, checkfirst=True)
_DataNodeModel.__table__.drop(bind=engine, checkfirst=True)
_JobModel.__table__.drop(bind=engine, checkfirst=True)
_ScenarioModel.__table__.drop(bind=engine, checkfirst=True)
_TaskModel.__table__.drop(bind=engine, checkfirst=True)
_VersionModel.__table__.drop(bind=engine, checkfirst=True)

_CycleModel.__table__.create(bind=engine, checkfirst=True)
_DataNodeModel.__table__.create(bind=engine, checkfirst=True)
_JobModel.__table__.create(bind=engine, checkfirst=True)
_ScenarioModel.__table__.create(bind=engine, checkfirst=True)
_TaskModel.__table__.create(bind=engine, checkfirst=True)
_VersionModel.__table__.create(bind=engine, checkfirst=True)
if connection:
connection.execute(str(DropTable(_CycleModel.__table__, if_exists=True).compile(dialect=sqlite.dialect())))
connection.execute(str(DropTable(_DataNodeModel.__table__, if_exists=True).compile(dialect=sqlite.dialect())))
connection.execute(str(DropTable(_JobModel.__table__, if_exists=True).compile(dialect=sqlite.dialect())))
connection.execute(str(DropTable(_ScenarioModel.__table__, if_exists=True).compile(dialect=sqlite.dialect())))
connection.execute(str(DropTable(_TaskModel.__table__, if_exists=True).compile(dialect=sqlite.dialect())))
connection.execute(str(DropTable(_VersionModel.__table__, if_exists=True).compile(dialect=sqlite.dialect())))

connection.execute(
str(CreateTable(_CycleModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))
)
connection.execute(
str(CreateTable(_DataNodeModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))
)
connection.execute(str(CreateTable(_JobModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect())))
connection.execute(
str(CreateTable(_ScenarioModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))
)
connection.execute(str(CreateTable(_TaskModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect())))
connection.execute(
str(CreateTable(_VersionModel.__table__, if_not_exists=True).compile(dialect=sqlite.dialect()))
)

return tmp_sqlite

0 comments on commit 2b13eaa

Please sign in to comment.