Skip to content
This repository has been archived by the owner on Jan 2, 2024. It is now read-only.

Commit

Permalink
fixed failed version query
Browse files Browse the repository at this point in the history
  • Loading branch information
Toan Quach authored and Toan Quach committed Nov 3, 2023
1 parent ed621ce commit 54df259
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 19 deletions.
12 changes: 7 additions & 5 deletions src/taipy/core/_repository/_sql_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(self, model_type: Type[ModelType], converter: Type[Converter], sess
def _save(self, entity: Entity):
obj = self.converter._entity_to_model(entity)
if self._exists(entity.id):
self.__update_entry(obj)
self._update_entry(obj)
return
self.__insert_model(obj)

Expand Down Expand Up @@ -247,12 +247,14 @@ def __get_entities_by_config_and_owner(
if owner_id:
parameters.append(owner_id)
query = query.filter_by(owner_id=owner_id)
query = str(query.compile(dialect=sqlite.dialect()))

if versions:
query = str(query.filter(self.model_type.version.in_(versions)).compile(dialect=sqlite.dialect())) # type: ignore
return self.db.execute(query)

query = str(query.compile(dialect=sqlite.dialect()))
query = query + f" AND {self.model_type.__table__.name}.version IN ({','.join(['?']*len(versions))})"
# query = str(query.filter(self.model_type.version.in_(versions)).compile(dialect=sqlite.dialect())) # type: ignore
parameters.extend(versions)

if entry := self.db.execute(query, parameters).fetchone():
return self.model_type.from_dict(entry)
return None
Expand All @@ -265,7 +267,7 @@ def __insert_model(self, model: ModelType):
self.db.execute(query, model.to_list(model))
self.db.commit()

def __update_entry(self, model):
def _update_entry(self, model):
query = str(self.model_type.__table__.update().filter_by(id=model.id).compile(dialect=sqlite.dialect()))
self.db.execute(query, model.to_list(model) + [model.id])
self.db.commit()
6 changes: 5 additions & 1 deletion src/taipy/core/_version/_version_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,15 @@ class _VersionModel(_BaseModel):

@staticmethod
def from_dict(data: Dict[str, Any]):
return _VersionModel(
model = _VersionModel(
id=data["id"],
config=data["config"],
creation_date=data["creation_date"],
)
model.is_production = data.get("is_production")
model.is_development = data.get("is_development")
model.is_latest = data.get("is_latest")
return model

@staticmethod
def to_list(model):
Expand Down
24 changes: 11 additions & 13 deletions src/taipy/core/_version/_version_sql_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,55 +24,54 @@ def __init__(self):

def _set_latest_version(self, version_number):
if old_latest := self.db.execute(str(self.model_type.__table__.select().filter_by(is_latest=True))).fetchone():
old_latest = self.model_type.from_dict(old_latest)
old_latest.is_latest = False
self._update_entry(old_latest)

version = self.__get_by_id(version_number)
version.is_latest = True

self.db.commit()
self._update_entry(version)

def _get_latest_version(self):
if latest := self.db.execute(
str(self.model_type.__table__.select().filter_by(is_latest=True).compile(dialect=sqlite.dialect()))
).fetchone():
return latest.id
return latest["id"]
return ""

def _set_development_version(self, version_number):
if old_development := self.db.execute(
str(self.model_type.__table__.select().filter_by(is_development=True))
).fetchone():
old_development = self.model_type.from_dict(old_development)
old_development.is_development = False
self._update_entry(old_development)

version = self.__get_by_id(version_number)
version.is_development = True
self._update_entry(version)

self._set_latest_version(version_number)

self.db.commit()

def _get_development_version(self):
if development := self.db.execute(
str(self.model_type.__table__.select().filter_by(is_development=True))
).fetchone():
return development.id
return development["id"]
raise ModelNotFound(self.model_type, "")

def _set_production_version(self, version_number):
version = self.__get_by_id(version_number)
version.is_production = True
self._update_entry(version)

self._set_latest_version(version_number)

self.db.commit()

def _get_production_versions(self):
if productions := self.db.execute(
str(self.model_type.__table__.select().filter_by(is_production=True).compile(dialect=sqlite.dialect())),
).fetchall():

# if productions := self.db.query(self.model_type).filter_by(is_production=True).all():
return [p.id for p in productions]
return [p["id"] for p in productions]
return []

def _delete_production_version(self, version_number):
Expand All @@ -81,8 +80,7 @@ def _delete_production_version(self, version_number):
if not version or not version.is_production:
raise VersionIsNotProductionVersion(f"Version '{version_number}' is not a production version.")
version.is_production = False

self.db.commit()
self._update_entry(version)

def __get_by_id(self, version_id):
query = str(self.model_type.__table__.select().filter_by(id=version_id).compile(dialect=sqlite.dialect()))
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ def sql_engine():
def init_sql_repo(tmp_sqlite):
Config.configure_core(repository_type="sql", repository_properties={"db_location": tmp_sqlite})

init_managers()
# Clean SQLite database
if connection:
connection.execute(str(DropTable(_CycleModel.__table__, if_exists=True).compile(dialect=sqlite.dialect())))
Expand Down

0 comments on commit 54df259

Please sign in to comment.