Skip to content

Commit

Permalink
Fix code generation model create and update (#378)
Browse files Browse the repository at this point in the history
  • Loading branch information
wu-clan authored Aug 11, 2024
1 parent f30a188 commit 144d704
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 19 deletions.
10 changes: 0 additions & 10 deletions backend/app/generator/crud/crud_gen_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,6 @@ async def get(self, db: AsyncSession, pk: int) -> GenModel | None:
"""
return await self.select_model_by_id(db, pk)

async def get_by_name(self, db: AsyncSession, name: str) -> GenModel | None:
"""
通过 name 获取代码生成模型表
:param db:
:param name:
:return:
"""
return await self.select_model_by_column(db, 'name', name)

async def get_all_by_business_id(self, db: AsyncSession, business_id: int) -> Sequence[GenModel]:
gen_model = await db.execute(
select(self.model).where(self.model.gen_business_id == business_id).order_by(self.model.sort)
Expand Down
2 changes: 1 addition & 1 deletion backend/app/generator/model/gen_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class GenModel(DataClassBase):
__tablename__ = 'sys_gen_model'

id: Mapped[id_key] = mapped_column(init=False)
name: Mapped[str] = mapped_column(String(50), unique=True, comment='列名称')
name: Mapped[str] = mapped_column(String(50), comment='列名称')
comment: Mapped[str | None] = mapped_column(String(255), default=None, comment='列描述')
type: Mapped[str] = mapped_column(String(20), default='str', comment='SQLA 模型列类型')
pd_type: Mapped[str] = mapped_column(String(20), default='str', comment='列类型对应的 pydantic 类型')
Expand Down
16 changes: 8 additions & 8 deletions backend/app/generator/service/gen_model_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ async def get(*, pk: int) -> GenModel:
@staticmethod
async def get_by_business(*, business_id: int) -> Sequence[GenModel]:
async with async_db_session() as db:
gen_model = await gen_model_dao.get_all_by_business_id(db, business_id)
return gen_model
gen_models = await gen_model_dao.get_all_by_business_id(db, business_id)
return gen_models

@staticmethod
async def create(*, obj: CreateGenModelParam) -> None:
async with async_db_session.begin() as db:
gen_model = await gen_model_dao.get_by_name(db, obj.name)
if gen_model:
raise errors.ForbiddenError(msg='禁止添加相同列到模型表')
gen_models = await gen_model_dao.get_all_by_business_id(db, obj.gen_business_id)
if obj.name in [gen_model.name for gen_model in gen_models]:
raise errors.ForbiddenError(msg='禁止添加相同列到同一模型表')
pd_type = sql_type_to_pydantic(obj.type)
await gen_model_dao.create(db, obj, pd_type=pd_type)

Expand All @@ -37,9 +37,9 @@ async def update(*, pk: int, obj: UpdateGenModelParam) -> int:
async with async_db_session.begin() as db:
model = await gen_model_dao.get(db, pk)
if obj.name != model.name:
model_check = await gen_model_dao.get_by_name(db, obj.name)
if model_check:
raise errors.ForbiddenError(msg='禁止添加相同列到模型表')
gen_models = await gen_model_dao.get_all_by_business_id(db, obj.gen_business_id)
if obj.name in [gen_model.name for gen_model in gen_models]:
raise errors.ForbiddenError(msg='模型列名已存在')
pd_type = sql_type_to_pydantic(obj.type)
count = await gen_model_dao.update(db, pk, obj, pd_type=pd_type)
return count
Expand Down

0 comments on commit 144d704

Please sign in to comment.