Skip to content

Commit

Permalink
pydantic user config validation
Browse files Browse the repository at this point in the history
  • Loading branch information
g0ldyy committed Jul 4, 2024
1 parent bbb257d commit ddb42c9
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 32 deletions.
4 changes: 2 additions & 2 deletions comet/api/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ async def stream(request: Request, b64config: str, type: str, id: str):

tasks = []
filtered = 0
filter_title = config["filterTitles"] if "filterTitles" in config else True # not needed when pydantic config validation system implemented
filter_title = config["filterTitles"]
for torrent in torrents:
if filter_title:
parsed_torrent = parse(
Expand All @@ -207,7 +207,7 @@ async def stream(request: Request, b64config: str, type: str, id: str):

tasks.append(get_torrent_hash(session, indexer_manager_type, torrent))

logger.info(f"{filtered} filtered torrents from Zilean API for {logName}")
logger.info(f"{filtered} filtered torrents for {logName}")

torrent_hashes = await asyncio.gather(*tasks)
torrent_hashes = list(set([hash for hash in torrent_hashes if hash]))
Expand Down
3 changes: 3 additions & 0 deletions comet/debrid/alldebrid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
class AllDebrid:
def __init__(self):
pass
9 changes: 7 additions & 2 deletions comet/debrid/manager.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import aiohttp

from .realdebrid import RealDebrid
from .alldebrid import AllDebrid


def getDebrid(session: aiohttp.ClientSession, config: dict):
if config["debridService"] == "realdebrid":
return RealDebrid(session, config["debridApiKey"])
debrid_service = config["debridService"]
debrid_api_key = config["debridApiKey"]
if debrid_service == "realdebrid":
return RealDebrid(session, debrid_api_key)
elif debrid_service == "alldebrid":
return AllDebrid(session, debrid_api_key)
4 changes: 2 additions & 2 deletions comet/templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -571,13 +571,13 @@
const button = document.querySelector("sl-button");
const alert = document.querySelector('sl-alert[variant="neutral"]');
button.addEventListener("click", () => {
const debridService = document.getElementById("debridService").value;
const debridApiKey = document.getElementById("debridApiKey").value;
const indexers = Array.from(document.getElementById("indexers").selectedOptions).map(option => option.value);
const languages = Array.from(document.getElementById("languages").selectedOptions).map(option => option.value);
const resolutions = Array.from(document.getElementById("resolutions").selectedOptions).map(option => option.value);
const maxResults = document.getElementById("maxResults").value;
const filterTitles = document.getElementById("filterTitles").checked;
const debridService = document.getElementById("debridService").value;
const debridApiKey = document.getElementById("debridApiKey").value;

const selectedLanguages = languages.length === defaultLanguages.length && languages.every((val, index) => val === defaultLanguages[index]) ? ["All"] : languages;
const selectedResolutions = resolutions.length === defaultResolutions.length && resolutions.every((val, index) => val === defaultResolutions[index]) ? ["All"] : resolutions;
Expand Down
40 changes: 15 additions & 25 deletions comet/utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import bencodepy

from comet.utils.logger import logger
from comet.utils.models import settings
from comet.utils.models import settings, ConfigModel

translation_table = {
"ā": "a",
Expand Down Expand Up @@ -181,26 +181,8 @@ def bytes_to_size(bytes: int):
def config_check(b64config: str):
try:
config = json.loads(base64.b64decode(b64config).decode())

if not isinstance(config["debridService"], str) or config[
"debridService"
] not in ["realdebrid"]:
return False
if not isinstance(config["debridApiKey"], str):
return False
if not isinstance(config["indexers"], list):
return False
if not isinstance(config["maxResults"], int) or config["maxResults"] < 0:
return False
if (
not isinstance(config["resolutions"], list)
or len(config["resolutions"]) == 0
):
return False
if not isinstance(config["languages"], list) or len(config["languages"]) == 0:
return False

return config
validated_config = ConfigModel(**config)
return validated_config.model_dump()
except:
return False

Expand Down Expand Up @@ -302,15 +284,23 @@ async def get_torrent_hash(
async def get_balanced_hashes(hashes: dict, config: dict):
max_results = config["maxResults"]
config_resolutions = config["resolutions"]
config_languages = {language.replace("_", " ").capitalize() for language in config["languages"]}
config_languages = {
language.replace("_", " ").capitalize() for language in config["languages"]
}
include_all_languages = "All" in config_languages
include_all_resolutions = "All" in config_resolutions
include_unknown_resolution = include_all_resolutions or "Unknown" in config_resolutions
include_unknown_resolution = (
include_all_resolutions or "Unknown" in config_resolutions
)

hashes_by_resolution = {}
for hash, hash_data in hashes.items():
hash_info = hash_data["data"]
if not include_all_languages and not hash_info["is_multi_audio"] and not any(lang in hash_info["language"] for lang in config_languages):
if (
not include_all_languages
and not hash_info["is_multi_audio"]
and not any(lang in hash_info["language"] for lang in config_languages)
):
continue

resolution = hash_info["resolution"]
Expand Down Expand Up @@ -348,7 +338,7 @@ async def get_balanced_hashes(hashes: dict, config: dict):
if missing_hashes <= 0:
break
current_count = len(balanced_hashes[resolution])
available_hashes = hash_list[current_count:current_count + missing_hashes]
available_hashes = hash_list[current_count : current_count + missing_hashes]
balanced_hashes[resolution].extend(available_hashes)
missing_hashes -= len(available_hashes)

Expand Down
35 changes: 34 additions & 1 deletion comet/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import List, Optional
from databases import Database
from pydantic import BaseModel, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
from RTN import RTN, BaseRankingModel, SettingsModel

Expand All @@ -27,6 +28,39 @@ class AppSettings(BaseSettings):
CUSTOM_HEADER_HTML: Optional[str] = None


settings = AppSettings()


class ConfigModel(BaseModel):
indexers: List[str]
languages: Optional[List[str]] = ["All"]
resolutions: Optional[List[str]] = ["All"]
maxResults: Optional[int] = 0
filterTitles: Optional[bool] = True
debridService: str
debridApiKey: str

@field_validator("indexers")
def check_indexers(cls, v, values):
if not any(indexer in settings.INDEXER_MANAGER_INDEXERS for indexer in v):
raise ValueError(
f"At least one indexer must be from {settings.INDEXER_MANAGER_INDEXERS}"
)
return v

@field_validator("maxResults")
def check_max_results(cls, v):
if v < 0:
raise ValueError("maxResults cannot be less than 0")
return v

@field_validator("debridService")
def check_debrid_service(cls, v):
if v not in ["realdebrid", "realdebrid"]:
raise ValueError("Invalid debridService")
return v


class BestOverallRanking(BaseRankingModel):
uhd: int = 100
fhd: int = 90
Expand All @@ -53,5 +87,4 @@ class BestOverallRanking(BaseRankingModel):

# For use anywhere
rtn = RTN(settings=rtn_settings, ranking_model=rtn_ranking)
settings = AppSettings()
database = Database(f"sqlite:///{settings.DATABASE_PATH}")

0 comments on commit ddb42c9

Please sign in to comment.