Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update some code generation api and params #373

Merged
merged 14 commits into from
Aug 4, 2024
122 changes: 103 additions & 19 deletions backend/app/generator/api/v1/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,29 @@
# -*- coding: utf-8 -*-
from typing import Annotated

from fastapi import APIRouter, Path, Query
from fastapi import APIRouter, Body, Depends, Path, Query
from fastapi.responses import StreamingResponse

from backend.app.generator.conf import generator_settings
from backend.app.generator.schema.gen_business import CreateGenBusinessParam, UpdateGenBusinessParam
from backend.app.generator.schema.gen_model import CreateGenModelParam, UpdateGenModelParam
from backend.app.generator.schema.gen_business import (
CreateGenBusinessParam,
GetGenBusinessListDetails,
UpdateGenBusinessParam,
)
from backend.app.generator.schema.gen_model import CreateGenModelParam, GetGenModelListDetails, UpdateGenModelParam
from backend.app.generator.service.gen_business_service import gen_business_service
from backend.app.generator.service.gen_model_service import gen_model_service
from backend.app.generator.service.gen_service import gen_service
from backend.common.response.response_schema import ResponseModel, response_base
from backend.common.security.jwt import DependsJwtAuth
from backend.common.security.permission import RequestPermission
from backend.common.security.rbac import DependsRBAC
from backend.utils.serializers import select_list_serialize
from backend.utils.serializers import select_as_dict, select_list_serialize

router = APIRouter()


@router.get('/all', summary='获取所有代码生成业务', dependencies=[DependsJwtAuth])
@router.get('/businesses/all', summary='获取所有代码生成业务', dependencies=[DependsJwtAuth])
async def get_all_businesses() -> ResponseModel:
businesses = await gen_business_service.get_all()
data = await select_list_serialize(businesses)
Expand All @@ -28,47 +33,105 @@ async def get_all_businesses() -> ResponseModel:

@router.get('/businesses/{pk}', summary='获取代码生成业务详情', dependencies=[DependsJwtAuth])
async def get_business(pk: Annotated[int, Path(...)]) -> ResponseModel:
data = await gen_service.get_business_and_model(pk=pk)
business = await gen_service.get_business_with_model(pk=pk)
data = GetGenBusinessListDetails(**await select_as_dict(business))
return await response_base.success(data=data)


@router.post('/businesses', summary='创建代码生成业务', deprecated=True, dependencies=[DependsRBAC])
@router.get('/businesses/{pk}/models', summary='获取代码生成业务所有模型', dependencies=[DependsJwtAuth])
async def get_business_models(pk: Annotated[int, Path(...)]) -> ResponseModel:
models = await gen_model_service.get_by_business(business_id=pk)
data = await select_list_serialize(models)
return await response_base.success(data=data)


@router.post(
'/businesses',
summary='创建代码生成业务',
deprecated=True,
dependencies=[
Depends(RequestPermission('gen:code:business:add')),
DependsRBAC,
],
)
async def create_business(obj: CreateGenBusinessParam) -> ResponseModel:
await gen_business_service.create(obj=obj)
return await response_base.success()


@router.put('/businesses/{pk}', summary='更新代码生成业务', dependencies=[DependsRBAC])
@router.put(
'/businesses/{pk}',
summary='更新代码生成业务',
dependencies=[
Depends(RequestPermission('gen:code:business:edit')),
DependsRBAC,
],
)
async def update_business(pk: Annotated[int, Path(...)], obj: UpdateGenBusinessParam) -> ResponseModel:
count = await gen_business_service.update(pk=pk, obj=obj)
if count > 0:
return await response_base.success()
return await response_base.fail()


@router.delete('/businesses', summary='删除代码生成业务', dependencies=[DependsRBAC])
async def delete_business(pk: Annotated[int, Query(...)]) -> ResponseModel:
@router.delete(
'/businesses/{pk}',
summary='删除代码生成业务',
dependencies=[
Depends(RequestPermission('gen:code:business:del')),
DependsRBAC,
],
)
async def delete_business(pk: Annotated[int, Path(...)]) -> ResponseModel:
count = await gen_business_service.delete(pk=pk)
if count > 0:
return await response_base.success()
return await response_base.fail()


@router.post('/models', summary='创建代码生成模型', dependencies=[DependsRBAC])
@router.get('/models/{pk}', summary='获取代码生成模型详情', dependencies=[DependsJwtAuth])
async def get_model(pk: Annotated[int, Path(...)]) -> ResponseModel:
model = await gen_model_service.get(pk=pk)
data = GetGenModelListDetails(**await select_as_dict(model))
return await response_base.success(data=data)


@router.post(
'/models',
summary='创建代码生成模型',
dependencies=[
Depends(RequestPermission('gen:code:model:add')),
DependsRBAC,
],
)
async def create_model(obj: CreateGenModelParam) -> ResponseModel:
await gen_model_service.create(obj=obj)
return await response_base.success()


@router.put('/models/{pk}', summary='更新代码生成模型', dependencies=[DependsRBAC])
@router.put(
'/models/{pk}',
summary='更新代码生成模型',
dependencies=[
Depends(RequestPermission('gen:code:model:edit')),
DependsRBAC,
],
)
async def update_model(pk: Annotated[int, Path(...)], obj: UpdateGenModelParam) -> ResponseModel:
count = await gen_model_service.update(pk=pk, obj=obj)
if count > 0:
return await response_base.success()
return await response_base.fail()


@router.delete('/models/{pk}', summary='删除代码生成模型', dependencies=[DependsRBAC])
@router.delete(
'/models/{pk}',
summary='删除代码生成模型',
dependencies=[
Depends(RequestPermission('gen:code:model:del')),
DependsRBAC,
],
)
async def delete_model(pk: Annotated[int, Path(...)]) -> ResponseModel:
count = await gen_model_service.delete(pk=pk)
if count > 0:
Expand All @@ -82,11 +145,18 @@ async def get_all_tables(table_schema: Annotated[str, Query(..., description='
return await response_base.success(data=data)


@router.post('/import', summary='导入代码生成业务和模型列', dependencies=[DependsRBAC])
@router.post(
'/import',
summary='导入代码生成业务和模型列',
dependencies=[
Depends(RequestPermission('')),
DependsRBAC,
],
)
async def import_table(
app: Annotated[str, Query(..., description='应用名称,用于代码生成到指定 app')],
table_name: Annotated[str, Query(..., description='数据库表名')],
table_schema: Annotated[str, Query(..., description='数据库名')] = 'fba',
app: Annotated[str, Body(..., description='应用名称,用于代码生成到指定 app')],
table_name: Annotated[str, Body(..., description='数据库表名')],
table_schema: Annotated[str, Body(..., description='数据库名')] = 'fba',
) -> ResponseModel:
await gen_service.import_business_and_model(app=app, table_schema=table_schema, table_name=table_name)
return await response_base.success()
Expand All @@ -98,13 +168,27 @@ async def preview_code(pk: Annotated[int, Path(..., description='业务ID')]) ->
return await response_base.success(data=data)


@router.post('/generate/{pk}', summary='生成代码', description='文件磁盘写入,请谨慎操作', dependencies=[DependsRBAC])
@router.get('/generate/{pk}/path', summary='获取代码生成路径', dependencies=[DependsJwtAuth])
async def generate_path(pk: Annotated[int, Path(..., description='业务ID')]):
data = await gen_service.get_generate_path(pk=pk)
return await response_base.success(data=data)


@router.post(
'/generate/{pk}',
summary='代码生成',
description='文件磁盘写入,请谨慎操作',
dependencies=[
Depends(RequestPermission('gen:code:generate')),
DependsRBAC,
],
)
async def generate_code(pk: Annotated[int, Path(..., description='业务ID')]) -> ResponseModel:
await gen_service.generate(pk=pk)
return await response_base.success()


@router.post('/download/{pk}', summary='下载代码', dependencies=[DependsRBAC])
@router.get('/download/{pk}', summary='下载代码', dependencies=[DependsRBAC])
async def download_code(pk: Annotated[int, Path(..., description='业务ID')]):
bio = await gen_service.download(pk=pk)
return StreamingResponse(
Expand Down
33 changes: 22 additions & 11 deletions backend/app/generator/crud/crud_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,47 @@
# -*- coding: utf-8 -*-
from typing import Sequence

from sqlalchemy import Row, text
from sqlalchemy import Row, select, text
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

from backend.app.generator.model import GenBusiness


class CRUDGen:
@staticmethod
async def get_business_with_model(db: AsyncSession, business_id: int) -> GenBusiness:
result = await db.execute(
select(GenBusiness).options(selectinload(GenBusiness.gen_model)).where(GenBusiness.id == business_id)
)
data = result.scalars().first()
return data

@staticmethod
async def get_all_tables(db: AsyncSession, table_schema: str) -> Sequence[str]:
t = text(
stmt = text(
'select table_name as table_name '
'from information_schema.tables '
'where table_name not like "sys_gen_%" '
'and table_schema = :table_schema;'
).bindparams(table_schema=table_schema)
stmt = await db.execute(t)
return stmt.scalars().all()
result = await db.execute(stmt)
return result.scalars().all()

@staticmethod
async def get_table(db: AsyncSession, table_name: str) -> Row[tuple]:
t = text(
stmt = text(
'select table_name as table_name, table_comment as table_comment '
'from information_schema.tables '
'where table_name not like "sys_gen_%" '
'and table_name = :table_name;'
).bindparams(table_name=table_name)
stmt = await db.execute(t)
return stmt.fetchone()
result = await db.execute(stmt)
return result.fetchone()

@staticmethod
async def get_all_columns(db: AsyncSession, table_schema: str, table_name: str) -> Sequence[Row[tuple]]:
t = text(
stmt = text(
'select column_name AS column_name, '
'case when column_key = "PRI" then 1 else 0 end as is_pk, '
'case when is_nullable = "NO" or column_key = "PRI" then 0 else 1 end as is_nullable, '
Expand All @@ -46,8 +57,8 @@ async def get_all_columns(db: AsyncSession, table_schema: str, table_name: str)
'and column_name != "updated_time" '
'order by sort;'
).bindparams(table_schema=table_schema, table_name=table_name)
stmt = await db.execute(t)
return stmt.fetchall()
result = await db.execute(stmt)
return result.fetchall()


gen_dao = CRUDGen()
gen_dao: CRUDGen = CRUDGen()
2 changes: 1 addition & 1 deletion backend/app/generator/crud/crud_gen_business.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,4 @@ async def delete(self, db: AsyncSession, pk: int) -> int:
return await self.delete_model(db, pk)


gen_business_dao = CRUDGenBusiness(GenBusiness)
gen_business_dao: CRUDGenBusiness = CRUDGenBusiness(GenBusiness)
11 changes: 10 additions & 1 deletion backend/app/generator/crud/crud_gen_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,18 @@


class CRUDGenModel(CRUDPlus[GenModel]):
async def get(self, db: AsyncSession, pk: int) -> GenModel | None:
"""
获取代码生成模型列

:return:
"""
return await self.select_model_by_id(db, pk)

async def get_by_name(self, db: AsyncSession, name: str) -> GenModel | None:
"""
通过 name 获取代码生成模型表

:param db:
:param name:
:return:
Expand Down Expand Up @@ -58,4 +67,4 @@ async def delete(self, db: AsyncSession, pk: int) -> int:
return await self.delete_model(db, pk)


gen_model_dao = CRUDGenModel(GenModel)
gen_model_dao: CRUDGenModel = CRUDGenModel(GenModel)
2 changes: 1 addition & 1 deletion backend/app/generator/model/gen_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class GenModel(DataClassBase):
__tablename__ = 'sys_gen_model'

id: Mapped[id_key] = mapped_column(init=False)
name: Mapped[str] = mapped_column(String(50), comment='列名称')
name: Mapped[str] = mapped_column(String(50), unique=True, comment='列名称')
comment: Mapped[str | None] = mapped_column(String(255), default=None, comment='列描述')
type: Mapped[str] = mapped_column(String(20), default='str', comment='SQLA 模型列类型')
pd_type: Mapped[str] = mapped_column(String(20), default='str', comment='列类型对应的 pydantic 类型')
Expand Down
2 changes: 2 additions & 0 deletions backend/app/generator/schema/gen_business.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from pydantic import ConfigDict, Field, model_validator

from backend.app.generator.schema.gen_model import GetGenModelListDetails
from backend.common.schema import SchemaBase


Expand Down Expand Up @@ -40,3 +41,4 @@ class GetGenBusinessListDetails(GenBusinessSchemaBase):
id: int
created_time: datetime
updated_time: datetime | None = None
gen_model: list[GetGenModelListDetails] | None = None
20 changes: 13 additions & 7 deletions backend/app/generator/service/gen_model_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@


class GenModelService:
@staticmethod
async def get(*, pk: int) -> GenModel:
async with async_db_session() as db:
gen_model = await gen_model_dao.get(db, pk)
return gen_model

@staticmethod
async def get_by_business(*, business_id: int) -> Sequence[GenModel]:
async with async_db_session() as db:
Expand All @@ -20,19 +26,19 @@ async def get_by_business(*, business_id: int) -> Sequence[GenModel]:
@staticmethod
async def create(*, obj: CreateGenModelParam) -> None:
async with async_db_session.begin() as db:
gen_models = await gen_model_dao.get_all_by_business_id(db, obj.gen_business_id)
if gen_models:
if obj.name in [model.name for model in gen_models]:
raise errors.ForbiddenError(msg='禁止添加相同列到模型表')
gen_model = await gen_model_dao.get_by_name(db, obj.name)
if gen_model:
raise errors.ForbiddenError(msg='禁止添加相同列到模型表')
pd_type = sql_type_to_pydantic(obj.type)
await gen_model_dao.create(db, obj, pd_type=pd_type)

@staticmethod
async def update(*, pk: int, obj: UpdateGenModelParam) -> int:
async with async_db_session.begin() as db:
gen_models = await gen_model_dao.get_all_by_business_id(obj.gen_business_id)
if gen_models:
if obj.name in [model.name for model in gen_models]:
model = await gen_model_dao.get(db, pk)
if obj.name != model.name:
model_check = await gen_model_dao.get_by_name(db, obj.name)
if model_check:
raise errors.ForbiddenError(msg='禁止添加相同列到模型表')
pd_type = sql_type_to_pydantic(obj.type)
count = await gen_model_dao.update(db, pk, obj, pd_type=pd_type)
Expand Down
Loading
Loading