From 07cff2a05e16b3b55a2bac7ac964537b2e565631 Mon Sep 17 00:00:00 2001 From: BiDuang Date: Wed, 17 Jul 2024 12:11:11 +0800 Subject: [PATCH] feat: PluginUtils; Plugin function; fix: Plugin loader may leak unregister source into source list --- Configs/{config.json => config.json.example} | 0 Models/database.py | 6 +-- Models/plugins.py | 32 +++++++++++-- Models/response.py | 6 +++ Models/user.py | 23 +++++++++ Routers/comic.py | 23 +++++++-- Routers/user.py | 49 +++++++++++--------- Services/Database/database.py | 2 +- Services/Modulator/manager.py | 45 +++++++++++++++--- Services/Security/user.py | 18 +++++-- Utils/convert.py | 3 +- main.py | 22 ++++++--- 12 files changed, 173 insertions(+), 56 deletions(-) rename Configs/{config.json => config.json.example} (100%) diff --git a/Configs/config.json b/Configs/config.json.example similarity index 100% rename from Configs/config.json rename to Configs/config.json.example diff --git a/Models/database.py b/Models/database.py index 33cf87e..2ea168a 100644 --- a/Models/database.py +++ b/Models/database.py @@ -1,11 +1,11 @@ -from sqlalchemy import Column, INT, TEXT, TIMESTAMP +from sqlalchemy import INT, TEXT, TIMESTAMP, Column from Services.Database.database import Base class UserDb(Base): __tablename__ = "user" - id = Column(INT, unique=True, primary_key=True, index=True, nullable=False) + id = Column(INT, unique=True, primary_key=True, index=True, nullable=False, autoincrement=True) uid = Column(TEXT, unique=True, index=True, nullable=False) username = Column(TEXT, unique=True, index=True, nullable=False) email = Column(TEXT, unique=True, nullable=False) @@ -15,7 +15,7 @@ class UserDb(Base): class PwdDb(Base): __tablename__ = "src_pwd" - id = Column(INT, unique=True, primary_key=True, index=True, nullable=False) + id = Column(INT, unique=True, primary_key=True, index=True, nullable=False, autoincrement=True) source = Column(TEXT, index=True, nullable=False) uid = Column(TEXT, index=True, nullable=False) account = Column(TEXT, nullable=False) diff --git a/Models/plugins.py b/Models/plugins.py index 8e15d8a..0d4c7e0 100644 --- a/Models/plugins.py +++ b/Models/plugins.py @@ -1,7 +1,15 @@ +import asyncio +import inspect from abc import ABC, abstractmethod +from typing import Any + +import nest_asyncio -from Models.response import StandardResponse from Models.comic import BaseComicInfo +from Models.response import StandardResponse +from Models.user import User, UserData + +nest_asyncio.apply() class BasePlugin(ABC): @@ -17,13 +25,13 @@ def search(self, keyword: str) -> list[BaseComicInfo]: pass -class IAuth: +class IAuth(ABC): @abstractmethod - def login(self, body: dict[str, str]) -> StandardResponse: + async def login(self, body: dict[str, str], user: UserData) -> StandardResponse: pass -class IShaper: +class IShaper(ABC): @abstractmethod def imager_shaper(self): pass @@ -34,7 +42,7 @@ class Plugin: version: str cnm_version: str source: list[str] - service: dict + service: dict[str, list[str]] instance: BasePlugin def __init__( @@ -43,10 +51,24 @@ def __init__( version: str, cnm_version: str, source: list[str], + service: dict[str, list[str]], instance: BasePlugin, ): self.name = name self.version = version self.cnm_version = cnm_version self.source = source + self.service = service self.instance = instance + + def try_call(self, method: str, *args: Any, **kwargs: Any) -> Any: + if hasattr(self.instance, method): + if inspect.iscoroutinefunction(getattr(self.instance, method)): + loop = asyncio.get_event_loop() + return loop.run_until_complete( + getattr(self.instance, method)(*args, **kwargs) + ) + else: + return getattr(self.instance, method)(*args, **kwargs) + else: + return None diff --git a/Models/response.py b/Models/response.py index 4e63039..93cf5a9 100644 --- a/Models/response.py +++ b/Models/response.py @@ -25,6 +25,12 @@ def limit_exceeded(request: Request, exc: RateLimitExceeded): class StandardResponse[T](BaseModel): + """ + Standard Response Model + ~~~~~~~~~~~~~~~~~~~~~ + This model is the default web response format of the API. + """ + status_code: int message: str | None data: T | None diff --git a/Models/user.py b/Models/user.py index 3e51129..5f4f971 100644 --- a/Models/user.py +++ b/Models/user.py @@ -1,4 +1,7 @@ +import json from datetime import datetime +from http.cookies import BaseCookie + from pydantic import BaseModel @@ -9,6 +12,26 @@ class User(BaseModel): created_at: datetime +class UserData: + uid: str + plugin_cookies: dict[str, BaseCookie[str]] + + def __init__(self, uid: str, plugin_cookies: dict[str, BaseCookie[str]]): + self.uid = uid + self.plugin_cookies = plugin_cookies + + def get_src_cookies(self, src: str) -> BaseCookie[str]: + if src in self.plugin_cookies: + return self.plugin_cookies[src] + else: + return BaseCookie[str]() + + def __str__(self): + return json.dumps( + {k: v.output(header="").strip() for k, v in self.plugin_cookies.items()} + ) + + class Token(BaseModel): access_token: str token_type: str diff --git a/Routers/comic.py b/Routers/comic.py index f3b8e0e..48f4a3c 100644 --- a/Routers/comic.py +++ b/Routers/comic.py @@ -1,12 +1,10 @@ -from fastapi import APIRouter, Response, Depends, HTTPException, Form, Request +from fastapi import APIRouter, Depends -from Models.user import User from Models.requests import ComicSearchReq from Models.response import ExceptionResponse, StandardResponse -from Services.Database.database import get_db -from Services.Limiter.limiter import limiter -from Services.Security.user import get_current_user +from Models.user import User, UserData from Services.Modulator.manager import plugin_manager +from Services.Security.user import get_current_user, get_user_data comic_router = APIRouter(prefix="/comic") @@ -16,6 +14,21 @@ async def search_comic(body: ComicSearchReq, user: User = Depends(get_current_us pass +@comic_router.get("/{src_id}/favor") +async def get_favor( + src_id: str, + data: dict[str, str] | None = None, + user_data: UserData = Depends(get_user_data), +): + if (source := plugin_manager.get_source(src_id)) is None: + raise ExceptionResponse.not_found + + if (resp := source.try_call("get_favor", user_data, data)) is not None: + return resp + + return StandardResponse(status_code=404, message="Not Found") + + @comic_router.get("/{src_id}/album/{album_id}") async def get_album(src_id: str, album_id: str, user: User = Depends(get_current_user)): if (source := plugin_manager.get_source(src_id)) is None: diff --git a/Routers/user.py b/Routers/user.py index 5b9644c..3c6118c 100644 --- a/Routers/user.py +++ b/Routers/user.py @@ -1,31 +1,30 @@ import logging import typing +from datetime import datetime, timedelta +from uuid import uuid4 -import sqlalchemy - -from sqlalchemy import Column, Text -from Models.plugins import IAuth -from fastapi import APIRouter, Response, Depends, HTTPException, Form, Request +from fastapi import APIRouter, Depends, Form, HTTPException, Request, Response from fastapi.security import OAuth2PasswordRequestForm from passlib.context import CryptContext +from sqlalchemy import Column, Text from sqlalchemy.orm import Session -from uuid import uuid4 -from datetime import datetime, timedelta +from Models.database import PwdDb, UserDb +from Models.plugins import IAuth +from Models.requests import SourceStorageReq +from Models.response import ExceptionResponse, StandardResponse +from Models.user import Token, User, UserData from Services.Database.database import get_db -from Services.Modulator.manager import plugin_manager from Services.Limiter.limiter import limiter +from Services.Modulator.manager import PluginUtils, plugin_manager from Services.Security.user import ( ACCESS_TOKEN_EXPIRE_MINUTES, create_access_token, - get_current_user, - encrypt_src_password, decrypt_src_password, + encrypt_src_password, + get_current_user, + get_user_data, ) -from Models.database import UserDb, PwdDb -from Models.user import Token, User -from Models.response import ExceptionResponse, StandardResponse -from Models.requests import SourceStorageReq from Utils.convert import sql_typecast user_router = APIRouter(prefix="/user") @@ -66,6 +65,7 @@ async def user_reg( @user_router.post("/login") @limiter.limit("5/minute") async def user_login( + request: Request, body: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db), ): @@ -99,19 +99,24 @@ async def src_login_info(src: str): @user_router.post("/{src}/login") -async def src_login(src: str, body: dict[str, str]): - if (source := plugin_manager.get_source(src)) is None or not isinstance( - source.instance, IAuth - ): - raise ExceptionResponse.not_found - +async def src_login( + response: Response, + src: str, + body: dict[str, str], + user_data: UserData = Depends(get_user_data), +): if ( - not isinstance(source.service.get("login"), list) + (source := plugin_manager.get_source(src)) is None + or not isinstance(source.instance, IAuth) + or not isinstance(source.service.get("login"), list) or source.service["login"] == [] ): raise ExceptionResponse.not_found - result = source.instance.login(body) + result = await source.instance.login(body, user_data) + response.set_cookie(key="plugin_cookies", value=user_data.__str__()) + + return result @user_router.post("/{src_id}/encrypt") diff --git a/Services/Database/database.py b/Services/Database/database.py index 15a0d5c..ecdcbe0 100644 --- a/Services/Database/database.py +++ b/Services/Database/database.py @@ -15,7 +15,7 @@ raise ValueError("Please complete the database configuration") engine = create_engine(f"{protocol}://{auth}{host}:{port}/{db}") -SessionLocal = sessionmaker(autocommit=False, bind=engine) +SessionLocal = sessionmaker(autocommit=False, bind=engine, expire_on_commit=True) Base = declarative_base() diff --git a/Services/Modulator/manager.py b/Services/Modulator/manager.py index 9b40247..362a34d 100644 --- a/Services/Modulator/manager.py +++ b/Services/Modulator/manager.py @@ -1,12 +1,16 @@ import importlib +import json import logging import os -import toml +from http.cookies import BaseCookie from pathlib import Path -from typing import Set +from typing import Set, cast + +import toml +from fastapi import Response -from Models.plugins import BasePlugin, Plugin from Configs.config import config +from Models.plugins import BasePlugin, Plugin logger = logging.getLogger(__name__) @@ -50,10 +54,9 @@ def load_plugin(self, plugin_dir: Path) -> bool: ) return False else: - logger.info(f"Registered source {src}") - self.registered_source.add(src) + logger.info(f"Registering source {src}") - module = importlib.import_module(f"Plugins.{plugin_dir.name}.main") + module = importlib.import_module(f"Plugins.{plugin_dir.name}.main") if issubclass(entry := getattr(module, plugin_dir.name), BasePlugin): instance = entry() if instance.on_load(): @@ -63,16 +66,23 @@ def load_plugin(self, plugin_dir: Path) -> bool: version=plugin_info["description"]["version"], cnm_version=plugin_info["plugin"]["cnm-version"], source=plugin_info["plugin"]["source"], + service=plugin_info["service"], instance=instance, ) ) else: raise ImportError - logger.info(f"Loaded plugin {plugin_dir.name}") + logger.info(f"Plugin {plugin_dir.name} Loaded") + self.registered_source.add(src) return True else: logger.error(f"Plugin {plugin_dir.name} is not a valid plugin") return False + except ModuleNotFoundError as module_err: + logger.error( + f"Failed to load {plugin_dir.name}, plugin requires some dependencies: {module_err.msg}" + ) + return False except FileNotFoundError: logger.error(f"Failed to load plugin {plugin_dir.name}'s information") return False @@ -103,4 +113,25 @@ def get_source(self, source: str) -> Plugin | None: return None +class PluginUtils: + @staticmethod + def load_cookies(cookies_str: str | None) -> dict[str, BaseCookie[str]]: + cookies = dict[str, BaseCookie[str]]() + if cookies_str is None or cookies_str == "": + return cookies + + try: + plugin_cookies: dict[str, str] = json.loads(cookies_str) + + for src in plugin_manager.registered_source: + if src in plugin_cookies: + cookies[src] = BaseCookie[str](plugin_cookies[src]) + else: + cookies[src] = BaseCookie[str]() + except: + logger.warning("Failed to load cookies") + + return cookies + + plugin_manager = PluginManager() diff --git a/Services/Security/user.py b/Services/Security/user.py index 2e37529..bdecceb 100644 --- a/Services/Security/user.py +++ b/Services/Security/user.py @@ -1,22 +1,24 @@ import os from datetime import timedelta, datetime, UTC +from typing import Annotated from jose import jwt, JWTError -from fastapi import Depends +from fastapi import Cookie, Depends, Request from fastapi.security import OAuth2PasswordBearer from sqlalchemy.orm import Session from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad from Services.Database.database import get_db +from Services.Modulator.manager import PluginUtils from Models.database import UserDb -from Models.user import User +from Models.user import User, UserData from Models.response import ExceptionResponse SECRET_KEY: str | None = os.environ.get("SECRET_KEY") ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 30 -if not SECRET_KEY or SECRET_KEY.__len__() != 32: +if not SECRET_KEY or SECRET_KEY.__len__() < 32: raise ValueError( "Please set `SECRET_KEY` environment variable, you can generate one with `openssl rand -hex 32`" ) @@ -49,7 +51,6 @@ def get_current_user( user: UserDb | None = db.query(UserDb).filter(UserDb.uid == uid).first() if user is None: raise ExceptionResponse.auth - return User( uid=user.uid, email=user.email, @@ -58,6 +59,15 @@ def get_current_user( ) +def get_user_data( + plugin_cookies: Annotated[str | None, Cookie()] = None, + user: User = Depends(get_current_user), +): + return UserData( + uid=user.uid, plugin_cookies=PluginUtils.load_cookies(plugin_cookies) + ) + + def encrypt_src_password(key: str, src_pwd: str) -> str: cipher = AES.new(pad(key.encode("utf-8"), AES.block_size), AES.MODE_ECB) return cipher.encrypt(pad(src_pwd.encode("utf-8"), AES.block_size)).hex() diff --git a/Utils/convert.py b/Utils/convert.py index f68f268..a0d7856 100644 --- a/Utils/convert.py +++ b/Utils/convert.py @@ -1,5 +1,4 @@ -from sqlalchemy.sql.elements import ColumnElement -from typing import Type, cast, Any +from typing import Any, Type, cast def sql_typecast[T](value: Any, target_type: Type[T]) -> T: diff --git a/main.py b/main.py index fba3a72..d904adf 100644 --- a/main.py +++ b/main.py @@ -1,22 +1,29 @@ -import logging, fastapi +import logging + +import fastapi from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from slowapi.errors import RateLimitExceeded from rich.logging import RichHandler +from slowapi.errors import RateLimitExceeded -from Routers.user import user_router from Models.response import ExceptionResponse +from Routers.comic import comic_router +from Routers.user import user_router from Services.Limiter.limiter import limiter from Services.Modulator.manager import plugin_manager -logging.basicConfig(level="INFO", format="%(message)s", datefmt="[%X]", - handlers=[RichHandler(rich_tracebacks=True, tracebacks_suppress=[fastapi])]) +logging.basicConfig( + level="INFO", + format="%(message)s", + datefmt="[%X]", + handlers=[RichHandler(rich_tracebacks=True, tracebacks_suppress=[fastapi])], +) app = FastAPI() app.state.limiter = limiter -app.add_exception_handler(RateLimitExceeded, ExceptionResponse.limit_exceeded) # type: ignore +app.add_exception_handler(RateLimitExceeded, ExceptionResponse.limit_exceeded) app.add_middleware( - CORSMiddleware, # type: ignore + CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], @@ -26,3 +33,4 @@ plugin_manager.load_plugins() app.include_router(user_router) +app.include_router(comic_router)