diff --git a/README.md b/README.md index 14d5c10a..1a7dea9e 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # FastAPI Best Architecture -This is a base project of the FastAPI framework. +This is a base project of the FastAPI framework, in production It‘s purpose is to allow you to develop your project directly with it as your base project @@ -29,6 +29,7 @@ git clone https://github.com/wu-clan/fastapi_best_architecture.git ### 1:Tradition 1. Install dependencies + ```shell pip install -r requirements.txt ``` @@ -36,12 +37,16 @@ git clone https://github.com/wu-clan/fastapi_best_architecture.git 2. Create a database `fba`, choose utf8mb4 encode 3. Install and start Redis 4. create a `.env` file in the `backend/app/` directory + ```shell cd backend/app/ + touch .env ``` -5. Copy .env.example to .env and view `backend/app/core/conf.py`, update database configuration information + +5. Copy `.env.example` to `.env` and view `backend/app/core/conf.py`, update database configuration information 6. Perform a database migration [alembic](https://alembic.sqlalchemy.org/en/latest/tutorial.html) + ```shell cd backend/app/ @@ -51,7 +56,8 @@ git clone https://github.com/wu-clan/fastapi_best_architecture.git # Perform the migration alembic upgrade head ``` -7. Execute the backend/app/main.py file startup service + +7. Execute the `backend/app/main.py` file startup service 8. Browser access: http://127.0.0.1:8000/v1/docs --- @@ -63,6 +69,7 @@ git clone https://github.com/wu-clan/fastapi_best_architecture.git ```shell docker-compose up -d --build ``` + 2. Wait for the command to finish automatically 3. Browser access: http://127.0.0.1:8000/v1/docs diff --git a/README_zh.md b/README_zh.md new file mode 100644 index 00000000..460a0648 --- /dev/null +++ b/README_zh.md @@ -0,0 +1,90 @@ +# FastAPI 最佳架构 + +这是 FastAPI 框架的一个基础项目,在制作中 + +它的目的是让你直接用它作为你的基础项目来开发你的项目 + +支持 python3.10 及以上版本 + +## 技术栈 + +- [x] FastAPI +- [x] Pydantic +- [x] SQLAlchemy +- [x] Alembic +- [x] MySQL +- [x] Redis +- [x] APScheduler +- [x] Docker + +## 克隆 + +```shell +git clone https://github.com/wu-clan/fastapi_best_architecture.git +``` + +## 使用: + +### 1:传统 + +1. 安装依赖项 + ```shell + pip install -r requirements.txt + ``` + +2. 创建一个数据库`fba`,选择 utf8mb4 编码 +3. 安装并启动 Redis +4. 在`backend/app/`目录下创建一个`.env`文件 + ```shell + cd backend/app/ + touch .env + ``` +5. 复制 `.env.example` 到 `.env` 并查看`backend/app/core/conf.py`,更新数据库配置信息 +6. 进行数据库迁移[alembic](https://alembic.sqlalchemy.org/en/latest/tutorial.html) + ```shell + cd backend/app/ + + # 生成迁移文件 + alembic revision --autogenerate + + # 执行迁移 + alembic upgrade head + ``` +7. 执行 `backend/app/main.py` 文件启动服务 +8. 浏览器访问:http://127.0.0.1:8000/v1/docs + +--- + +### 2:Docker + +1. 在 `docker-compose.yml` 文件所在的目录中运行一键启动命令 + + ```shell + docker-compose up -d -build + ``` + +2. 等待命令自动完成 + +3. 浏览器访问:http://127.0.0.1:8000/v1/docs + +## 初始化测试数据 + +执行 `backend/app/init_test_data.py` 文件 + +## 测试 + +通过 pytest 进行测试 + +**提示**: 在测试开始前,请先执行初始化测试数据,同时,需要启动 fastapi 服务。 + +1. 首先,进入app目录 + + ```shell + cd backend/app/ + ``` + +2. 执行测试命令 + + ```shell + pytest -vs --disable-warnings + ``` diff --git a/backend/app/api/routers.py b/backend/app/api/routers.py index 8ec02b71..9e74eb30 100644 --- a/backend/app/api/routers.py +++ b/backend/app/api/routers.py @@ -3,13 +3,23 @@ from fastapi import APIRouter from backend.app.api.v1.auth import router as auth_router +from backend.app.api.v1.user import router as user_router +from backend.app.api.v1.casbin import router as casbin_router +from backend.app.api.v1.dept import router as dept_router +from backend.app.api.v1.role import router as role_router +from backend.app.api.v1.menu import router as menu_router +from backend.app.api.v1.api import router as api_router from backend.app.api.v1.task_demo import router as task_demo_router -from backend.app.api.v1.sys_config import router as sys_config_router +from backend.app.api.v1.config import router as config_router v1 = APIRouter(prefix='/v1') v1.include_router(auth_router) - +v1.include_router(user_router, prefix='/users', tags=['用户管理']) +v1.include_router(casbin_router, prefix='/casbin', tags=['权限管理']) +v1.include_router(dept_router, prefix='/depts', tags=['部门管理']) +v1.include_router(role_router, prefix='/roles', tags=['角色管理']) +v1.include_router(menu_router, prefix='/menus', tags=['菜单管理']) +v1.include_router(api_router, prefix='/apis', tags=['API管理']) +v1.include_router(config_router, prefix='/configs', tags=['系统配置']) v1.include_router(task_demo_router, prefix='/tasks', tags=['任务管理']) - -v1.include_router(sys_config_router, prefix='/configs', tags=['系统配置']) diff --git a/backend/app/api/v1/api.py b/backend/app/api/v1/api.py new file mode 100644 index 00000000..f544f6e0 --- /dev/null +++ b/backend/app/api/v1/api.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from fastapi import APIRouter + +router = APIRouter() + +# TODO: 添加 api 相关接口 diff --git a/backend/app/api/v1/auth/__init__.py b/backend/app/api/v1/auth/__init__.py index f18f1114..2ceb01e7 100644 --- a/backend/app/api/v1/auth/__init__.py +++ b/backend/app/api/v1/auth/__init__.py @@ -1,8 +1,8 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- from fastapi import APIRouter -from backend.app.api.v1.auth.user import router as user_router +from backend.app.api.v1.auth.auth import router as auth_router -router = APIRouter(prefix='/auth', tags=['用户管理']) +router = APIRouter(prefix='/auth', tags=['认证']) -router.include_router(user_router, prefix='/users') +router.include_router(auth_router, prefix='/users') diff --git a/backend/app/api/v1/auth/auth.py b/backend/app/api/v1/auth/auth.py new file mode 100644 index 00000000..099e7483 --- /dev/null +++ b/backend/app/api/v1/auth/auth.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from fastapi import APIRouter, Depends +from fastapi.security import OAuth2PasswordRequestForm + +from backend.app.common.jwt import DependsUser +from backend.app.common.response.response_schema import response_base +from backend.app.schemas.token import Token +from backend.app.schemas.user import Auth +from backend.app.services.user_service import UserService + +router = APIRouter() + + +@router.post('/swagger_login', summary='swagger 表单登录', description='form 格式登录,仅用于 swagger 文档调试接口') +async def swagger_user_login(form_data: OAuth2PasswordRequestForm = Depends()) -> Token: + token, user = await UserService.swagger_login(form_data) + return Token(access_token=token, user=user) + + +@router.post('/login', summary='用户登录', description='json 格式登录, 仅支持在第三方api工具调试接口, 例如: postman') +async def user_login(obj: Auth): + token, user = await UserService.login(obj) + # TODO: token 存储 + data = Token(access_token=token, user=user) + return response_base.response_200(data=data) + + +@router.post('/logout', summary='用户登出', dependencies=[DependsUser]) +async def user_logout(): + # TODO: 加入 token 黑名单 + return response_base.response_200() diff --git a/backend/app/api/v1/casbin.py b/backend/app/api/v1/casbin.py new file mode 100644 index 00000000..09fe81c3 --- /dev/null +++ b/backend/app/api/v1/casbin.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from fastapi import APIRouter + +router = APIRouter() + +# TODO: 添加 casbin 相关接口 diff --git a/backend/app/api/v1/sys_config.py b/backend/app/api/v1/config.py similarity index 74% rename from backend/app/api/v1/sys_config.py rename to backend/app/api/v1/config.py index c1900049..f535cb68 100644 --- a/backend/app/api/v1/sys_config.py +++ b/backend/app/api/v1/config.py @@ -1,17 +1,18 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from fastapi import APIRouter +from fastapi import APIRouter, Request +from fastapi.routing import APIRoute -from backend.app.api.jwt import DependsSuperUser -from backend.app.common.response.response_schema import ResponseModel +from backend.app.common.casbin_rbac import DependsRBAC +from backend.app.common.response.response_schema import response_base from backend.app.core.conf import settings router = APIRouter() -@router.get('', summary='获取系统配置', dependencies=[DependsSuperUser]) -async def get_sys_config() -> ResponseModel: - return ResponseModel( +@router.get('', summary='获取系统配置', dependencies=[DependsRBAC]) +async def get_sys_config(): + return response_base.success( data={ 'title': settings.TITLE, 'version': settings.VERSION, @@ -49,3 +50,12 @@ async def get_sys_config() -> ResponseModel: 'middleware_access': settings.MIDDLEWARE_ACCESS, } ) + + +@router.get('/routers', summary='获取所有路由', dependencies=[DependsRBAC]) +async def get_all_route(request: Request): + data = [] + for route in request.app.routes: + if isinstance(route, APIRoute): + data.append({'path': route.path, 'name': route.name, 'summary': route.summary, 'methods': route.methods}) + return response_base.success(data={'route_list': data}) diff --git a/backend/app/api/v1/dept.py b/backend/app/api/v1/dept.py new file mode 100644 index 00000000..75cbd3cc --- /dev/null +++ b/backend/app/api/v1/dept.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from fastapi import APIRouter + +router = APIRouter() + +# TODO: 添加 dept 相关接口 diff --git a/backend/app/api/v1/menu.py b/backend/app/api/v1/menu.py new file mode 100644 index 00000000..f32f6265 --- /dev/null +++ b/backend/app/api/v1/menu.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from fastapi import APIRouter + +router = APIRouter() + +# TODO: 添加 menu 相关接口 diff --git a/backend/app/api/v1/role.py b/backend/app/api/v1/role.py new file mode 100644 index 00000000..d8dc2ca5 --- /dev/null +++ b/backend/app/api/v1/role.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from fastapi import APIRouter + +router = APIRouter() + +# TODO: 添加 role 相关接口 diff --git a/backend/app/api/v1/auth/user.py b/backend/app/api/v1/user.py similarity index 66% rename from backend/app/api/v1/auth/user.py rename to backend/app/api/v1/user.py index 6cede6ed..037fc9f0 100644 --- a/backend/app/api/v1/auth/user.py +++ b/backend/app/api/v1/user.py @@ -1,31 +1,18 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from fastapi import APIRouter, Depends -from fastapi.security import OAuth2PasswordRequestForm +from fastapi import APIRouter -from backend.app.api.jwt import CurrentUser, DependsUser, DependsSuperUser -from backend.app.api.service.user_service import UserService -from backend.app.common.pagination import Page +from backend.app.common.jwt import DependsUser, CurrentUser, DependsSuperUser +from backend.app.common.pagination import paging_data, PageDepends from backend.app.common.response.response_schema import response_base -from backend.app.schemas.token import Token -from backend.app.schemas.user import CreateUser, GetUserInfo, ResetPassword, UpdateUser, Avatar, Auth +from backend.app.database.db_mysql import CurrentSession +from backend.app.schemas.user import CreateUser, GetUserInfo, ResetPassword, UpdateUser, Avatar +from backend.app.services.user_service import UserService +from backend.app.utils.serializers import select_to_json router = APIRouter() -@router.post('/swagger_login', summary='swagger 表单登录', description='form 格式登录,仅用于 swagger 文档调试接口') -async def swagger_user_login(form_data: OAuth2PasswordRequestForm = Depends()) -> Token: - token, user = await UserService.swagger_login(form_data) - return Token(access_token=token, user=user) - - -@router.post('/login', summary='用户登录', description='json 格式登录, 仅支持在第三方api工具调试接口, 例如: postman') -async def user_login(obj: Auth): - token, user = await UserService.login(obj) - data = Token(access_token=token, user=user) - return response_base.response_200(data=data) - - @router.post('/register', summary='用户注册') async def user_register(obj: CreateUser): await UserService.register(obj) @@ -41,7 +28,8 @@ async def password_reset(obj: ResetPassword): @router.get('/{username}', summary='查看用户信息', dependencies=[DependsUser]) async def userinfo(username: str): current_user = await UserService.get_userinfo(username) - return response_base.response_200(data=current_user, exclude={'password'}) + data = GetUserInfo(**select_to_json(current_user)) + return response_base.response_200(data=data, exclude={'password'}) @router.put('/{username}', summary='更新用户信息') @@ -60,9 +48,11 @@ async def update_avatar(username: str, avatar: Avatar, current_user: CurrentUser return response_base.fail() -@router.get('', summary='获取所有用户', dependencies=[DependsUser]) -async def get_all_users() -> Page[GetUserInfo]: - return await UserService.get_user_list() +@router.get('', summary='获取所有用户', dependencies=[DependsUser, PageDepends]) +async def get_all_users(db: CurrentSession): + user_list = await UserService.get_user_list() + page_data = await paging_data(db, user_list, GetUserInfo) + return response_base.response_200(data=page_data) @router.post('/{pk}/super', summary='修改用户超级权限', dependencies=[DependsSuperUser]) diff --git a/backend/app/common/casbin_rbac.py b/backend/app/common/casbin_rbac.py new file mode 100644 index 00000000..fad8614d --- /dev/null +++ b/backend/app/common/casbin_rbac.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import casbin +import casbin_sqlalchemy_adapter + +from fastapi import Request, Depends + +from backend.app.common.exception.errors import AuthorizationError +from backend.app.common.jwt import CurrentUser +from backend.app.core.conf import settings +from backend.app.core.path_conf import RBAC_MODEL_CONF +from backend.app.models.sys_casbin_rule import CasbinRule + + +class RBAC: + def __init__(self): + self._CASBIN_DATABASE_URL = f'mysql+pymysql://{settings.DB_USER}:{settings.DB_PASSWORD}@{settings.DB_HOST}:{settings.DB_PORT}/{settings.DB_DATABASE}?charset={settings.DB_CHARSET}' + + def get_casbin_enforcer(self) -> casbin.Enforcer: + """ + 由于 casbin_sqlalchemy_adapter 内部使用的 SQLAlchemy 同步, 这里只能使用: mysql+pymysql + + :return: + """ + adapter = casbin_sqlalchemy_adapter.Adapter(self._CASBIN_DATABASE_URL, db_class=CasbinRule) + + enforcer = casbin.Enforcer(RBAC_MODEL_CONF, adapter) + + return enforcer + + async def rbac_verify(self, request: Request, user: CurrentUser) -> None: + """ + 权限校验,超级用户跳过校验,默认拥有所有权限 + + :param request: + :param user: + :return: + """ + user_uuid = user.user_uuid + user_roles = user.roles + role_data_scope = [role.data_scope for role in user_roles] + super_user = user.is_superuser + path = request.url.path + method = request.method + + if super_user: + return + + for ce in settings.CASBIN_EXCLUDE: + if ce['method'] == method and ce['path'] == path: + return + + if 1 in set(role_data_scope): + return + + # TODO: 通过 redis 做鉴权查询优化,减少数据库查询 + enforcer = self.get_casbin_enforcer() + if not enforcer.enforce(user_uuid, path, method): + raise AuthorizationError + + +rbac = RBAC() +# RBAC 依赖注入 +DependsRBAC = Depends(rbac.rbac_verify) diff --git a/backend/app/common/enums.py b/backend/app/common/enums.py index cc1d145c..fee9a452 100644 --- a/backend/app/common/enums.py +++ b/backend/app/common/enums.py @@ -23,3 +23,28 @@ class StrEnum(str, EnumBase): """字符串枚举""" pass + + +class MenuType(IntEnum): + """菜单类型""" + + directory = 0 + menu = 1 + button = 2 + + +class RoleDataScope(IntEnum): + """数据范围""" + + all = 1 + custom = 2 + + +class MethodType(StrEnum): + """请求方法""" + + GET = 'GET' + POST = 'POST' + PUT = 'PUT' + DELETE = 'DELETE' + PATCH = 'PATCH' diff --git a/backend/app/common/exception/exception_handler.py b/backend/app/common/exception/exception_handler.py index f14d2f25..a54ab20c 100644 --- a/backend/app/common/exception/exception_handler.py +++ b/backend/app/common/exception/exception_handler.py @@ -77,7 +77,7 @@ def validation_exception_handler(request: Request, exc: RequestValidationError): message += ( f'{data.get(field, field)} {_msg}' + ', ' if errors_len > 0 - else f'{data.get(field, field)} {_msg}' + else f'{data.get(field, field)} {_msg}' + '.' ) elif isinstance(raw_error.exc, json.JSONDecodeError): message += 'json解析失败' @@ -85,7 +85,7 @@ def validation_exception_handler(request: Request, exc: RequestValidationError): status_code=422, content=response_base.fail( code=422, - msg='请求参数非法' if len(message) == 0 else f'请求参数非法: {message[:-1]}', + msg='请求参数非法' if len(message) == 0 else f'请求参数非法: {message}', data={'errors': exc.errors()} if message == '' and settings.UVICORN_RELOAD is True else None, ), ) diff --git a/backend/app/api/jwt.py b/backend/app/common/jwt.py similarity index 91% rename from backend/app/api/jwt.py rename to backend/app/common/jwt.py index c0cb9f2c..d0456e8a 100644 --- a/backend/app/api/jwt.py +++ b/backend/app/common/jwt.py @@ -54,7 +54,7 @@ def create_access_token(data: int | Any, expires_delta: timedelta | None = None) expires = datetime.utcnow() + expires_delta else: expires = datetime.utcnow() + timedelta(settings.TOKEN_EXPIRE_MINUTES) - to_encode = {'exp': expires, 'sub': str(data)} + to_encode = {'exp': expires, 'sub': str(data[0]), 'role_ids': str(data[1])} encoded_jwt = jwt.encode(to_encode, settings.TOKEN_SECRET_KEY, settings.TOKEN_ALGORITHM) return encoded_jwt @@ -70,11 +70,12 @@ async def get_current_user(db: CurrentSession, token: str = Depends(oauth2_schem try: payload = jwt.decode(token, settings.TOKEN_SECRET_KEY, algorithms=[settings.TOKEN_ALGORITHM]) user_id = payload.get('sub') - if not user_id: + user_role = payload.get('role_ids') + if not user_id or not user_role: raise TokenError except (jwt.JWTError, ValidationError): raise TokenError - user = await UserDao.get_user_by_id(db, user_id) + user = await UserDao.get_user_with_relation(db, user_id=user_id) if not user: raise TokenError return user @@ -93,7 +94,7 @@ async def get_current_is_superuser(user: User = Depends(get_current_user)): return is_superuser -# User dependency injection +# User Annotated CurrentUser = Annotated[User, Depends(get_current_user)] CurrentSuperUser = Annotated[bool, Depends(get_current_is_superuser)] # Permission dependency injection diff --git a/backend/app/common/pagination.py b/backend/app/common/pagination.py index 9085ffd7..450ecd08 100644 --- a/backend/app/common/pagination.py +++ b/backend/app/common/pagination.py @@ -3,22 +3,26 @@ from __future__ import annotations import math -from typing import TypeVar, Generic, Sequence, Dict +from typing import TypeVar, Generic, Sequence, Dict, TYPE_CHECKING -from fastapi import Query +from fastapi import Query, Depends +from fastapi_pagination import pagination_ctx from fastapi_pagination.bases import AbstractPage, AbstractParams, RawParams +from fastapi_pagination.ext.sqlalchemy import paginate from fastapi_pagination.links.bases import create_links from pydantic import BaseModel +from pydantic.generics import GenericModel -T = TypeVar('T') +if TYPE_CHECKING: + from sqlalchemy import Select + from sqlalchemy.ext.asyncio import AsyncSession -""" -重写分页库:fastapi-pagination -使用方法:example link: https://github.com/uriyyo/fastapi-pagination/tree/main/examples -""" +T = TypeVar('T') +DataT = TypeVar('DataT') +SchemaT = TypeVar('SchemaT') -class Params(BaseModel, AbstractParams): +class _Params(BaseModel, AbstractParams): page: int = Query(1, ge=1, description='Page number') size: int = Query(20, gt=0, le=100, description='Page size') # 默认 20 条记录 @@ -29,23 +33,23 @@ def to_raw_params(self) -> RawParams: ) -class Page(AbstractPage[T], Generic[T]): - data: Sequence[T] # 数据 +class _Page(AbstractPage[T], Generic[T]): + items: Sequence[T] # 数据 total: int # 总数据数 page: int # 第n页 size: int # 每页数量 total_pages: int # 总页数 links: Dict[str, str | None] # 跳转链接 - __params_type__ = Params # 使用自定义的Params + __params_type__ = _Params # 使用自定义的Params @classmethod def create( cls, - data: Sequence[T], + items: Sequence[T], total: int, - params: Params, - ) -> Page[T]: + params: _Params, + ) -> _Page[T]: page = params.page size = params.size total_pages = math.ceil(total / params.size) @@ -58,4 +62,26 @@ def create( } ).dict() - return cls(data=data, total=total, page=params.page, size=params.size, total_pages=total_pages, links=links) + return cls(items=items, total=total, page=params.page, size=params.size, total_pages=total_pages, links=links) + + +class _PageData(GenericModel, Generic[DataT]): + page_data: DataT | None = None + + +async def paging_data(db: AsyncSession, select: Select, page_data_schema: SchemaT) -> dict: + """ + 基于 SQLAlchemy 创建分页数据 + + :param db: + :param select: + :param page_data_schema: + :return: + """ + _paginate = await paginate(db, select) + page_data = _PageData[_Page[page_data_schema]](page_data=_paginate).dict()['page_data'] + return page_data + + +# 分页依赖注入 +PageDepends = Depends(pagination_ctx(_Page)) diff --git a/backend/app/core/conf.py b/backend/app/core/conf.py index b992faab..39e1c916 100644 --- a/backend/app/core/conf.py +++ b/backend/app/core/conf.py @@ -53,7 +53,7 @@ def validator_api_url(cls, values): STATIC_FILES: bool = False # MySQL - DB_ECHO: bool = False + DB_ECHO: bool = True DB_DATABASE: str = 'fba' DB_CHARSET: str = 'utf8mb4' @@ -81,6 +81,15 @@ def validator_api_url(cls, values): MIDDLEWARE_GZIP: bool = True MIDDLEWARE_ACCESS: bool = False + # Casbin + CASBIN_RBAC_MODEL_NAME: str = 'rbac_model.conf' + CASBIN_EXCLUDE: list[dict[str, str], dict[str, str]] = [ + {'method': 'POST', 'path': '/api/v1/auth/users/swagger_login'}, + {'method': 'POST', 'path': '/api/v1/auth/users/login'}, + {'method': 'POST', 'path': '/api/v1/auth/users/register'}, + {'method': 'POST', 'path': '/api/v1/auth/users/password/reset'}, + ] + class Config: # https://docs.pydantic.dev/usage/settings/#dotenv-env-support env_file = '.env' diff --git a/backend/app/core/path_conf.py b/backend/app/core/path_conf.py index 343ab987..0dd71bd8 100644 --- a/backend/app/core/path_conf.py +++ b/backend/app/core/path_conf.py @@ -3,6 +3,8 @@ import os from pathlib import Path +from backend.app.core.conf import settings + # 获取项目根目录 # 或使用绝对路径,指到backend目录为止,例如windows:BasePath = D:\git_project\fastapi_mysql\backend BasePath = Path(__file__).resolve().parent.parent.parent @@ -12,3 +14,6 @@ # 日志文件路径 LogPath = os.path.join(BasePath, 'app', 'log') + +# RBAC model.conf 文件路径 +RBAC_MODEL_CONF = os.path.join(BasePath, 'app', 'core', settings.CASBIN_RBAC_MODEL_NAME) diff --git a/backend/app/core/rbac_model.conf b/backend/app/core/rbac_model.conf new file mode 100644 index 00000000..9ca4b928 --- /dev/null +++ b/backend/app/core/rbac_model.conf @@ -0,0 +1,14 @@ +[request_definition] +r = sub, obj, act + +[policy_definition] +p = sub, obj, act + +[role_definition] +g = _, _ + +[policy_effect] +e = some(where (p.eft == allow)) + +[matchers] +m = g(r.sub, p.sub) && r.obj == p.obj && r.act == p.act diff --git a/backend/app/core/registrar.py b/backend/app/core/registrar.py index 2aa96477..0c9bdd9e 100644 --- a/backend/app/core/registrar.py +++ b/backend/app/core/registrar.py @@ -13,7 +13,9 @@ from backend.app.common.task import scheduler from backend.app.core.conf import settings from backend.app.database.db_mysql import create_table -from backend.app.middleware.access_middle import AccessMiddleware +from backend.app.middleware.access_middleware import AccessMiddleware +from backend.app.utils.openapi import simplify_operation_ids +from backend.app.utils.health_check import ensure_unique_route_names @asynccontextmanager @@ -111,6 +113,10 @@ def register_router(app: FastAPI): """ app.include_router(v1) + # extra + ensure_unique_route_names(app) + simplify_operation_ids(app) + def register_page(app: FastAPI): """ diff --git a/backend/app/crud/crud_api.py b/backend/app/crud/crud_api.py new file mode 100644 index 00000000..87f4f401 --- /dev/null +++ b/backend/app/crud/crud_api.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from backend.app.crud.base import CRUDBase +from backend.app.models import Api +from backend.app.schemas.api import CreateApi, UpdateApi + + +class CRUDApi(CRUDBase[Api, CreateApi, UpdateApi]): + # TODO: 添加 api 相关数据库操作 + pass + + +ApiDao: CRUDApi = CRUDApi(Api) diff --git a/backend/app/crud/crud_casbin.py b/backend/app/crud/crud_casbin.py new file mode 100644 index 00000000..6eaef35b --- /dev/null +++ b/backend/app/crud/crud_casbin.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from backend.app.crud.base import CRUDBase +from backend.app.models import CasbinRule +from backend.app.schemas.casbin_rule import CreatePolicy, UpdatePolicy + + +class CRUDCasbin(CRUDBase[CasbinRule, CreatePolicy, UpdatePolicy]): + # TODO: 添加 casbin 相关数据库操作 + pass + + +CasbinDao: CRUDCasbin = CRUDCasbin(CasbinRule) diff --git a/backend/app/crud/crud_dept.py b/backend/app/crud/crud_dept.py new file mode 100644 index 00000000..bd5c9245 --- /dev/null +++ b/backend/app/crud/crud_dept.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from backend.app.crud.base import CRUDBase +from backend.app.models import Dept +from backend.app.schemas.dept import CreateDept, UpdateDept + + +class CRUDDept(CRUDBase[Dept, CreateDept, UpdateDept]): + async def get_dept_by_id(self, db, dept_id): + return await self.get(db, dept_id) + + +DeptDao: CRUDDept = CRUDDept(Dept) diff --git a/backend/app/crud/crud_menu.py b/backend/app/crud/crud_menu.py new file mode 100644 index 00000000..ac5adfc6 --- /dev/null +++ b/backend/app/crud/crud_menu.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from backend.app.crud.base import CRUDBase +from backend.app.models import Menu +from backend.app.schemas.menu import CreateMenu, UpdateMenu + + +class CRUDMenu(CRUDBase[Menu, CreateMenu, UpdateMenu]): + # TODO: 添加 menu 相关数据库操作 + pass + + +MenuDao: CRUDMenu = CRUDMenu(Menu) diff --git a/backend/app/crud/crud_role.py b/backend/app/crud/crud_role.py new file mode 100644 index 00000000..b1890152 --- /dev/null +++ b/backend/app/crud/crud_role.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from backend.app.crud.base import CRUDBase +from backend.app.models import Role +from backend.app.schemas.role import CreateRole, UpdateRole + + +class CRUDRole(CRUDBase[Role, CreateRole, UpdateRole]): + async def get_role_by_id(self, db, role_id): + return await self.get(db, role_id) + + +RoleDao: CRUDRole = CRUDRole(Role) diff --git a/backend/app/crud/crud_user.py b/backend/app/crud/crud_user.py index 610b213c..5cffb880 100644 --- a/backend/app/crud/crud_user.py +++ b/backend/app/crud/crud_user.py @@ -4,11 +4,12 @@ from sqlalchemy import func, select, update, desc from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload from sqlalchemy.sql import Select -from backend.app.api import jwt +from backend.app.common import jwt from backend.app.crud.base import CRUDBase -from backend.app.models import User +from backend.app.models import User, Role from backend.app.schemas.user import CreateUser, UpdateUser, Avatar @@ -26,11 +27,25 @@ async def update_user_login_time(self, db: AsyncSession, username: str) -> int: async def create_user(self, db: AsyncSession, create: CreateUser) -> NoReturn: create.password = jwt.get_hash_password(create.password) - new_user = self.model(**create.dict()) + new_user = self.model(**create.dict(exclude={'roles'})) + role_list = [] + for role_id in create.roles: + role_list.append(await db.get(Role, role_id)) + new_user.roles.append(*role_list) db.add(new_user) - async def update_userinfo(self, db: AsyncSession, current_user: User, obj: UpdateUser) -> int: - user = await db.execute(update(self.model).where(self.model.id == current_user.id).values(**obj.dict())) + async def update_userinfo(self, db: AsyncSession, input_user: User, obj: UpdateUser) -> int: + user = await db.execute( + update(self.model).where(self.model.id == input_user.id).values(**obj.dict(exclude={'roles'})) + ) + # 删除用户所有角色 + for i in list(input_user.roles): + input_user.roles.remove(i) + # 添加用户角色 + role_list = [] + for role_id in obj.roles: + role_list.append(await db.get(Role, role_id)) + input_user.roles.append(*role_list) return user.rowcount async def update_avatar(self, db: AsyncSession, current_user: User, avatar: Avatar) -> int: @@ -51,7 +66,11 @@ async def reset_password(self, db: AsyncSession, pk: int, password: str) -> int: return user.rowcount def get_users(self) -> Select: - return select(self.model).order_by(desc(self.model.time_joined)) + return ( + select(self.model) + .options(selectinload(self.model.roles).selectinload(Role.menus)) + .order_by(desc(self.model.time_joined)) + ) async def get_user_is_super(self, db: AsyncSession, user_id: int) -> bool: user = await self.get_user_by_id(db, user_id) @@ -75,5 +94,26 @@ async def active_set(self, db: AsyncSession, user_id: int) -> int: ) return user.rowcount + async def get_user_role_ids(self, db: AsyncSession, user_id: int) -> list[int]: + user = await db.execute( + select(self.model).where(self.model.id == user_id).options(selectinload(self.model.roles)) + ) + roles_id = [role.id for role in user.scalars().first().roles] + return roles_id + + async def get_user_with_relation(self, db: AsyncSession, *, user_id: int = None, username: str = None) -> User: + where = 'condition' + if user_id: + where = 'self.model.id == user_id' + if username: + where = 'self.model.username == username' + user = await db.execute( + select(self.model) + .where(eval(where)) + .options(selectinload(self.model.dept)) + .options(selectinload(self.model.roles).joinedload(Role.menus)) + ) + return user.scalars().first() + UserDao: CRUDUser = CRUDUser(User) diff --git a/backend/app/database/db_mysql.py b/backend/app/database/db_mysql.py index 32395d8b..ecc2a637 100644 --- a/backend/app/database/db_mysql.py +++ b/backend/app/database/db_mysql.py @@ -46,7 +46,7 @@ async def get_db() -> AsyncSession: await session.close() -# Session 依赖注入 +# Session Annotated CurrentSession = Annotated[AsyncSession, Depends(get_db)] diff --git a/backend/app/init_test_data.py b/backend/app/init_test_data.py index 0f099a49..3d9bab02 100644 --- a/backend/app/init_test_data.py +++ b/backend/app/init_test_data.py @@ -5,18 +5,35 @@ from email_validator import EmailNotValidError, validate_email from faker import Faker -from backend.app.api.jwt import get_hash_password +from backend.app.common.jwt import get_hash_password from backend.app.common.log import log from backend.app.database.db_mysql import async_db_session -from backend.app.models import User +from backend.app.models import User, Role, Menu, Dept -class InitData: - """初始化数据""" +class InitTestData: + """初始化测试数据""" def __init__(self): self.fake = Faker('zh_CN') + @staticmethod + async def create_dept(): + """自动创建部门""" + async with async_db_session.begin() as db: + department_obj = Dept(name='test', create_user=1) + db.add(department_obj) + log.info('部门 test 创建成功') + + @staticmethod + async def create_role(): + """自动创建角色""" + async with async_db_session.begin() as db: + role_obj = Role(name='test', create_user=1) + role_obj.menus.append(Menu(name='test', create_user=1)) + db.add(role_obj) + log.info('角色 test 创建成功') + @staticmethod async def create_test_user(): """创建测试用户""" @@ -25,17 +42,21 @@ async def create_test_user(): email = 'test@gmail.com' user_obj = User( username=username, + nickname=username, password=get_hash_password(password), email=email, is_superuser=True, + dept_id=1, ) async with async_db_session.begin() as db: + user_obj.roles.append(await db.get(Role, 1)) db.add(user_obj) log.info(f'测试用户创建成功,账号:{username},密码:{password}') @staticmethod async def create_superuser_by_yourself(): """手动创建管理员账户""" + log.info('开始创建自定义管理员用户') print('请输入用户名:') username = input() print('请输入密码:') @@ -51,13 +72,16 @@ async def create_superuser_by_yourself(): break user_obj = User( username=username, + nickname=username, password=get_hash_password(password), email=email, is_superuser=True, + dept_id=1, ) async with async_db_session.begin() as db: + user_obj.roles.append(await db.get(Role, 1)) db.add(user_obj) - log.info(f'管理员用户创建成功,账号:{username},密码:{password}') + log.info(f'自定义管理员用户创建成功,账号:{username},密码:{password}') async def fake_user(self): """自动创建普通用户""" @@ -66,11 +90,14 @@ async def fake_user(self): email = self.fake.email() user_obj = User( username=username, + nickname=username, password=get_hash_password(password), email=email, is_superuser=False, + dept_id=1, ) async with async_db_session.begin() as db: + user_obj.roles.append(await db.get(Role, 1)) db.add(user_obj) log.info(f'普通用户创建成功,账号:{username},密码:{password}') @@ -81,12 +108,15 @@ async def fake_no_active_user(self): email = self.fake.email() user_obj = User( username=username, + nickname=username, password=get_hash_password(password), email=email, is_active=False, is_superuser=False, + dept_id=1, ) async with async_db_session.begin() as db: + user_obj.roles.append(await db.get(Role, 1)) db.add(user_obj) log.info(f'普通锁定用户创建成功,账号:{username},密码:{password}') @@ -97,11 +127,14 @@ async def fake_superuser(self): email = self.fake.email() user_obj = User( username=username, + nickname=username, password=get_hash_password(password), email=email, is_superuser=True, + dept_id=1, ) async with async_db_session.begin() as db: + user_obj.roles.append(await db.get(Role, 1)) db.add(user_obj) log.info(f'管理员用户创建成功,账号:{username},密码:{password}') @@ -112,18 +145,23 @@ async def fake_no_active_superuser(self): email = self.fake.email() user_obj = User( username=username, + nickname=username, password=get_hash_password(password), email=email, is_active=False, is_superuser=True, + dept_id=1, ) async with async_db_session.begin() as db: + user_obj.roles.append(await db.get(Role, 1)) db.add(user_obj) log.info(f'管理员锁定用户创建成功,账号:{username},密码:{password}') async def init_data(self): """自动创建数据""" log.info('⏳ 开始初始化数据') + await self.create_dept() + await self.create_role() await self.create_test_user() await self.create_superuser_by_yourself() await self.fake_user() @@ -134,6 +172,6 @@ async def init_data(self): if __name__ == '__main__': - init = InitData() + init = InitTestData() loop = asyncio.get_event_loop() loop.run_until_complete(init.init_data()) diff --git a/backend/app/middleware/access_middle.py b/backend/app/middleware/access_middleware.py similarity index 93% rename from backend/app/middleware/access_middle.py rename to backend/app/middleware/access_middleware.py index 0682d6a7..c8f621ac 100644 --- a/backend/app/middleware/access_middle.py +++ b/backend/app/middleware/access_middleware.py @@ -9,9 +9,7 @@ class AccessMiddleware(BaseHTTPMiddleware): - """ - 记录请求日志 - """ + """记录请求日志中间件""" async def dispatch(self, request: Request, call_next) -> Response: start_time = datetime.now() diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index eac378ee..4b8a5297 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -4,5 +4,10 @@ # 导入所有模型,并将 Base 放在最前面, 以便 Base 拥有它们 # imported by Alembic """ -from backend.app.database.base_class import MappedBase -from backend.app.models.user import User +from backend.app.database.base_class import MappedBase # F401 +from backend.app.models.sys_api import Api +from backend.app.models.sys_casbin_rule import CasbinRule +from backend.app.models.sys_dept import Dept +from backend.app.models.sys_menu import Menu +from backend.app.models.sys_role import Role +from backend.app.models.sys_user import User diff --git a/backend/app/models/sys_api.py b/backend/app/models/sys_api.py new file mode 100644 index 00000000..4f4a243e --- /dev/null +++ b/backend/app/models/sys_api.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from sqlalchemy import String +from sqlalchemy.dialects.mysql import LONGTEXT +from sqlalchemy.orm import Mapped, mapped_column + +from backend.app.database.base_class import Base, id_key + + +class Api(Base): + """系统api""" + + __tablename__ = 'sys_api' + + id: Mapped[id_key] = mapped_column(init=False) + name: Mapped[str] = mapped_column(String(50), unique=True, comment='api名称') + method: Mapped[str] = mapped_column(String(16), comment='请求方法') + path: Mapped[str] = mapped_column(String(500), comment='api路径') + remark: Mapped[str | None] = mapped_column(LONGTEXT, comment='备注') diff --git a/backend/app/models/sys_casbin_rule.py b/backend/app/models/sys_casbin_rule.py new file mode 100644 index 00000000..d6ab0001 --- /dev/null +++ b/backend/app/models/sys_casbin_rule.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from sqlalchemy import String +from sqlalchemy.dialects.mysql import LONGTEXT +from sqlalchemy.orm import Mapped, mapped_column + +from backend.app.database.base_class import id_key, MappedBase + + +class CasbinRule(MappedBase): + """ + 重写 casbin_sqlalchemy_adapter 中的 casbinRule model类, 使用自定义 MappedBase, 避免产生 alembic 迁移问题 + """ + + __tablename__ = 'sys_casbin_rule' + + id: Mapped[id_key] + ptype: Mapped[str] = mapped_column(String(255), comment='策略类型: p 或者 g') + v0: Mapped[str] = mapped_column(String(255), comment='角色 / 用户uuid') + v1: Mapped[str] = mapped_column(LONGTEXT, comment='api路径 / 角色名称') + v2: Mapped[str | None] = mapped_column(String(255), comment='请求方法') + v3: Mapped[str | None] = mapped_column(String(255)) + v4: Mapped[str | None] = mapped_column(String(255)) + v5: Mapped[str | None] = mapped_column(String(255)) diff --git a/backend/app/models/sys_dept.py b/backend/app/models/sys_dept.py new file mode 100644 index 00000000..ee6be0ed --- /dev/null +++ b/backend/app/models/sys_dept.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from sqlalchemy import String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from backend.app.database.base_class import Base, id_key + + +class Dept(Base): + """部门表""" + + __tablename__ = 'sys_dept' + + id: Mapped[id_key] = mapped_column(init=False) + name: Mapped[str] = mapped_column(String(50), unique=True, comment='部门名称') + parent_id: Mapped[int] = mapped_column(default=0, comment='父部门ID') + level: Mapped[int] = mapped_column(default=0, comment='部门层级') + sort: Mapped[int] = mapped_column(default=0, comment='排序') + leader: Mapped[str | None] = mapped_column(String(20), default=None, comment='负责人') + phone: Mapped[str | None] = mapped_column(String(11), default=None, comment='手机') + email: Mapped[str | None] = mapped_column(String(50), default=None, comment='邮箱') + status: Mapped[bool] = mapped_column(default=True, comment='部门状态(0停用 1正常)') + del_flag: Mapped[bool] = mapped_column(default=True, comment='删除标志(0删除 1存在)') + # 用户部门一对多 + users: Mapped['User'] = relationship(init=False, back_populates='dept') # noqa: F821 diff --git a/backend/app/models/sys_menu.py b/backend/app/models/sys_menu.py new file mode 100644 index 00000000..3db75709 --- /dev/null +++ b/backend/app/models/sys_menu.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from sqlalchemy import String +from sqlalchemy.dialects.mysql import LONGTEXT +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from backend.app.database.base_class import Base, id_key +from backend.app.models.sys_role_menu import sys_role_menu + + +class Menu(Base): + """菜单表""" + + __tablename__ = 'sys_menu' + + id: Mapped[id_key] = mapped_column(init=False) + name: Mapped[str] = mapped_column(String(50), unique=True, comment='菜单名称') + parent_id: Mapped[int] = mapped_column(default=0, comment='父菜单ID') + level: Mapped[int] = mapped_column(default=0, comment='菜单层级') + sort: Mapped[int] = mapped_column(default=0, comment='显示顺序') + path: Mapped[str] = mapped_column(String(200), default='', comment='路由地址') + menu_type: Mapped[int] = mapped_column(default=0, comment='菜单类型(0目录 1菜单 2按钮)') + icon: Mapped[str | None] = mapped_column(String(100), default='#', comment='菜单图标') + remark: Mapped[str | None] = mapped_column(LONGTEXT, default=None, comment='备注') + del_flag: Mapped[bool] = mapped_column(default=True, comment='删除标志(0删除 1存在)') + # 菜单角色多对多 + roles: Mapped[list['Role']] = relationship( # noqa: F821 + init=False, secondary=sys_role_menu, back_populates='menus' + ) diff --git a/backend/app/models/sys_role.py b/backend/app/models/sys_role.py new file mode 100644 index 00000000..3f516bc0 --- /dev/null +++ b/backend/app/models/sys_role.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from sqlalchemy import String +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from backend.app.database.base_class import Base, id_key +from backend.app.models.sys_role_menu import sys_role_menu +from backend.app.models.sys_user_role import sys_user_role + + +class Role(Base): + """角色表""" + + __tablename__ = 'sys_role' + + id: Mapped[id_key] = mapped_column(init=False) + name: Mapped[str] = mapped_column(String(20), unique=True, comment='角色名称') + sort: Mapped[int] = mapped_column(default=0, comment='显示顺序') + data_scope: Mapped[int | None] = mapped_column(default=2, comment='数据范围(1:全部数据权限 2:自定数据权限)') + del_flag: Mapped[bool] = mapped_column(default=True, comment='删除标志(0删除 1存在)') + # 角色用户多对多 + users: Mapped[list['User']] = relationship( # noqa: F821 + init=False, secondary=sys_user_role, back_populates='roles' + ) + # 角色菜单多对多 + menus: Mapped[list['Menu']] = relationship( # noqa: F821 + init=False, secondary=sys_role_menu, back_populates='roles' + ) diff --git a/backend/app/models/sys_role_menu.py b/backend/app/models/sys_role_menu.py new file mode 100644 index 00000000..73835df7 --- /dev/null +++ b/backend/app/models/sys_role_menu.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from sqlalchemy import Table, Column, ForeignKey, INT, Integer + +from backend.app.database.base_class import MappedBase + +sys_role_menu = Table( + 'sys_role_menu', + MappedBase.metadata, + Column('id', INT, primary_key=True, unique=True, index=True, autoincrement=True, comment='主键ID'), + Column('role_id', Integer, ForeignKey('sys_role.id', ondelete='CASCADE'), primary_key=True, comment='角色ID'), + Column('menu_id', Integer, ForeignKey('sys_menu.id', ondelete='CASCADE'), primary_key=True, comment='菜单ID'), +) diff --git a/backend/app/models/user.py b/backend/app/models/sys_user.py similarity index 50% rename from backend/app/models/user.py rename to backend/app/models/sys_user.py index 4a9123cc..a19b35ff 100644 --- a/backend/app/models/user.py +++ b/backend/app/models/sys_user.py @@ -1,11 +1,13 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- from datetime import datetime +from typing import Union -from sqlalchemy import func, String -from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy import func, String, ForeignKey +from sqlalchemy.orm import Mapped, mapped_column, relationship from backend.app.database.base_class import use_uuid, id_key, DataClassBase +from backend.app.models.sys_user_role import sys_user_role class User(DataClassBase): @@ -14,13 +16,21 @@ class User(DataClassBase): __tablename__ = 'sys_user' id: Mapped[id_key] = mapped_column(init=False) - uid: Mapped[str] = mapped_column(String(50), init=False, insert_default=use_uuid, unique=True, comment='唯一标识') + user_uuid: Mapped[str] = mapped_column(String(50), init=False, insert_default=use_uuid, unique=True) username: Mapped[str] = mapped_column(String(20), unique=True, index=True, comment='用户名') + nickname: Mapped[str] = mapped_column(String(20), unique=True, comment='昵称') password: Mapped[str] = mapped_column(String(255), comment='密码') email: Mapped[str] = mapped_column(String(50), unique=True, index=True, comment='邮箱') - is_superuser: Mapped[bool] = mapped_column(default=False, comment='超级权限') - is_active: Mapped[bool] = mapped_column(default=True, comment='用户账号状态') + is_superuser: Mapped[bool] = mapped_column(default=False, comment='超级权限(0否 1是)') + is_active: Mapped[bool] = mapped_column(default=True, comment='用户账号状态(0停用 1正常)') avatar: Mapped[str | None] = mapped_column(String(255), default=None, comment='头像') - mobile_number: Mapped[str | None] = mapped_column(String(11), default=None, comment='手机号') + phone: Mapped[str | None] = mapped_column(String(11), default=None, comment='手机号') time_joined: Mapped[datetime] = mapped_column(init=False, default=func.now(), comment='注册时间') last_login: Mapped[datetime | None] = mapped_column(init=False, onupdate=func.now(), comment='上次登录') + # 用户部门一对多 + dept_id: Mapped[int | None] = mapped_column(ForeignKey('sys_dept.id'), default=None, comment='部门关联ID') + dept: Mapped[Union['Dept', None]] = relationship(init=False, back_populates='users') # noqa: F821 + # 用户角色多对多 + roles: Mapped[list['Role']] = relationship( # noqa: F821 + init=False, secondary=sys_user_role, back_populates='users' + ) diff --git a/backend/app/models/sys_user_role.py b/backend/app/models/sys_user_role.py new file mode 100644 index 00000000..c29d16fe --- /dev/null +++ b/backend/app/models/sys_user_role.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from sqlalchemy import Table, Column, ForeignKey, INT, Integer + +from backend.app.database.base_class import MappedBase + +sys_user_role = Table( + 'sys_user_role', + MappedBase.metadata, + Column('id', INT, primary_key=True, unique=True, index=True, autoincrement=True, comment='主键ID'), + Column('user_id', Integer, ForeignKey('sys_user.id', ondelete='CASCADE'), primary_key=True, comment='用户ID'), + Column('role_id', Integer, ForeignKey('sys_role.id', ondelete='CASCADE'), primary_key=True, comment='角色ID'), +) diff --git a/backend/app/schemas/api.py b/backend/app/schemas/api.py new file mode 100644 index 00000000..97e82912 --- /dev/null +++ b/backend/app/schemas/api.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from curses.ascii import isupper +from datetime import datetime + +from pydantic import BaseModel, Field, validator + +from backend.app.common.enums import MethodType + + +class ApiBase(BaseModel): + name: str + method: str = Field(default=MethodType.GET, description='请求方法') + path: str = Field(..., description='api路径') + remark: str | None = None + + @validator('method') + def check_method(cls, v): + if not isupper(v): + raise ValueError('请求方式必须大写') + allow_method = MethodType.get_member_values() + if v not in allow_method: + raise ValueError(f'请求方式不合法, 仅支持: {allow_method}') + return v + + +class CreateApi(ApiBase): + pass + + +class UpdateApi(ApiBase): + pass + + +class GetAllApi(ApiBase): + id: int + create_user: int + update_user: int = None + created_time: datetime + updated_time: datetime | None = None + + class Config: + orm_mode = True diff --git a/backend/app/schemas/casbin_rule.py b/backend/app/schemas/casbin_rule.py new file mode 100644 index 00000000..74e9b381 --- /dev/null +++ b/backend/app/schemas/casbin_rule.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from curses.ascii import isupper + +from pydantic import BaseModel, Field, validator + +from backend.app.common.enums import MethodType + + +class RBACBase(BaseModel): + sub: str = Field(..., description='用户uuid / 角色') + + +class CreatePolicy(RBACBase): + path: str = Field(..., description='api路径') + method: str = Field(default=MethodType.GET, description='请求方法') + + @validator('method') + def check_method(cls, v): + if not isupper(v): + raise ValueError('请求方式必须大写') + allow_method = MethodType.get_member_values() + if v not in allow_method: + raise ValueError(f'请求方式不合法, 仅支持: {allow_method}') + return v + + +class UpdatePolicy(CreatePolicy): + pass + + +class DeletePolicy(CreatePolicy): + pass + + +class UserRole(BaseModel): + uuid: str = Field(..., description='用户uuid') + role: str = Field(..., description='角色') + + +class GetAllPolicy(BaseModel): + id: int + ptype: str + v0: str + v1: str + v2: str | None = None + v3: str | None = None + v4: str | None = None + v5: str | None = None + + class Config: + orm_mode = True diff --git a/backend/app/schemas/dept.py b/backend/app/schemas/dept.py new file mode 100644 index 00000000..5d31b436 --- /dev/null +++ b/backend/app/schemas/dept.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from datetime import datetime + +from pydantic import BaseModel, Field + + +class DeptBase(BaseModel): + name: str + parent_id: int = Field(default=0, ge=0, description='菜单父级ID') + level: int = Field(default=0, ge=0, description='菜单层级') + sort: int = Field(default=0, ge=0, description='排序') + leader: str | None = None + phone: str | None = None + email: str | None = None + status: bool + del_flag: bool + + +class CreateDept(DeptBase): + pass + + +class UpdateDept(DeptBase): + pass + + +class GetAllDept(DeptBase): + id: int + create_user: int + update_user: int = None + created_time: datetime + updated_time: datetime | None = None + + class Config: + orm_mode = True diff --git a/backend/app/schemas/menu.py b/backend/app/schemas/menu.py new file mode 100644 index 00000000..98beaa43 --- /dev/null +++ b/backend/app/schemas/menu.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from datetime import datetime + +from pydantic import BaseModel, Field, validator + +from backend.app.common.enums import MenuType + + +class MenuBase(BaseModel): + name: str + parent_id: int = Field(default=0, ge=0, description='菜单父级ID') + level: int = Field(default=0, ge=0, description='菜单层级') + sort: int = Field(default=0, ge=0, description='排序') + path: str = Field(..., description='路由地址') + menu_type: int = Field(default=MenuType.directory, ge=0, description='菜单类型(0目录 1菜单 2按钮)') + icon: str | None = None + remark: str | None = None + del_flag: bool + + @validator('menu_type') + def check_menu_type(cls, v): + if v not in MenuType.get_member_values(): + raise ValueError('菜单类型只能是0、1、2') + return v + + +class CreateMenu(MenuBase): + pass + + +class UpdateMenu(MenuBase): + pass + + +class GetAllMenu(MenuBase): + id: int + create_user: int + update_user: int = None + created_time: datetime + updated_time: datetime | None = None + + class Config: + orm_mode = True diff --git a/backend/app/schemas/role.py b/backend/app/schemas/role.py new file mode 100644 index 00000000..3f5f091f --- /dev/null +++ b/backend/app/schemas/role.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from datetime import datetime + +from pydantic import BaseModel, Field, validator + +from backend.app.common.enums import RoleDataScope +from backend.app.schemas.menu import GetAllMenu + + +class RoleBase(BaseModel): + name: str + sort: int = Field(default=0, ge=0, description='排序') + data_scope: int | None = Field(default=RoleDataScope.custom, description='数据范围(1:全部数据权限 2:自定数据权限)') # noqa: E501 + del_flag: bool + + @validator('data_scope') + def check_data_scope(cls, v): + if v not in RoleDataScope.get_member_values(): + raise ValueError('数据范围只能是1或2') + return v + + +class CreateRole(RoleBase): + menu_ids: list[int] + + +class UpdateRole(RoleBase): + menu_ids: list[int] + + +class GetAllRole(RoleBase): + id: int + create_user: int + update_user: int = None + created_time: datetime + updated_time: datetime | None = None + menus: list[GetAllMenu] + + class Config: + orm_mode = True diff --git a/backend/app/schemas/token.py b/backend/app/schemas/token.py index f8e8a638..cfc7678e 100644 --- a/backend/app/schemas/token.py +++ b/backend/app/schemas/token.py @@ -2,10 +2,10 @@ # -*- coding: utf-8 -*- from pydantic import BaseModel -from backend.app.schemas.user import GetUserInfo +from backend.app.schemas.user import GetUserInfoNoRelation class Token(BaseModel): access_token: str token_type: str = 'Bearer' - user: GetUserInfo + user: GetUserInfoNoRelation diff --git a/backend/app/schemas/user.py b/backend/app/schemas/user.py index 7094fc5a..28a6bfe2 100644 --- a/backend/app/schemas/user.py +++ b/backend/app/schemas/user.py @@ -2,7 +2,10 @@ # -*- coding: utf-8 -*- from datetime import datetime -from pydantic import BaseModel, Field, HttpUrl +from pydantic import BaseModel, HttpUrl, Field + +from backend.app.schemas.dept import GetAllDept +from backend.app.schemas.role import GetAllRole class Auth(BaseModel): @@ -11,22 +14,31 @@ class Auth(BaseModel): class CreateUser(Auth): + dept_id: int + roles: list[int] + nickname: str email: str = Field(..., example='user@example.com') -class UpdateUser(BaseModel): +class _UserInfoBase(BaseModel): + dept_id: int username: str - email: str - mobile_number: str | None = None + nickname: str + email: str = Field(..., example='user@example.com') + phone: str | None = None + + +class UpdateUser(_UserInfoBase): + roles: list[int] class Avatar(BaseModel): - url: HttpUrl = Field(..., description='头像地址') + url: HttpUrl = Field(..., description='头像 http 地址') -class GetUserInfo(UpdateUser): +class GetUserInfoNoRelation(_UserInfoBase): id: int - uid: str + user_uuid: str avatar: str | None = None is_active: bool is_superuser: bool @@ -37,7 +49,15 @@ class Config: orm_mode = True +class GetUserInfo(GetUserInfoNoRelation): + dept: GetAllDept | None = None + roles: list[GetAllRole] + + class Config: + orm_mode = True + + class ResetPassword(BaseModel): - id: int = Field(..., example='1', description='用户ID') + id: int password1: str password2: str diff --git a/backend/app/api/service/__init__.py b/backend/app/services/__init__.py similarity index 100% rename from backend/app/api/service/__init__.py rename to backend/app/services/__init__.py diff --git a/backend/app/services/api_service.py b/backend/app/services/api_service.py new file mode 100644 index 00000000..ccaf40ce --- /dev/null +++ b/backend/app/services/api_service.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + + +class ApiService: + # TODO: 添加 api 相关服务 + pass diff --git a/backend/app/services/casbin_service.py b/backend/app/services/casbin_service.py new file mode 100644 index 00000000..c763dac7 --- /dev/null +++ b/backend/app/services/casbin_service.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + + +class CasbinService: + # TODO: 添加 casbin 相关服务 + pass diff --git a/backend/app/services/dept_service.py b/backend/app/services/dept_service.py new file mode 100644 index 00000000..a6b70eb5 --- /dev/null +++ b/backend/app/services/dept_service.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + + +class DeptService: + # TODO: 添加 dept 相关服务 + pass diff --git a/backend/app/services/menu_service.py b/backend/app/services/menu_service.py new file mode 100644 index 00000000..778e575b --- /dev/null +++ b/backend/app/services/menu_service.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + + +class MenuService: + # TODO: 添加 menu 相关服务 + pass diff --git a/backend/app/services/role_service.py b/backend/app/services/role_service.py new file mode 100644 index 00000000..bbb2e951 --- /dev/null +++ b/backend/app/services/role_service.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + + +class RoleService: + # TODO: 添加 role 相关服务 + pass diff --git a/backend/app/api/service/user_service.py b/backend/app/services/user_service.py similarity index 81% rename from backend/app/api/service/user_service.py rename to backend/app/services/user_service.py index 70f7539b..8b786c0c 100644 --- a/backend/app/api/service/user_service.py +++ b/backend/app/services/user_service.py @@ -2,10 +2,11 @@ # -*- coding: utf-8 -*- from email_validator import validate_email, EmailNotValidError from fastapi.security import OAuth2PasswordRequestForm -from fastapi_pagination.ext.sqlalchemy import paginate -from backend.app.api import jwt +from backend.app.common import jwt from backend.app.common.exception import errors +from backend.app.crud.crud_dept import DeptDao +from backend.app.crud.crud_role import RoleDao from backend.app.crud.crud_user import UserDao from backend.app.database.db_mysql import async_db_session from backend.app.models import User @@ -26,10 +27,12 @@ async def swagger_login(form_data: OAuth2PasswordRequestForm): raise errors.AuthorizationError(msg='该用户已被锁定,无法登录') # 更新登陆时间 await UserDao.update_user_login_time(db, form_data.username) + # 查询用户角色 + user_role_ids = await UserDao.get_user_role_ids(db, current_user.id) # 获取最新用户信息 user = await UserDao.get_user_by_id(db, current_user.id) # 创建token - access_token = jwt.create_access_token(user.id) + access_token = jwt.create_access_token([user.id, user_role_ids]) return access_token, user @staticmethod @@ -42,12 +45,10 @@ async def login(obj: Auth): raise errors.AuthorizationError(msg='密码错误') elif not current_user.is_active: raise errors.AuthorizationError(msg='该用户已被锁定,无法登录') - # 更新登陆时间 await UserDao.update_user_login_time(db, obj.username) - # 获取最新用户信息 + user_role_ids = await UserDao.get_user_role_ids(db, current_user.id) user = await UserDao.get_user_by_id(db, current_user.id) - # 创建token - access_token = jwt.create_access_token(user.id) + access_token = jwt.create_access_token([user.id, user_role_ids]) return access_token, user @staticmethod @@ -63,6 +64,13 @@ async def register(obj: CreateUser): validate_email(obj.email, check_deliverability=False).email except EmailNotValidError: raise errors.ForbiddenError(msg='邮箱格式错误') + dept = await DeptDao.get_dept_by_id(db, obj.dept_id) + if not dept: + raise errors.NotFoundError(msg='部门不存在') + for role_id in obj.roles: + role = await RoleDao.get_role_by_id(db, role_id) + if not role: + raise errors.NotFoundError(msg='角色不存在') await UserDao.create_user(db, obj) @staticmethod @@ -77,7 +85,7 @@ async def pwd_reset(obj: ResetPassword): @staticmethod async def get_userinfo(username: str): async with async_db_session() as db: - user = await UserDao.get_user_by_username(db, username) + user = await UserDao.get_user_with_relation(db, username=username) if not user: raise errors.NotFoundError(msg='用户不存在') return user @@ -88,7 +96,7 @@ async def update(*, username: str, current_user: User, obj: UpdateUser): if not current_user.is_superuser: if not username == current_user.username: raise errors.AuthorizationError - input_user = await UserDao.get_user_by_username(db, username) + input_user = await UserDao.get_user_with_relation(db, username=username) if not input_user: raise errors.NotFoundError(msg='用户不存在') if input_user.username != obj.username: @@ -103,9 +111,16 @@ async def update(*, username: str, current_user: User, obj: UpdateUser): validate_email(obj.email, check_deliverability=False).email except EmailNotValidError: raise errors.ForbiddenError(msg='邮箱格式错误') - if obj.mobile_number is not None: - if not re_verify.is_mobile(obj.mobile_number): + if obj.phone is not None: + if not re_verify.is_phone(obj.phone): raise errors.ForbiddenError(msg='手机号码输入有误') + dept = await DeptDao.get_dept_by_id(db, obj.dept_id) + if not dept: + raise errors.NotFoundError(msg='部门不存在') + for role_id in obj.roles: + role = await RoleDao.get_role_by_id(db, role_id) + if not role: + raise errors.NotFoundError(msg='角色不存在') count = await UserDao.update_userinfo(db, input_user, obj) return count @@ -123,9 +138,7 @@ async def update_avatar(*, username: str, current_user: User, avatar: Avatar): @staticmethod async def get_user_list(): - async with async_db_session() as db: - user_select = UserDao.get_users() - return await paginate(db, user_select) + return UserDao.get_users() @staticmethod async def update_permission(pk: int): diff --git a/backend/app/test/__init__.py b/backend/app/tests/__init__.py similarity index 100% rename from backend/app/test/__init__.py rename to backend/app/tests/__init__.py diff --git a/backend/app/test/conftest.py b/backend/app/tests/conftest.py similarity index 100% rename from backend/app/test/conftest.py rename to backend/app/tests/conftest.py diff --git a/backend/app/tests/test_auth.py b/backend/app/tests/test_auth.py new file mode 100644 index 00000000..148ab5cb --- /dev/null +++ b/backend/app/tests/test_auth.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import sys + +import pytest +from httpx import AsyncClient + +sys.path.append('../../') + +from backend.app.core.conf import settings # noqa: E402 +from backend.app.main import app # noqa: E402 + + +class TestAuth: + pytestmark = pytest.mark.anyio + + async def test_login(self): + async with AsyncClient( + app=app, headers={'accept': 'application/json', 'Content-Type': 'application/json'} + ) as client: + response = await client.post( + url=f'http://{settings.UVICORN_HOST}:{settings.UVICORN_PORT}/v1/auth/users/login', + json={'username': 'test', 'password': 'test'}, + ) + assert response.status_code == 200 + assert response.json()['data']['token_type'] == 'Bearer' diff --git a/backend/app/test/test_auth.py b/backend/app/tests/test_user.py similarity index 75% rename from backend/app/test/test_auth.py rename to backend/app/tests/test_user.py index 90e1bef2..d930ab33 100644 --- a/backend/app/test/test_auth.py +++ b/backend/app/tests/test_user.py @@ -13,26 +13,16 @@ from backend.app.common.redis import redis_client # noqa: E402 -class TestAuth: +class TestUser: pytestmark = pytest.mark.anyio faker = Faker(locale='zh_CN') - users_api_base_url = f'http://{settings.UVICORN_HOST}:{settings.UVICORN_PORT}/v1/auth/users' + users_api_base_url = f'http://{settings.UVICORN_HOST}:{settings.UVICORN_PORT}/v1/users' @property async def get_token(self): token = await redis_client.get('test_token') return token - async def test_login(self): - async with AsyncClient( - app=app, headers={'accept': 'application/json', 'Content-Type': 'application/json'} - ) as client: - response = await client.post( - url=f'{self.users_api_base_url}/login', json={'username': 'test', 'password': 'test'} - ) - assert response.status_code == 200 - assert response.json()['data']['token_type'] == 'Bearer' - async def test_register(self): async with AsyncClient( app=app, headers={'accept': 'application/json', 'Content-Type': 'application/json'} @@ -41,8 +31,11 @@ async def test_register(self): url=f'{self.users_api_base_url}/register', json={ 'username': f'{self.faker.user_name()}', + 'nickname': f'{self.faker.name()}', 'password': f'{self.faker.password()}', 'email': f'{self.faker.email()}', + 'dept_id': 1, + 'roles': [1], }, ) assert response.status_code == 200 @@ -67,6 +60,6 @@ async def test_get_all_users(self): response = await client.get(url=f'{self.users_api_base_url}?page=1&size=20') assert response.status_code == 200 r_json = response.json() - assert isinstance(r_json['data'], list) - assert isinstance(r_json['links'], dict) - assert isinstance(r_json['links']['self'], str) + assert isinstance(r_json['data']['items'], list) + assert isinstance(r_json['data']['links'], dict) + assert isinstance(r_json['data']['links']['self'], str) diff --git a/backend/app/utils/health_check.py b/backend/app/utils/health_check.py new file mode 100644 index 00000000..1dd98ca9 --- /dev/null +++ b/backend/app/utils/health_check.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from fastapi import FastAPI +from fastapi.routing import APIRoute + + +def ensure_unique_route_names(app: FastAPI) -> None: + """ + 检查路由名称是否唯一 + + :param app: + :return: + """ + temp_routes = set() + for route in app.routes: + if isinstance(route, APIRoute): + if route.name in temp_routes: + raise ValueError(f'Non-unique route name: {route.name}') + temp_routes.add(route.name) diff --git a/backend/app/utils/openapi.py b/backend/app/utils/openapi.py new file mode 100644 index 00000000..fc53e82c --- /dev/null +++ b/backend/app/utils/openapi.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from fastapi import FastAPI +from fastapi.routing import APIRoute + + +def simplify_operation_ids(app: FastAPI) -> None: + """ + 简化操作 ID,以便生成的客户端具有更简单的 api 函数名称 + + :param app: + :return: + """ + for route in app.routes: + if isinstance(route, APIRoute): + route.operation_id = route.name diff --git a/backend/app/utils/re_verify.py b/backend/app/utils/re_verify.py index a5d27d8b..b85f3f9f 100644 --- a/backend/app/utils/re_verify.py +++ b/backend/app/utils/re_verify.py @@ -33,7 +33,7 @@ def match_string(pattern, text) -> bool: return False -def is_mobile(text: str) -> bool: +def is_phone(text: str) -> bool: """ 检查手机号码 diff --git a/backend/app/utils/serializers.py b/backend/app/utils/serializers.py new file mode 100644 index 00000000..c601ad8c --- /dev/null +++ b/backend/app/utils/serializers.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +from decimal import Decimal + +from sqlalchemy.sql import Select + + +def select_to_dict(obj: Select) -> dict: + """ + Serialize SQLAlchemy Select to dict + + :param obj: + :return: + """ + obj_dict = {} + for column in obj.__table__.columns.keys(): + val = getattr(obj, column) + if isinstance(val, Decimal): + val = float(val) + obj_dict[column] = val + return obj_dict + + +def select_to_list(obj: list) -> list: + """ + Serialize SQLAlchemy Select to list + + :param obj: + :return: + """ + ret_list = [] + for _ in obj: + ret_dict = select_to_dict(_) + ret_list.append(ret_dict) + return ret_list + + +def select_to_json(obj: Select) -> dict: + """ + Serialize SQLAlchemy Select to json + + :param obj: + :return: + """ + obj_dict = obj.__dict__ + if '_sa_instance_state' in obj_dict: + del obj_dict['_sa_instance_state'] + return obj_dict diff --git a/requirements.txt b/requirements.txt index f21785ea..ec8eed10 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,8 @@ alembic==1.7.4 APScheduler==3.8.1 asyncmy==0.2.5 bcrypt==3.2.2 +casbin==1.18.2 +casbin_sqlalchemy_adapter==0.5.1 cryptography==39.0.1 email-validator==1.1.3 Faker==9.7.1 @@ -18,6 +20,7 @@ passlib==1.7.4 path==15.1.2 pre-commit==3.2.2 pydantic==1.10.5 +pymysql==0.9.3 pytest==7.2.2 pytest-pretty==1.2.0 python-jose==3.3.0