Skip to content

Commit

Permalink
feat: PluginUtils; Plugin function; fix: Plugin loader may leak unreg…
Browse files Browse the repository at this point in the history
…ister source into source list
  • Loading branch information
BiDuang committed Jul 17, 2024
1 parent e321b15 commit 07cff2a
Show file tree
Hide file tree
Showing 12 changed files with 173 additions and 56 deletions.
File renamed without changes.
6 changes: 3 additions & 3 deletions Models/database.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from sqlalchemy import Column, INT, TEXT, TIMESTAMP
from sqlalchemy import INT, TEXT, TIMESTAMP, Column

from Services.Database.database import Base


class UserDb(Base):
__tablename__ = "user"
id = Column(INT, unique=True, primary_key=True, index=True, nullable=False)
id = Column(INT, unique=True, primary_key=True, index=True, nullable=False, autoincrement=True)
uid = Column(TEXT, unique=True, index=True, nullable=False)
username = Column(TEXT, unique=True, index=True, nullable=False)
email = Column(TEXT, unique=True, nullable=False)
Expand All @@ -15,7 +15,7 @@ class UserDb(Base):

class PwdDb(Base):
__tablename__ = "src_pwd"
id = Column(INT, unique=True, primary_key=True, index=True, nullable=False)
id = Column(INT, unique=True, primary_key=True, index=True, nullable=False, autoincrement=True)
source = Column(TEXT, index=True, nullable=False)
uid = Column(TEXT, index=True, nullable=False)
account = Column(TEXT, nullable=False)
Expand Down
32 changes: 27 additions & 5 deletions Models/plugins.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import asyncio
import inspect
from abc import ABC, abstractmethod
from typing import Any

import nest_asyncio

from Models.response import StandardResponse
from Models.comic import BaseComicInfo
from Models.response import StandardResponse
from Models.user import User, UserData

nest_asyncio.apply()


class BasePlugin(ABC):
Expand All @@ -17,13 +25,13 @@ def search(self, keyword: str) -> list[BaseComicInfo]:
pass


class IAuth:
class IAuth(ABC):
@abstractmethod
def login(self, body: dict[str, str]) -> StandardResponse:
async def login(self, body: dict[str, str], user: UserData) -> StandardResponse:
pass


class IShaper:
class IShaper(ABC):
@abstractmethod
def imager_shaper(self):
pass
Expand All @@ -34,7 +42,7 @@ class Plugin:
version: str
cnm_version: str
source: list[str]
service: dict
service: dict[str, list[str]]
instance: BasePlugin

def __init__(
Expand All @@ -43,10 +51,24 @@ def __init__(
version: str,
cnm_version: str,
source: list[str],
service: dict[str, list[str]],
instance: BasePlugin,
):
self.name = name
self.version = version
self.cnm_version = cnm_version
self.source = source
self.service = service
self.instance = instance

def try_call(self, method: str, *args: Any, **kwargs: Any) -> Any:
if hasattr(self.instance, method):
if inspect.iscoroutinefunction(getattr(self.instance, method)):
loop = asyncio.get_event_loop()
return loop.run_until_complete(
getattr(self.instance, method)(*args, **kwargs)
)
else:
return getattr(self.instance, method)(*args, **kwargs)
else:
return None
6 changes: 6 additions & 0 deletions Models/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ def limit_exceeded(request: Request, exc: RateLimitExceeded):


class StandardResponse[T](BaseModel):
"""
Standard Response Model
~~~~~~~~~~~~~~~~~~~~~
This model is the default web response format of the API.
"""

status_code: int
message: str | None
data: T | None
Expand Down
23 changes: 23 additions & 0 deletions Models/user.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import json
from datetime import datetime
from http.cookies import BaseCookie

from pydantic import BaseModel


Expand All @@ -9,6 +12,26 @@ class User(BaseModel):
created_at: datetime


class UserData:
uid: str
plugin_cookies: dict[str, BaseCookie[str]]

def __init__(self, uid: str, plugin_cookies: dict[str, BaseCookie[str]]):
self.uid = uid
self.plugin_cookies = plugin_cookies

def get_src_cookies(self, src: str) -> BaseCookie[str]:
if src in self.plugin_cookies:
return self.plugin_cookies[src]
else:
return BaseCookie[str]()

def __str__(self):
return json.dumps(
{k: v.output(header="").strip() for k, v in self.plugin_cookies.items()}
)


class Token(BaseModel):
access_token: str
token_type: str
23 changes: 18 additions & 5 deletions Routers/comic.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from fastapi import APIRouter, Response, Depends, HTTPException, Form, Request
from fastapi import APIRouter, Depends

from Models.user import User
from Models.requests import ComicSearchReq
from Models.response import ExceptionResponse, StandardResponse
from Services.Database.database import get_db
from Services.Limiter.limiter import limiter
from Services.Security.user import get_current_user
from Models.user import User, UserData
from Services.Modulator.manager import plugin_manager
from Services.Security.user import get_current_user, get_user_data

comic_router = APIRouter(prefix="/comic")

Expand All @@ -16,6 +14,21 @@ async def search_comic(body: ComicSearchReq, user: User = Depends(get_current_us
pass


@comic_router.get("/{src_id}/favor")
async def get_favor(
src_id: str,
data: dict[str, str] | None = None,
user_data: UserData = Depends(get_user_data),
):
if (source := plugin_manager.get_source(src_id)) is None:
raise ExceptionResponse.not_found

if (resp := source.try_call("get_favor", user_data, data)) is not None:
return resp

return StandardResponse(status_code=404, message="Not Found")


@comic_router.get("/{src_id}/album/{album_id}")
async def get_album(src_id: str, album_id: str, user: User = Depends(get_current_user)):
if (source := plugin_manager.get_source(src_id)) is None:
Expand Down
49 changes: 27 additions & 22 deletions Routers/user.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,30 @@
import logging
import typing
from datetime import datetime, timedelta
from uuid import uuid4

import sqlalchemy

from sqlalchemy import Column, Text
from Models.plugins import IAuth
from fastapi import APIRouter, Response, Depends, HTTPException, Form, Request
from fastapi import APIRouter, Depends, Form, HTTPException, Request, Response
from fastapi.security import OAuth2PasswordRequestForm
from passlib.context import CryptContext
from sqlalchemy import Column, Text
from sqlalchemy.orm import Session
from uuid import uuid4
from datetime import datetime, timedelta

from Models.database import PwdDb, UserDb
from Models.plugins import IAuth
from Models.requests import SourceStorageReq
from Models.response import ExceptionResponse, StandardResponse
from Models.user import Token, User, UserData
from Services.Database.database import get_db
from Services.Modulator.manager import plugin_manager
from Services.Limiter.limiter import limiter
from Services.Modulator.manager import PluginUtils, plugin_manager
from Services.Security.user import (
ACCESS_TOKEN_EXPIRE_MINUTES,
create_access_token,
get_current_user,
encrypt_src_password,
decrypt_src_password,
encrypt_src_password,
get_current_user,
get_user_data,
)
from Models.database import UserDb, PwdDb
from Models.user import Token, User
from Models.response import ExceptionResponse, StandardResponse
from Models.requests import SourceStorageReq
from Utils.convert import sql_typecast

user_router = APIRouter(prefix="/user")
Expand Down Expand Up @@ -66,6 +65,7 @@ async def user_reg(
@user_router.post("/login")
@limiter.limit("5/minute")
async def user_login(
request: Request,
body: OAuth2PasswordRequestForm = Depends(),
db: Session = Depends(get_db),
):
Expand Down Expand Up @@ -99,19 +99,24 @@ async def src_login_info(src: str):


@user_router.post("/{src}/login")
async def src_login(src: str, body: dict[str, str]):
if (source := plugin_manager.get_source(src)) is None or not isinstance(
source.instance, IAuth
):
raise ExceptionResponse.not_found

async def src_login(
response: Response,
src: str,
body: dict[str, str],
user_data: UserData = Depends(get_user_data),
):
if (
not isinstance(source.service.get("login"), list)
(source := plugin_manager.get_source(src)) is None
or not isinstance(source.instance, IAuth)
or not isinstance(source.service.get("login"), list)
or source.service["login"] == []
):
raise ExceptionResponse.not_found

result = source.instance.login(body)
result = await source.instance.login(body, user_data)
response.set_cookie(key="plugin_cookies", value=user_data.__str__())

return result


@user_router.post("/{src_id}/encrypt")
Expand Down
2 changes: 1 addition & 1 deletion Services/Database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
raise ValueError("Please complete the database configuration")

engine = create_engine(f"{protocol}://{auth}{host}:{port}/{db}")
SessionLocal = sessionmaker(autocommit=False, bind=engine)
SessionLocal = sessionmaker(autocommit=False, bind=engine, expire_on_commit=True)
Base = declarative_base()


Expand Down
45 changes: 38 additions & 7 deletions Services/Modulator/manager.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import importlib
import json
import logging
import os
import toml
from http.cookies import BaseCookie
from pathlib import Path
from typing import Set
from typing import Set, cast

import toml
from fastapi import Response

from Models.plugins import BasePlugin, Plugin
from Configs.config import config
from Models.plugins import BasePlugin, Plugin

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -50,10 +54,9 @@ def load_plugin(self, plugin_dir: Path) -> bool:
)
return False
else:
logger.info(f"Registered source {src}")
self.registered_source.add(src)
logger.info(f"Registering source {src}")

module = importlib.import_module(f"Plugins.{plugin_dir.name}.main")
module = importlib.import_module(f"Plugins.{plugin_dir.name}.main")
if issubclass(entry := getattr(module, plugin_dir.name), BasePlugin):
instance = entry()
if instance.on_load():
Expand All @@ -63,16 +66,23 @@ def load_plugin(self, plugin_dir: Path) -> bool:
version=plugin_info["description"]["version"],
cnm_version=plugin_info["plugin"]["cnm-version"],
source=plugin_info["plugin"]["source"],
service=plugin_info["service"],
instance=instance,
)
)
else:
raise ImportError
logger.info(f"Loaded plugin {plugin_dir.name}")
logger.info(f"Plugin {plugin_dir.name} Loaded")
self.registered_source.add(src)
return True
else:
logger.error(f"Plugin {plugin_dir.name} is not a valid plugin")
return False
except ModuleNotFoundError as module_err:
logger.error(
f"Failed to load {plugin_dir.name}, plugin requires some dependencies: {module_err.msg}"
)
return False
except FileNotFoundError:
logger.error(f"Failed to load plugin {plugin_dir.name}'s information")
return False
Expand Down Expand Up @@ -103,4 +113,25 @@ def get_source(self, source: str) -> Plugin | None:
return None


class PluginUtils:
@staticmethod
def load_cookies(cookies_str: str | None) -> dict[str, BaseCookie[str]]:
cookies = dict[str, BaseCookie[str]]()
if cookies_str is None or cookies_str == "":
return cookies

try:
plugin_cookies: dict[str, str] = json.loads(cookies_str)

for src in plugin_manager.registered_source:
if src in plugin_cookies:
cookies[src] = BaseCookie[str](plugin_cookies[src])
else:
cookies[src] = BaseCookie[str]()
except:
logger.warning("Failed to load cookies")

return cookies


plugin_manager = PluginManager()
Loading

0 comments on commit 07cff2a

Please sign in to comment.