diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 99496f0..5b1e925 100755 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,10 +1,12 @@ repos: - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.3.0 + rev: v1.8.0 hooks: - id: mypy exclude: deprecated|deploy|venv - args: [--python-version=3.11, --ignore-missing-imports, --install-types, --non-interactive,] + additional_dependencies: + - sqlmodel==0.0.22 + args: [--python-version=3.12, --ignore-missing-imports, --install-types, --non-interactive,] - repo: local hooks: @@ -20,7 +22,7 @@ repos: entry: black language: python types: [ python ] - args: [ --line-length=88, --target-version=py311 ] + args: [ --line-length=88, --target-version=py312, --force-exclude, src/db/migrations/*] # For upgrade python code syntax for newer versions of the language - id: pyupgrade diff --git a/Dockerfile b/Dockerfile index ec3ebdf..53ed948 100755 --- a/Dockerfile +++ b/Dockerfile @@ -1,29 +1,4 @@ -#FROM python:3.11-slim-bookworm -# -## Update the package list and install required packages -#RUN apt-get update && \ -# apt-get install -y git && \ -# apt-get clean && \ -# rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* -# -## Copy the application source code to the container and set the working directory -#COPY ./requirements.txt /app/requirements.txt -#COPY ./src /app/src -# -## Copy alembic files -#COPY alembic.ini alembic-upgrade.bash /app/ -# -#WORKDIR /app -# -## Install Python dependencies -#RUN pip3 install --no-cache-dir --upgrade pip && \ -# pip3 install --no-cache-dir -r requirements.txt -# -## Set the user to non-root -#USER 1000 - - -FROM python:3.11-alpine +FROM python:3.12-alpine # Set up environment variables for non-interactive installation ENV PYTHONDONTWRITEBYTECODE=1 \ diff --git a/README.MD b/README.MD index a1f1efa..eec85d8 100644 --- a/README.MD +++ b/README.MD @@ -1,10 +1,17 @@ -# telegram-youtube-notifier-bot +# telegram-stream-notifier-bot (former telegram-youtube-notifier-bot) > ⚠️ Warning: project is still under development, use with caution. -Simple Youtube LiveStreams notifier in telegram based on [youtube-dlp](https://github.com/yt-dlp/yt-dlp) and [aiogram](https://github.com/aiogram/aiogram) . +Simple LiveStreams notifier in telegram based on [aiogram](https://github.com/aiogram/aiogram), [aiogram-dialog](https://github.com/Tishka17/aiogram_dialog) and [aps-schedule](https://github.com/agronholm/apscheduler). -Use this bot to receive periodic reports on live broadcasts on YouTube and generate and send the report to telegram. +### Available platforms + +| Platform | Based on | Status | +|-----------------|-------------------------------------------------|----------------------------------------------------------------------| +| Youtube | [youtube-dlp](https://github.com/yt-dlp/yt-dlp) | ✅ | +| Twitch | [twitch-api](https://github.com/Teekeks/pyTwitchAPI) | ✅ | + +Use this bot to receive periodic reports on live broadcasts on stream platforms and generate and send the report to telegram. The current version of the bot works in the telegram channel [НАСРАНО](https://t.me/HACPAH1). @@ -16,7 +23,7 @@ The current version of the bot works in the telegram channel [НАСРАНО](ht 3. Regularly checking active channels using [aps-schedule](https://github.com/agronholm/apscheduler). -4. Getting information about streams using [youtube-dlp](https://github.com/yt-dlp/yt-dlp). +4. Getting information about streams using packages from [list](#Available-platforms). ## Commands @@ -26,19 +33,18 @@ The current version of the bot works in the telegram channel [НАСРАНО](ht | channels | User | Channels Administration | | cancel | User | Clear current fsm-state (if error) | | add_user | Superuser | Add user | -| add_channels | Superuser | Bulk channels importing from file | | scheduler_start | Superuser | Start scheduling periodic tasks of fetching information about streams | | scheduler_pause | Superuser | Pause scheduling periodic tasks of fetching information about streams | ## Database schema - + ## Deploy ### Install from source -> Tested on Ubuntu 22.04, python 3.11 +> Tested on Ubuntu 22.04, python 3.12 Just copy source code: @@ -96,7 +102,7 @@ Or you can do it step by step: `kubectl apply -f ` -## Setup Cookies +## Setup Cookies for Youtube You can use own cookies for yt-dlp extractors. diff --git a/db-schema.png b/db-schema.png new file mode 100644 index 0000000..e66b682 Binary files /dev/null and b/db-schema.png differ diff --git a/deploy/docker/example.config.yaml b/deploy/docker/example.config.yaml index d45c742..8e56f67 100755 --- a/deploy/docker/example.config.yaml +++ b/deploy/docker/example.config.yaml @@ -8,7 +8,7 @@ chat_id: 123456789 temp_chat_id: 987654321 # Report customization -# With Jinja2 template you can access to ChanelDescription object via 'channel' variable +# With Jinja2 template you can access to ChanelDescription object via 'channel_listing' variable # #class ChannelDescription(BaseModel): # url: str @@ -51,3 +51,12 @@ start_scheduler: False # interval between checks in seconds interval_s: 60 +# For notification from Twitch +# You can obtain from https://dev.twitch.tv/console +# twitch: +# app_id: xxxxxxxxxxxxxxxxxxxxxxxxxx +# app_secret: xxxxxxxxxxxxxxxxxxxxxxxxxx + +# Path to cookie file for youtube-dlp +# youtube: +# cookies_filepath: /path/to/cookies.txt diff --git a/deploy/docker/example.docker-compose.yml b/deploy/docker/example.docker-compose.yml index 70e492f..395b559 100755 --- a/deploy/docker/example.docker-compose.yml +++ b/deploy/docker/example.docker-compose.yml @@ -7,6 +7,8 @@ services: restart: always volumes: - ./config.yaml:/app/config.yaml + # youtube cookies optional + # make sure that has same path at config - ./cookies.txt:/app/cookies.txt:touch - ./youtube-notifier-bot.db:/app/youtube-notifier-bot.db - /etc/timezone:/etc/timezone:ro diff --git a/deploy/k8s/01_configmap.yaml b/deploy/k8s/01_configmap.yaml index ed105cf..363d671 100644 --- a/deploy/k8s/01_configmap.yaml +++ b/deploy/k8s/01_configmap.yaml @@ -46,6 +46,14 @@ data: start_scheduler: True interval_s: 300 + # For notification from Twitch + # You can obtain from https://dev.twitch.tv/console + # twitch: + # app_id: xxxxxxxxxxxxxxxxxxxxxxxxxx + # app_secret: xxxxxxxxxxxxxxxxxxxxxxxxxx + # Path to cookie file for youtube-dlp + # youtube: + # cookies_filepath: /path/to/cookies.txt cookies.txt: | # Netscape HTTP Cookie File # This file is generated by yt-dlp. Do not edit. \ No newline at end of file diff --git a/erd-generator.bash b/erd-generator.bash new file mode 100644 index 0000000..7fb8761 --- /dev/null +++ b/erd-generator.bash @@ -0,0 +1 @@ +eralchemy -i sqlite:///youtube-notifier-bot.db -o db-schema.png --exclude-tables alembic_version --exclude-columns created_at updated_at \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index e7afaf4..aee9aba 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,19 +1,20 @@ -aiogram==3.10.0 +aiogram==3.13.1 aiogram-dialog==2.2.0 pydantic==2.8.2 APScheduler==3.10.4 yt-dlp==2024.8.6 +twitchapi==4.3.1 sulguk==0.8.0 jinja2==3.1.4 pyyaml==6.0.2 pydantic_settings==2.4.0 structlog==24.4.0 -orjson==3.10.5 +orjson==3.10.7 uvloop==0.20.0 -SQLAlchemy==2.0.31 +SQLAlchemy==2.0.35 alembic==1.13.2 aiosqlite==0.20.0 asyncclick==8.1.7.2 anyio==4.4.0 aiocache==0.12.2 -sqlalchemy-data-model-visualizer==0.1.3 +sqlmodel==0.0.22 \ No newline at end of file diff --git a/requirements_dev.txt b/requirements_dev.txt index 40c6d40..cce62aa 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -2,7 +2,8 @@ pre-commit==3.8.0 black==24.8.0 autoflake==2.3.1 reorder-python-imports==3.13.0 -pyupgrade==3.16.0 +pyupgrade==3.17.0 pytest==8.3.2 pytest-cov==5.0.0 +eralchemy==1.5.0 git+https://github.com/aio-libs/sort-all.git \ No newline at end of file diff --git a/src/bot/commands.py b/src/bot/commands.py index 87d9cee..b6005c5 100644 --- a/src/bot/commands.py +++ b/src/bot/commands.py @@ -4,9 +4,8 @@ def superuser_commands() -> list[BotCommand]: commands = [ BotCommand(command="add_user", description="Add user to bot."), - BotCommand(command="add_channels", description="Add channels from file."), - BotCommand(command="scheduler_start", description="Start scheduler"), - BotCommand(command="scheduler_pause", description="Stop scheduler"), + BotCommand(command="scheduler_start", description="Start scheduler."), + BotCommand(command="scheduler_pause", description="Stop scheduler."), ] return commands @@ -14,7 +13,7 @@ def superuser_commands() -> list[BotCommand]: def user_commands() -> list[BotCommand]: commands = [ BotCommand(command="cancel", description="Clear current state."), - BotCommand(command="add_channel", description="Add channel."), + BotCommand(command="add_channel", description="Add channel_listing."), BotCommand( command="channels", description="Start Channels Administration Dialog." ), diff --git a/src/bot/dialogs/__init__.py b/src/bot/dialogs/__init__.py index 070b648..90bd579 100644 --- a/src/bot/dialogs/__init__.py +++ b/src/bot/dialogs/__init__.py @@ -1,10 +1,10 @@ from aiogram import Dispatcher -from .channel import channel_dialog +from .user import register_user_dialogs def register_dialogs(dp: Dispatcher) -> None: - dp.include_router(channel_dialog) + register_user_dialogs(dp=dp) __all__ = ["register_dialogs"] diff --git a/src/bot/dialogs/user/__init__.py b/src/bot/dialogs/user/__init__.py new file mode 100644 index 0000000..fb377f9 --- /dev/null +++ b/src/bot/dialogs/user/__init__.py @@ -0,0 +1,12 @@ +from aiogram import Dispatcher + +from .channel import channel_create_dialog +from .channel import channels_list_dialog + + +def register_user_dialogs(dp: Dispatcher) -> None: + dp.include_router(channel_create_dialog) + dp.include_router(channels_list_dialog) + + +__all__ = ["register_user_dialogs"] diff --git a/src/bot/dialogs/user/channel/__init__.py b/src/bot/dialogs/user/channel/__init__.py new file mode 100644 index 0000000..f429c6c --- /dev/null +++ b/src/bot/dialogs/user/channel/__init__.py @@ -0,0 +1,5 @@ +from .add import channel_create_dialog +from .list import channels_list_dialog + + +__all__ = ["channel_create_dialog", "channels_list_dialog"] diff --git a/src/bot/dialogs/user/channel/add/__init__.py b/src/bot/dialogs/user/channel/add/__init__.py new file mode 100644 index 0000000..97de74e --- /dev/null +++ b/src/bot/dialogs/user/channel/add/__init__.py @@ -0,0 +1,16 @@ +from aiogram_dialog import Dialog +from aiogram_dialog import LaunchMode + +from .windows import label_handler_window +from .windows import select_channel_type_window +from .windows import url_handler_window + +channel_create_dialog = Dialog( + select_channel_type_window(), + url_handler_window(), + label_handler_window(), + launch_mode=LaunchMode.STANDARD, +) + + +__all__ = ["channel_create_dialog"] diff --git a/src/bot/dialogs/user/channel/add/constants.py b/src/bot/dialogs/user/channel/add/constants.py new file mode 100644 index 0000000..f124e4e --- /dev/null +++ b/src/bot/dialogs/user/channel/add/constants.py @@ -0,0 +1,18 @@ +from src.db.models.channel_type import ChannelType +from src.utils import kick_channel_url_validator +from src.utils import twitch_channel_url_validator +from src.utils import youtube_channel_url_validator + +url_validators: dict = { + ChannelType.YOUTUBE: youtube_channel_url_validator, + ChannelType.TWITCH: twitch_channel_url_validator, + ChannelType.KICK: kick_channel_url_validator, +} + +url_examples: dict = { + ChannelType.YOUTUBE: "https://www.youtube.com/@username", + ChannelType.TWITCH: "https://www.twitch.tv/username", + ChannelType.KICK: "empty", +} + +__all__ = ["url_examples", "url_validators"] diff --git a/src/bot/dialogs/user/channel/add/getters.py b/src/bot/dialogs/user/channel/add/getters.py new file mode 100644 index 0000000..bdbad91 --- /dev/null +++ b/src/bot/dialogs/user/channel/add/getters.py @@ -0,0 +1,45 @@ +from aiogram_dialog import DialogManager + +from .constants import url_examples +from .constants import url_validators +from src.db import DataAccessLayer +from src.db.models.channel_type import ChannelType + + +async def select_channel_type_window_getter(dialog_manager: DialogManager, **kwargs): + + dal: DataAccessLayer = dialog_manager.start_data["dal"] + + channels = await dal.get_channels() + + dialog_manager.dialog_data["dal"] = dal + dialog_manager.dialog_data["channels"] = channels + dialog_manager.dialog_data["channel_types"] = ChannelType.list() + + return { + "channel_types": dialog_manager.dialog_data["channel_types"], + } + + +async def url_handler_window_getter(dialog_manager: DialogManager, **kwargs): + + selected_channel_type = dialog_manager.dialog_data["selected_channel_type"] + url_validator = url_validators[selected_channel_type] + url_example = url_examples[selected_channel_type] + + dialog_manager.dialog_data["url_validator"] = url_validator + dialog_manager.dialog_data["url_example"] = url_example + + return {"url_example": url_example} + + +async def label_handler_window_getter(dialog_manager: DialogManager, **kwargs): + + return {"url": dialog_manager.dialog_data["url"]} + + +__all__ = [ + "label_handler_window_getter", + "select_channel_type_window_getter", + "url_handler_window_getter", +] diff --git a/src/bot/dialogs/user/channel/add/handlers.py b/src/bot/dialogs/user/channel/add/handlers.py new file mode 100644 index 0000000..75dce08 --- /dev/null +++ b/src/bot/dialogs/user/channel/add/handlers.py @@ -0,0 +1,87 @@ +from aiogram.types import Message +from aiogram_dialog import DialogManager +from aiogram_dialog.widgets.input import MessageInput +from sulguk import SULGUK_PARSE_MODE + +from src.bot.states import ChannelCreateSG +from src.db import DataAccessLayer +from src.db.models import ChannelModel +from src.db.models.channel_type import ChannelType + + +async def url_handler( + message: Message, + message_input: MessageInput, + dialog_manager: DialogManager, + **kwargs, +) -> None: + + url = message.text.lower().strip() + url_validator = dialog_manager.dialog_data["url_validator"] + dal: DataAccessLayer = dialog_manager.dialog_data["dal"] + + if url_validator(url): + dialog_manager.dialog_data["url"] = url + + obj = await dal.channel_dao.get_first(**{"url": url}) + if obj: + await message.answer( + f"❌ Channel {url} already exists with label {obj.label}.
" + f"Please set different url or /cancel for reject.", + parse_mode=SULGUK_PARSE_MODE, + ) + return + + await dialog_manager.switch_to(ChannelCreateSG.url_selected) + + +async def label_handler( + message: Message, + message_input: MessageInput, + dialog_manager: DialogManager, + **kwargs, +) -> None: + label = message.text.strip() + url = dialog_manager.dialog_data["url"] + if url and label: + dal: DataAccessLayer = dialog_manager.dialog_data["dal"] + + user_obj = await dal.get_user_by_attr(**{"user_id": message.from_user.id}) + if user_obj: + + selected_channel_type = dialog_manager.dialog_data["selected_channel_type"] + channel_type = ChannelType(selected_channel_type) + channel_type_obj, _ = await dal.channel_type_dao.get_or_create( + type=channel_type + ) + + channel = ChannelModel( + url=url, + label=label, + enabled=True, + user=user_obj, + type=channel_type_obj, + ) + + result = await dal.create_channel(obj=channel) + if result: + await message.answer( + f"✅ Created.
" + f"#id: {result.id}
" + f"label: {result.label}
" + f"url:{result.url}
", + parse_mode=SULGUK_PARSE_MODE, + ) + else: + await message.answer( + f" ❌ Cannot add channel with parameters:
" + f"label: {label}
" + f"url:{url}

" + f"Contact admins.", + parse_mode=SULGUK_PARSE_MODE, + ) + + await dialog_manager.done() + + +__all__ = ["label_handler", "url_handler"] diff --git a/src/bot/dialogs/user/channel/add/on_click.py b/src/bot/dialogs/user/channel/add/on_click.py new file mode 100644 index 0000000..b14cb56 --- /dev/null +++ b/src/bot/dialogs/user/channel/add/on_click.py @@ -0,0 +1,37 @@ +from contextlib import suppress +from typing import Any + +from aiogram.exceptions import TelegramBadRequest +from aiogram.types import CallbackQuery +from aiogram_dialog import ChatEvent +from aiogram_dialog import DialogManager +from aiogram_dialog.widgets.kbd import ( + Button, +) + + +async def on_finish( + callback: CallbackQuery, button: Button, manager: DialogManager +) -> None: + if manager.has_context(): + with suppress(TelegramBadRequest): + await callback.message.delete() + + await manager.done() + + +async def on_select_channel_type( + callback: ChatEvent, + select: Any, + manager: DialogManager, + item_id: str, +): + + manager.dialog_data["selected_channel_type"] = item_id + await manager.next() + + +__all__ = [ + "on_finish", + "on_select_channel_type", +] diff --git a/src/bot/dialogs/user/channel/add/windows.py b/src/bot/dialogs/user/channel/add/windows.py new file mode 100644 index 0000000..a3141ed --- /dev/null +++ b/src/bot/dialogs/user/channel/add/windows.py @@ -0,0 +1,73 @@ +from aiogram.enums import ContentType +from aiogram_dialog import Window +from aiogram_dialog.widgets.input import MessageInput +from aiogram_dialog.widgets.kbd import Button +from aiogram_dialog.widgets.kbd import Select +from aiogram_dialog.widgets.text import Const +from aiogram_dialog.widgets.text import Format +from aiogram_dialog.widgets.text import Multi +from sulguk import SULGUK_PARSE_MODE + +from .getters import label_handler_window_getter +from .getters import select_channel_type_window_getter +from .getters import url_handler_window_getter +from .handlers import label_handler +from .handlers import url_handler +from .on_click import on_finish +from .on_click import on_select_channel_type +from src.bot.states import ChannelCreateSG + + +def select_channel_type_window(): + return Window( + Const( + "Select channel type", + ), + Select( + Format("{item}"), + items="channel_types", + item_id_getter=lambda x: x, + id="channel_types", + on_click=on_select_channel_type, + ), + Button(Const("❌ Exit"), id="finish", on_click=on_finish), + state=ChannelCreateSG.start, + getter=select_channel_type_window_getter, + parse_mode=SULGUK_PARSE_MODE, + ) + + +def url_handler_window(): + return Window( + Multi( + Const( + "✅ OK, now you can set the URL. Please use the template below" + ), + Format("{url_example}"), + sep="
", + ), + MessageInput(url_handler, content_types=[ContentType.TEXT]), + Button(Const("❌ Exit"), id="finish", on_click=on_finish), + state=ChannelCreateSG.type_selected, + getter=url_handler_window_getter, + parse_mode=SULGUK_PARSE_MODE, + ) + + +def label_handler_window(): + return Window( + Multi( + Format( + "URL set to {url}, enter display name or /cancel for reject" + ), + sep="
", + ), + MessageInput(label_handler, content_types=[ContentType.TEXT]), + Button(Const("❌ Exit"), id="finish", on_click=on_finish), + state=ChannelCreateSG.url_selected, + getter=label_handler_window_getter, + parse_mode=SULGUK_PARSE_MODE, + ) + + +__all__ = ["label_handler_window", "select_channel_type_window", "url_handler_window"] diff --git a/src/bot/dialogs/channel/__init__.py b/src/bot/dialogs/user/channel/list/__init__.py similarity index 83% rename from src/bot/dialogs/channel/__init__.py rename to src/bot/dialogs/user/channel/list/__init__.py index 503a960..2de2d75 100644 --- a/src/bot/dialogs/channel/__init__.py +++ b/src/bot/dialogs/user/channel/list/__init__.py @@ -6,7 +6,7 @@ from .windows import turn_off_window from .windows import turn_on_window -channel_dialog = Dialog( +channels_list_dialog = Dialog( scroll_window(), delete_window(), turn_on_window(), @@ -15,4 +15,4 @@ ) -__all__ = ["channel_dialog"] +__all__ = ["channels_list_dialog"] diff --git a/src/bot/dialogs/channel/constants.py b/src/bot/dialogs/user/channel/list/constants.py similarity index 100% rename from src/bot/dialogs/channel/constants.py rename to src/bot/dialogs/user/channel/list/constants.py diff --git a/src/bot/dialogs/channel/getters.py b/src/bot/dialogs/user/channel/list/getters.py similarity index 95% rename from src/bot/dialogs/channel/getters.py rename to src/bot/dialogs/user/channel/list/getters.py index a8088ed..a3c9905 100644 --- a/src/bot/dialogs/channel/getters.py +++ b/src/bot/dialogs/user/channel/list/getters.py @@ -1,7 +1,7 @@ from aiogram_dialog import DialogManager -from ...filters import UserRole from .constants import ID_STUB_SCROLL +from src.bot.filters import UserRole from src.db import DataAccessLayer diff --git a/src/bot/dialogs/channel/on_click.py b/src/bot/dialogs/user/channel/list/on_click.py similarity index 68% rename from src/bot/dialogs/channel/on_click.py rename to src/bot/dialogs/user/channel/list/on_click.py index 01aceba..125fc19 100644 --- a/src/bot/dialogs/channel/on_click.py +++ b/src/bot/dialogs/user/channel/list/on_click.py @@ -9,9 +9,9 @@ Button, ) -from src.bot.states import ChannelsSG +from src.bot.states import ChannelsListSG from src.db import DataAccessLayer -from src.dto import ChannelRetrieveDTO +from src.db.models import ChannelModel async def on_finish( @@ -25,18 +25,18 @@ async def on_finish( async def on_delete(callback: CallbackQuery, button: Button, manager: DialogManager): - await manager.switch_to(ChannelsSG.delete) + await manager.switch_to(ChannelsListSG.delete) async def on_perform_delete( callback: CallbackQuery, button: Button, manager: DialogManager ): index = manager.dialog_data["current_page"] - channel: ChannelRetrieveDTO = manager.dialog_data["channels"][index] + channel: ChannelModel = manager.dialog_data["channels"][index] if channel.id: dal: DataAccessLayer = manager.start_data["dal"] - result = await dal.delete_channel_by_id(_id=channel.id) + result = await dal.delete_channel_by_id(id=channel.id) if result: await callback.answer("Success.") @@ -45,22 +45,22 @@ async def on_perform_delete( else: await callback.answer("Cannot delete row.") - await manager.switch_to(state=ChannelsSG.scrolling) + await manager.switch_to(state=ChannelsListSG.scrolling) async def on_turn_off(callback: CallbackQuery, button: Button, manager: DialogManager): - await manager.switch_to(ChannelsSG.turn_off) + await manager.switch_to(ChannelsListSG.turn_off) async def on_turn_on(callback: CallbackQuery, button: Button, manager: DialogManager): - await manager.switch_to(ChannelsSG.turn_on) + await manager.switch_to(ChannelsListSG.turn_on) async def on_perform_update( callback: CallbackQuery, button: Button, manager: DialogManager ): index = manager.dialog_data["current_page"] - channel: ChannelRetrieveDTO = manager.dialog_data["channels"][index] + channel: ChannelModel = manager.dialog_data["channels"][index] if channel.id: dal: DataAccessLayer = manager.start_data["dal"] @@ -74,9 +74,16 @@ async def on_perform_update( await callback.answer("Unknown callback data") return - result = await dal.update_channel_by_id( - _id=channel.id, data={"enabled": enabled} - ) + channel_obj = await dal.channel_dao.get_first(**{"id": channel.id}) + + if not channel_obj: + await callback.answer( + "❌ Channel does not exist. Please contact administrator." + ) + return + + channel_obj.enabled = enabled + result = await dal.update_channel_by_id(obj=channel_obj) if result: await callback.answer("Success.") @@ -85,7 +92,7 @@ async def on_perform_update( else: await callback.answer("Cannot update row.") - await manager.switch_to(state=ChannelsSG.scrolling) + await manager.switch_to(state=ChannelsListSG.scrolling) __all__ = [ diff --git a/src/bot/dialogs/channel/widgets.py b/src/bot/dialogs/user/channel/list/widgets.py similarity index 100% rename from src/bot/dialogs/channel/widgets.py rename to src/bot/dialogs/user/channel/list/widgets.py diff --git a/src/bot/dialogs/channel/windows.py b/src/bot/dialogs/user/channel/list/windows.py similarity index 88% rename from src/bot/dialogs/channel/windows.py rename to src/bot/dialogs/user/channel/list/windows.py index b7ec96f..126f58a 100644 --- a/src/bot/dialogs/channel/windows.py +++ b/src/bot/dialogs/user/channel/list/windows.py @@ -15,7 +15,6 @@ from aiogram_dialog.widgets.text import Multi from sulguk import SULGUK_PARSE_MODE -from ...filters import UserRole from .constants import ID_STUB_SCROLL from .getters import scroll_getter from .on_click import on_delete @@ -25,7 +24,8 @@ from .on_click import on_turn_off from .on_click import on_turn_on from .widgets import Viewer -from src.bot.states import ChannelsSG +from src.bot.states import ChannelsListSG +from src.db.models.user_role import UserRole def scroll_window(): @@ -77,14 +77,14 @@ def scroll_window(): Button(Const("❌ Exit"), id="finish", on_click=on_finish), when=~F["is_empty"], ), - state=ChannelsSG.scrolling, + state=ChannelsListSG.scrolling, getter=scroll_getter, parse_mode=SULGUK_PARSE_MODE, ) SWITCH_TO_SCROLLING = SwitchTo( - text=Const("🔙 No, return me back."), state=ChannelsSG.scrolling, id="back" + text=Const("🔙 No, return me back."), state=ChannelsListSG.scrolling, id="back" ) @@ -93,14 +93,14 @@ def delete_window(): Const("Are you sure?"), Row( Button( - Const("✅ Yes, delete this channel."), + Const("✅ Yes, delete this channel_listing."), id="delete", on_click=on_perform_delete, ), SWITCH_TO_SCROLLING, ), Button(Const("❌ Exit"), id="finish", on_click=on_finish), - state=ChannelsSG.delete, + state=ChannelsListSG.delete, getter=scroll_getter, ) @@ -110,14 +110,14 @@ def turn_on_window(): Const("Are you sure?"), Row( Button( - Const("✅ Yes, turn on this channel."), + Const("✅ Yes, turn on this channel_listing."), id="on", on_click=on_perform_update, ), SWITCH_TO_SCROLLING, ), Button(Const("❌ Exit"), id="finish", on_click=on_finish), - state=ChannelsSG.turn_on, + state=ChannelsListSG.turn_on, getter=scroll_getter, ) @@ -127,14 +127,14 @@ def turn_off_window(): Const("Are you sure?"), Row( Button( - Const("✅ Yes, turn off this channel."), + Const("✅ Yes, turn off this channel_listing."), id="off", on_click=on_perform_update, ), SWITCH_TO_SCROLLING, ), Button(Const("❌ Exit"), id="finish", on_click=on_finish), - state=ChannelsSG.turn_off, + state=ChannelsListSG.turn_off, getter=scroll_getter, ) diff --git a/src/bot/filters/role/role.py b/src/bot/filters/role.py similarity index 93% rename from src/bot/filters/role/role.py rename to src/bot/filters/role.py index 9b2f962..fce28fb 100755 --- a/src/bot/filters/role/role.py +++ b/src/bot/filters/role.py @@ -4,7 +4,7 @@ from aiogram.filters.base import Filter from aiogram.types import Message -from .model import UserRole +from src.db.models.user_role import UserRole class RoleFilter(Filter): diff --git a/src/bot/filters/role/__init__.py b/src/bot/filters/role/__init__.py deleted file mode 100755 index afc4599..0000000 --- a/src/bot/filters/role/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .model import UserRole -from .role import RoleFilter - -__all__ = ["RoleFilter", "UserRole"] diff --git a/src/bot/filters/role/model.py b/src/bot/filters/role/model.py deleted file mode 100755 index 26751fb..0000000 --- a/src/bot/filters/role/model.py +++ /dev/null @@ -1,16 +0,0 @@ -from enum import Enum - - -class UserRole(Enum): - SUPERUSER = "superuser" - USER = "user" - UNKNOWN = "unknown" - - def __repr__(self) -> str: - return self.value - - def __str__(self) -> str: - return self.value - - -__all__ = ["UserRole"] diff --git a/src/bot/handlers/superuser/__init__.py b/src/bot/handlers/superuser/__init__.py index 5e08ee1..6443fca 100644 --- a/src/bot/handlers/superuser/__init__.py +++ b/src/bot/handlers/superuser/__init__.py @@ -1,13 +1,11 @@ from aiogram import Router -from .add_channels import add_channels_router from .add_user import add_user_router from .scheduler import scheduler_router superuser_router = Router(name="superuser") superuser_router.include_router(add_user_router) superuser_router.include_router(scheduler_router) -superuser_router.include_router(add_channels_router) __all__ = ["superuser_router"] diff --git a/src/bot/handlers/superuser/add_channels.py b/src/bot/handlers/superuser/add_channels.py deleted file mode 100644 index 180148f..0000000 --- a/src/bot/handlers/superuser/add_channels.py +++ /dev/null @@ -1,103 +0,0 @@ -from io import TextIOWrapper -from typing import BinaryIO -from typing import Optional - -from aiogram import Bot -from aiogram import F -from aiogram import Router -from aiogram.enums import ContentType -from aiogram.filters import Command -from aiogram.filters import StateFilter -from aiogram.fsm.context import FSMContext -from aiogram.fsm.state import State -from aiogram.types import File -from aiogram.types import Message -from sulguk import SULGUK_PARSE_MODE - -from ....db import DataAccessLayer -from ....dto import ChannelCreateDTO -from ....dto import ChannelRetrieveDTO -from ...filters import RoleFilter -from ...filters import UserRole -from ...states import ChannelsSG -from src.utils import youtube_channel_url_validator - -add_channels_router = Router(name="add_channels") - - -@add_channels_router.message( - Command("add_channels"), - RoleFilter(role=[UserRole.SUPERUSER]), - State(state="*"), -) -async def add_channels(message: Message, state: FSMContext, **kwargs) -> None: - await message.answer( - text="Upload a file with format:
" - "url[TAB]label[END_ROW]
" - "every single line == channel to insert.
" - "Set channel url in format: https://www.youtube.com/@username" - "Enter /cancel for exit
", - parse_mode=SULGUK_PARSE_MODE, - ) - await state.set_state(ChannelsSG.bulk_channels) - - -@add_channels_router.message( - F.content_type == ContentType.DOCUMENT, - RoleFilter(role=[UserRole.SUPERUSER]), - StateFilter(ChannelsSG.bulk_channels), -) -async def channel_file_handler( - message: Message, state: FSMContext, bot: Bot, dal: DataAccessLayer, **kwargs -) -> None: - file_id = message.document.file_id - file: File = await bot.get_file(file_id=file_id) - - if file.file_size >= 10 * 1024 * 1024: - await message.answer( - text="File too big. Try again with filesize lower then 10 mb.", - parse_mode=SULGUK_PARSE_MODE, - ) - return - - user_schema = await dal.get_user_by_attr(**{"user_id": message.from_user.id}) - if user_schema: - _: BinaryIO = await bot.download_file(file.file_path) - channels: list[ChannelCreateDTO] = [] - with TextIOWrapper(_, encoding="utf-8") as text_io: - for line in text_io: - line = line.strip() - splitted_line = line.split("\t") - if len(splitted_line) != 2: - await message.answer( - text=f"Malformed line {line[0:255]}. ", - parse_mode=SULGUK_PARSE_MODE, - ) - return - - if not youtube_channel_url_validator(splitted_line[0]): - await message.answer( - text=f"Error url validation: {splitted_line[0]}
" - "Set channel url in format: https://www.youtube.com/@username", - parse_mode=SULGUK_PARSE_MODE, - ) - return - - channel = ChannelCreateDTO( - url=splitted_line[0], - label=splitted_line[1], - enabled=True, - user_id=user_schema.id, - ) - channels.append(channel) - - for channel in channels: - result: Optional[ChannelRetrieveDTO] = await dal.create_channel( - channel_schema=channel - ) - await message.answer(f"{str(result)}") - - await state.clear() - - -__all__ = ["add_channels_router"] diff --git a/src/bot/handlers/superuser/add_user.py b/src/bot/handlers/superuser/add_user.py index 8249e54..c118aed 100644 --- a/src/bot/handlers/superuser/add_user.py +++ b/src/bot/handlers/superuser/add_user.py @@ -18,7 +18,8 @@ from ...filters import UserRole from ...states import UsersSG from src.db import DataAccessLayer -from src.dto import UserCreateDTO +from src.db.models import UserModel +from src.db.models import UserRoleModel from src.logger import logger add_user_router = Router(name="acl") @@ -62,23 +63,23 @@ async def handle_user( user_id = message.user_shared.user_id chat: Chat = await bot.get_chat(chat_id=user_id) - user_dto = UserCreateDTO( + user = UserModel( user_id=user_id, username=chat.username, firstname=chat.first_name, lastname=chat.last_name, - is_superuser=False, + role=UserRoleModel(role=UserRole.USER), ) - result = await dal.create_user(user_schema=user_dto) + result = await dal.create_user(obj=user) if result: await bot.set_my_commands( - user_commands(), scope=BotCommandScopeChat(chat_id=user_dto.user_id) + user_commands(), scope=BotCommandScopeChat(chat_id=user.user_id) ) await message.answer(text="Success. User added.") else: - await message.answer(text="Error during create. Try again.") + await message.answer(text="❌ Error during create. Try again.") except (Exception,) as ex: await logger.aerror(str(ex)) diff --git a/src/bot/handlers/user/__init__.py b/src/bot/handlers/user/__init__.py index c2928c2..d0e8bd2 100644 --- a/src/bot/handlers/user/__init__.py +++ b/src/bot/handlers/user/__init__.py @@ -1,12 +1,12 @@ from aiogram import Router from .add_channel import add_channel_router -from .scroll_channels import scroll_channel_router +from .list_channels import list_channels_router user_router = Router(name="user") user_router.include_router(add_channel_router) -user_router.include_router(scroll_channel_router) +user_router.include_router(list_channels_router) __all__ = ["user_router"] diff --git a/src/bot/handlers/user/add_channel.py b/src/bot/handlers/user/add_channel.py index dbce812..d8dfc63 100644 --- a/src/bot/handlers/user/add_channel.py +++ b/src/bot/handlers/user/add_channel.py @@ -1,20 +1,14 @@ -from aiogram import F from aiogram import Router -from aiogram.exceptions import TelegramAPIError from aiogram.filters import Command -from aiogram.filters import StateFilter -from aiogram.fsm.context import FSMContext from aiogram.fsm.state import State from aiogram.types import Message -from sulguk import SULGUK_PARSE_MODE +from aiogram_dialog import DialogManager +from aiogram_dialog import StartMode from ....db import DataAccessLayer -from ....dto import ChannelCreateDTO from ...filters import RoleFilter from ...filters import UserRole -from ...states import ChannelsSG -from src.logger import logger -from src.utils import youtube_channel_url_validator +from ...states import ChannelCreateSG add_channel_router = Router(name="add_channel") @@ -25,85 +19,20 @@ RoleFilter(role=[UserRole.USER, UserRole.SUPERUSER]), State(state="*"), ) -async def add_channel(message: Message, state: FSMContext, **kwargs) -> None: - await message.answer( - text="Set channel url in format: https://www.youtube.com/@username", - parse_mode=SULGUK_PARSE_MODE, - ) - await state.set_state(ChannelsSG.input_url) - - -@add_channel_router.message( - StateFilter(ChannelsSG.input_url), - RoleFilter(role=[UserRole.USER, UserRole.SUPERUSER]), - F.text, -) -async def url_handler(message: Message, state: FSMContext, **kwargs) -> None: - url = message.text.lower().strip() - if youtube_channel_url_validator(url): - await state.update_data(url=url) - await message.answer( - f"URL set to {url}, enter display name or /cancel for reject", - parse_mode=SULGUK_PARSE_MODE, - ) - await state.set_state(ChannelsSG.input_label) - else: - await message.answer( - text="Set channel url in format: https://www.youtube.com/@username", - parse_mode=SULGUK_PARSE_MODE, - ) - - -@add_channel_router.message( - StateFilter(ChannelsSG.input_label), - RoleFilter(role=[UserRole.USER, UserRole.SUPERUSER]), - F.text, -) -async def label_handler( - message: Message, state: FSMContext, dal: DataAccessLayer, **kwargs +async def start_add_channel_dialog( + message: Message, + dialog_manager: DialogManager, + dal: DataAccessLayer, + role: UserRole, + **kwargs, ) -> None: - label = message.text.strip() - user_data = await state.get_data() - if user_data: - try: - url = user_data.get("url", None) - if url: - user_schema = await dal.get_user_by_attr( - **{"user_id": message.from_user.id} - ) - if user_schema: - channel_schema = ChannelCreateDTO( - url=url, - label=label, - enabled=True, - user_id=user_schema.id, - ) - - result = await dal.create_channel(channel_schema=channel_schema) - if result: - await message.answer( - f"Created.
" - f"#id: {result.id}
" - f"label: {result.label}
" - f"url:{result.url}
", - parse_mode=SULGUK_PARSE_MODE, - ) - else: - await message.answer( - f"Cannot add channel with parameters:
" - f"label: {label}
" - f"url:{url}

" - f"Contact admins.", - parse_mode=SULGUK_PARSE_MODE, - ) - except TelegramAPIError as ex: - await logger.aerror(ex.message) - await message.reply( - text=f"❌ TelegramAPIError. Notify administrators. Thank you!" - ) - finally: - await state.clear() - await message.answer("States cleared.") + await dialog_manager.start( + ChannelCreateSG.start, + mode=StartMode.RESET_STACK, + data={ + "dal": dal, + }, + ) __all__ = ["add_channel_router"] diff --git a/src/bot/handlers/user/scroll_channels.py b/src/bot/handlers/user/list_channels.py similarity index 78% rename from src/bot/handlers/user/scroll_channels.py rename to src/bot/handlers/user/list_channels.py index 5c82139..60a3cdb 100644 --- a/src/bot/handlers/user/scroll_channels.py +++ b/src/bot/handlers/user/list_channels.py @@ -8,12 +8,12 @@ from ....db import DataAccessLayer from ...filters import RoleFilter from ...filters import UserRole -from ...states import ChannelsSG +from ...states import ChannelsListSG -scroll_channel_router = Router(name="scroll_channel") +list_channels_router = Router(name="list_channel") -@scroll_channel_router.message( +@list_channels_router.message( Command("channels"), RoleFilter(role=[UserRole.USER, UserRole.SUPERUSER]), State(state="*"), @@ -26,10 +26,10 @@ async def start_channels_dialog( **kwargs, ): await dialog_manager.start( - ChannelsSG.scrolling, + ChannelsListSG.scrolling, mode=StartMode.RESET_STACK, data={"dal": dal, "role": role}, ) -__all__ = ["scroll_channel_router"] +__all__ = ["list_channels_router"] diff --git a/src/bot/middlewares/role.py b/src/bot/middlewares/role.py index af5fbd7..aa9b339 100755 --- a/src/bot/middlewares/role.py +++ b/src/bot/middlewares/role.py @@ -7,9 +7,9 @@ from aiogram.dispatcher.middlewares.base import BaseMiddleware from aiogram.types import Message -from src.bot.filters.role import UserRole from src.db import DataAccessLayer -from src.dto import UserCreateDTO +from src.db.models import UserModel +from src.db.models.user_role import UserRole class RoleMiddleware(BaseMiddleware): @@ -40,11 +40,11 @@ async def __call__( else: user_id = user.id - _user: Optional[UserCreateDTO] = await dal.get_user_by_attr( + _user: Optional[UserModel] = await dal.get_user_by_attr( **{"user_id": user_id} ) if _user: - if _user.is_superuser: + if _user.role.role == UserRole.SUPERUSER: data["role"] = UserRole.SUPERUSER else: data["role"] = UserRole.USER diff --git a/src/bot/states.py b/src/bot/states.py index 50a5fe4..87bdac9 100644 --- a/src/bot/states.py +++ b/src/bot/states.py @@ -2,18 +2,21 @@ from aiogram.fsm.state import StatesGroup -class ChannelsSG(StatesGroup): - input_url = State() - input_label = State() +class ChannelsListSG(StatesGroup): scrolling = State() delete = State() turn_on = State() turn_off = State() - bulk_channels = State() + + +class ChannelCreateSG(StatesGroup): + start = State() + type_selected = State() + url_selected = State() class UsersSG(StatesGroup): promote = State() -__all__ = ["ChannelsSG", "UsersSG"] +__all__ = ["ChannelCreateSG", "ChannelsListSG", "UsersSG"] diff --git a/src/bot/utils/setup_bot.py b/src/bot/utils/setup_bot.py index d1fe4c4..355b4dd 100644 --- a/src/bot/utils/setup_bot.py +++ b/src/bot/utils/setup_bot.py @@ -2,6 +2,7 @@ from aiogram.types import BotCommandScopeChat from sulguk import AiogramSulgukMiddleware +from ...constants import VERSION from ..commands import superuser_commands from ..commands import user_commands from src.config import BotConfig @@ -35,6 +36,7 @@ async def setup_bot( user_commands() + superuser_commands(), scope=BotCommandScopeChat(chat_id=_id), ) + await bot.send_message(chat_id=_id, text=f"Starting bot, version: {VERSION}") await bot.delete_webhook() return bot diff --git a/src/cli.py b/src/cli.py index 79b2e36..c301fba 100644 --- a/src/cli.py +++ b/src/cli.py @@ -7,8 +7,9 @@ from apscheduler.schedulers.asyncio import AsyncIOScheduler from .db import DataAccessLayer -from .dto import UserCreateDTO -from .dto import UserRetrieveDTO +from .db.models import UserModel +from .db.models import UserRoleModel +from .db.models.user_role import UserRole from .logger import logger from src.bot import setup_bot from src.bot import setup_dispatcher @@ -27,11 +28,13 @@ async def create_super_user(telegram_id: int) -> None: dal = DataAccessLayer() - user_dto = UserCreateDTO(user_id=telegram_id, is_superuser=True, is_admin=True) + role = UserRoleModel(role=UserRole.SUPERUSER) - result = await dal.create_user(user_schema=user_dto) + user = UserModel(user_id=telegram_id, role=role) - if isinstance(result, UserRetrieveDTO) and result.id is not None: + result = await dal.create_user(obj=user) + + if isinstance(result, UserModel) and result.id is not None: await logger.ainfo("User created.") else: await logger.error("Cannot create user.") @@ -46,7 +49,9 @@ async def run_bot() -> None: superusers: list[int] = await dal.get_users(superusers=True) if not await dal.is_superusers_exists(): - await logger.aerror("You must to create superuser before start.") + await logger.aerror( + "You must to create superuser before start. Restart with --telegram_id YOUR_TELEGRAM_ID option" + ) return bot: Bot = await setup_bot( diff --git a/src/config/models.py b/src/config/models.py index 709437f..4f03f26 100644 --- a/src/config/models.py +++ b/src/config/models.py @@ -17,6 +17,20 @@ class BotConfig(BaseSettings): token: SecretStr +class YoutubeCredentials(BaseSettings): + cookies_filepath: str + + +class TwitchCredentials(BaseSettings): + """ + Credentials for TwitchApi + https://dev.twitch.tv/docs/authentication/register-app/ + """ + + app_id: str + app_secret: str + + class Config(BaseSettings): """ All in one config @@ -28,6 +42,8 @@ class Config(BaseSettings): report: Report start_scheduler: bool interval_s: int + twitch: Optional[TwitchCredentials] = None + youtube: Optional[YoutubeCredentials] = None __all__ = ["BotConfig", "Config"] diff --git a/src/constants.py b/src/constants.py index f2ab57f..7f2926a 100755 --- a/src/constants.py +++ b/src/constants.py @@ -4,17 +4,14 @@ CONFIG_FILE_PATH: str = os.environ.get( "CONFIG_FILE_PATH", os.path.join(ROOT_DIR, "config.yaml") ) -COOKIES_FILE_PATH: str = os.environ.get( - "COOKIES_FILE_PATH", os.path.join(ROOT_DIR, "cookies.txt") -) + SQLITE_DATABASE_FILE_PATH: str = os.environ.get( "SQLITE_DATABASE_FILE_PATH", os.path.join(ROOT_DIR, "youtube-notifier-bot.db") ) -VERSION: str = "2024-08-23.20" +VERSION: str = "2024-09-26.07" __all__ = [ "CONFIG_FILE_PATH", - "COOKIES_FILE_PATH", "ROOT_DIR", "SQLITE_DATABASE_FILE_PATH", "VERSION", diff --git a/src/db/dal.py b/src/db/dal.py index 9a746b1..12b3050 100644 --- a/src/db/dal.py +++ b/src/db/dal.py @@ -1,36 +1,38 @@ import asyncio import os -from typing import cast from typing import Optional from sqlalchemy.ext.asyncio import AsyncSession from ..constants import SQLITE_DATABASE_FILE_PATH -from ..dto import ChannelCreateDTO -from ..dto import ChannelRetrieveDTO -from ..dto import MessageLogCreateDTO -from ..dto import MessageLogRetrieveDTO -from ..dto import UserCreateDTO -from ..dto import UserRetrieveDTO from .dao import ChannelDAO +from .dao import ChannelTypeDAO from .dao import MessageLogDAO -from .dao import UserRepo +from .dao import UserDAO +from .dao import UserRoleDAO from .exceptions import DatabaseDoesNotExist -from .models import ChannelORM -from .models import MessageLogORM -from .models import UserORM +from .models import ChannelModel +from .models import MessageLogModel +from .models import UserModel +from .models import UserRoleModel +from .models.user_role import UserRole from .session import session_maker class DataAccessLayer: def __init__(self, session: Optional[AsyncSession] = None) -> None: + if session is None: self.__create_session() else: self.__session = session self.__sqlite_exists() - self.__init_repo() + self.channel_dao = ChannelDAO(session=self.__session) + self.channel_type_dao = ChannelTypeDAO(session=self.__session) + self.message_log_dao = MessageLogDAO(session=self.__session) + self.user_dao = UserDAO(session=self.__session) + self.user_role_dao = UserRoleDAO(session=self.__session) @staticmethod def __sqlite_exists(): @@ -51,74 +53,62 @@ def __del__(self): if self.__session: asyncio.create_task(self.__session.close()) - def __init_repo(self) -> None: + async def create_user(self, obj: UserModel) -> Optional[UserModel]: """ :return: """ - self.__user_repo = UserRepo( - session=self.__session, schema=UserRetrieveDTO, model_orm=UserORM - ) - self.__channel_repo = ChannelDAO( - session=self.__session, - schema=ChannelRetrieveDTO, - model_orm=ChannelORM, - ) - self.__message_log_repo = MessageLogDAO( - session=self.__session, - schema=MessageLogRetrieveDTO, - model_orm=MessageLogORM, + + user_role = obj.role + user_role_instance, _ = await self.user_role_dao.get_or_create( + role=user_role.role ) + obj.role = user_role_instance - async def create_user( - self, user_schema: UserCreateDTO - ) -> Optional[UserRetrieveDTO]: - """ - :param user_schema: - :return: - """ - return await self.__user_repo.create(user_schema=user_schema) + return await self.user_dao.create(obj=obj) - async def get_user_by_pk(self, pk: int) -> Optional[UserRetrieveDTO]: + async def get_user_by_pk(self, pk: int) -> Optional[UserModel]: """ :param pk: :return: """ - return await self.__user_repo.get_by_pk(pk=pk) + return await self.user_dao.get_first(id=pk) - async def get_user_by_attr(self, **kwargs) -> Optional[UserRetrieveDTO]: + async def get_user_by_attr(self, *args, **kwargs) -> Optional[UserModel]: """ :param kwargs: :return: """ - return await self.__user_repo.get_by_attr(**kwargs) + return await self.user_dao.get_first(*args, **kwargs) - async def list_users_by_attr(self, **kwargs) -> list[UserRetrieveDTO]: + async def list_users_by_attr(self, *args, **kwargs) -> list[UserModel]: """ :param kwargs: :return: """ - users: list[UserRetrieveDTO] = cast( - list[UserRetrieveDTO], await self.__user_repo.list_by_attrs(**kwargs) - ) - return users + return list(await self.user_dao.get_many(*args, **kwargs)) async def is_superusers_exists(self) -> bool: """ :return: """ - return bool(await self.list_users_by_attr(**{"is_superuser": True})) + return bool( + await self.list_users_by_attr(UserRoleModel.role == UserRole.SUPERUSER) + ) async def get_users(self, superusers: bool = False) -> list[int]: """ :param superusers: :return: """ - users: list[UserRetrieveDTO] + users: list[UserModel] if superusers: - users = await self.list_users_by_attr(**{"is_superuser": True}) + role = UserRole.SUPERUSER else: - users = await self.list_users_by_attr(**{"is_superuser": False}) + role = UserRole.USER + + role_instance, _ = await self.user_role_dao.get_or_create(role=role) + users = await self.list_users_by_attr(**{"role": role_instance}) user_ids: list[int] = [user.user_id for user in users] return user_ids @@ -127,59 +117,55 @@ async def get_last_published_message_id(self) -> Optional[int]: """ :return: """ - message_log_dto: Optional[MessageLogRetrieveDTO] = ( - await self.__message_log_repo.get_by_attr() - ) - if message_log_dto: - return message_log_dto.message_id + message_log = await self.message_log_dao.get_first() + if message_log: + return message_log.message_id return None - async def create_message( - self, message_log_schema: MessageLogCreateDTO - ) -> Optional[MessageLogRetrieveDTO]: + async def create_message(self, obj: MessageLogModel) -> Optional[MessageLogModel]: """ - :param message_log_schema: + :param obj: :return: """ - return await self.__message_log_repo.create( - message_log_schema=message_log_schema - ) + return await self.message_log_dao.create(obj=obj) - async def create_channel( - self, channel_schema: ChannelCreateDTO - ) -> Optional[ChannelRetrieveDTO]: + async def create_channel(self, obj: ChannelModel) -> Optional[ChannelModel]: """ - :param channel_schema: :return: """ - return await self.__channel_repo.create(channel_schema=channel_schema) - async def get_channels(self, **kwargs) -> list[ChannelRetrieveDTO]: + channel_type = obj.type + channel_type_instance, _ = await self.channel_type_dao.get_or_create( + type=channel_type.type + ) + obj.type = channel_type_instance + + return await self.channel_dao.create(obj=obj) + + async def get_channels(self, *args, **kwargs) -> list[ChannelModel]: """ :param kwargs: :return: """ - channels: list[ChannelRetrieveDTO] = cast( - list[ChannelRetrieveDTO], await self.__channel_repo.list_by_attrs(**kwargs) - ) - return channels + return list(await self.channel_dao.get_many(*args, **kwargs)) - async def delete_channel_by_id(self, _id: int) -> Optional[int]: + async def delete_channel_by_id(self, id: int) -> Optional[int]: """ - :param _id: + :param id: :return: """ - return await self.__channel_repo.delete_by_pk(pk=_id) + return await self.channel_dao.delete(id=id) - async def update_channel_by_id(self, _id: int, data: dict) -> Optional[int]: + async def update_channel_by_id(self, obj: ChannelModel) -> Optional[int]: """ - :param data: - :param _id: :return: """ - return await self.__channel_repo.update_by_pk(pk=_id, data=data) + channel = await self.channel_dao.update(obj=obj) + if channel: + return channel.id + return None -__all__ = ["DataAccessLayer"] + __all__ = ["DataAccessLayer"] diff --git a/src/db/dao/__init__.py b/src/db/dao/__init__.py index e510865..1bfae14 100644 --- a/src/db/dao/__init__.py +++ b/src/db/dao/__init__.py @@ -1,5 +1,7 @@ -from .channels import ChannelDAO +from .channel import ChannelDAO +from .channel_type import ChannelTypeDAO from .message_log import MessageLogDAO -from .users import UserRepo +from .user import UserDAO +from .user_role import UserRoleDAO -__all__ = ["ChannelDAO", "MessageLogDAO", "UserRepo"] +__all__ = ["ChannelDAO", "ChannelTypeDAO", "MessageLogDAO", "UserDAO", "UserRoleDAO"] diff --git a/src/db/dao/base.py b/src/db/dao/base.py index 421b71a..b12db1b 100644 --- a/src/db/dao/base.py +++ b/src/db/dao/base.py @@ -1,161 +1,106 @@ -from abc import ABC -from typing import Any +from abc import abstractmethod +from typing import Generic from typing import Optional +from typing import Sequence from typing import Type +from typing import TypeVar -from pydantic import TypeAdapter -from sqlalchemy import Delete -from sqlalchemy import inspect -from sqlalchemy import ScalarResult -from sqlalchemy import Select -from sqlalchemy import Update +from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload +from sqlmodel import select +from sqlmodel import SQLModel -from src.db.exceptions import ColumnDoesNotExist -from src.db.models import ModelORM -from src.dto import DTO +from src.logger import logger +T = TypeVar("T", bound=SQLModel) -class DAO(ABC): - def __init__( - self, session: AsyncSession, model_orm: Type[ModelORM], schema: Type[DTO] - ) -> None: + +class BaseDAO(Generic[T]): + def __init__(self, session: AsyncSession, model: Type[T]): self.session = session - self.model_orm = model_orm - self.schema = schema + self.model = model @property - def relations(self): - """ - Get all relations - :return: - """ - return { - tuple_name_column[0]: tuple_name_column[1] - for tuple_name_column in inspect(self.model_orm).relationships.items() - } - - async def scalars(self, statement): - """ - :param statement: sqlalchemy statement, for example sqlalchemy.select - :return: - """ - for relation in self.relations.values(): - statement = statement.options(selectinload(relation)) - - return await self.session.scalars(statement) - - async def create(self, *args, **kwargs) -> Any: - raise NotImplementedError() - - def __generate_predicate(self, **kwargs) -> list: - where_clause: list = [] - - for key in kwargs.keys(): - if not hasattr(self.model_orm, key): - raise ColumnDoesNotExist(column=key, table=self.model_orm.__tablename__) - else: - col = getattr(self.model_orm, key) - where_clause.append(col == kwargs[key]) - - return where_clause - - async def __get_by_attrs(self, **kwargs) -> ScalarResult: - """ - :param kwargs: - :return: - """ - stm = ( - Select(self.model_orm) - .where(*self.__generate_predicate(**kwargs)) - .order_by(self.model_orm.id.desc()) - ) - result: ScalarResult = await self.scalars(statement=stm) - return result - - async def list_by_attrs(self, **kwargs) -> list[DTO]: - """ - :param kwargs: - :return: - """ - result: ScalarResult = await self.__get_by_attrs(**kwargs) - all_results = result.all() - ta = TypeAdapter(list[self.schema]) # type: ignore - dto_objects = ta.validate_python(all_results) - - return dto_objects - - async def get_by_pk(self, pk: int) -> Optional[DTO]: - """ - :param pk: - :return: - """ - stm = Select(self.model_orm).where(self.model_orm.id == pk) - - result: ScalarResult = await self.scalars(statement=stm) - model: Optional[ModelORM] = result.first() - - if model: - dto = self.schema.model_validate(model) - return dto - - return None - - async def get_by_attr(self, **kwargs) -> Optional[DTO]: - """ - :param kwargs: - :return: - """ - result: ScalarResult = await self.__get_by_attrs(**kwargs) - model: Optional[ModelORM] = result.first() - - if model: - dto = self.schema.model_validate(model) - return dto - - return None - - async def delete_by_pk(self, pk: int) -> Optional[int]: - stm = ( - Delete(self.model_orm) - .where(self.model_orm.id == pk) - .returning(self.model_orm.id) - ) - result: ScalarResult = await self.session.scalars(statement=stm) - deleted_id = result.first() - - await self.session.commit() - return deleted_id - - async def delete_by_attr(self, **kwargs) -> list[int]: - """ - :param kwargs: - :return: - """ - stm = ( - Delete(self.model_orm) - .where(*self.__generate_predicate(**kwargs)) - .returning(self.model_orm.id) - ) - - result: ScalarResult = await self.session.scalars(statement=stm) - - await self.session.commit() - return list(result.all()) - - async def update_by_pk(self, pk: int, data: dict) -> Optional[int]: - stm = ( - Update(self.model_orm) - .where(self.model_orm.id == pk) - .values(**data) - .returning(self.model_orm.id) - ) - result: ScalarResult = await self.session.scalars(statement=stm) - updated_id = result.first() - - await self.session.commit() - return updated_id - - -__all__ = ["DAO"] + def __prepare_select_statement(self): + statement = select(self.model) + return statement + + @abstractmethod + async def get_many(self, *args, **kwargs) -> Sequence[T]: + try: + statement = self.__prepare_select_statement.where(*args).filter_by(**kwargs) + results = await self.session.execute(statement) + return results.scalars().all() + except SQLAlchemyError as e: + await logger.aerror( + f"Error searching {self.model.__name__} with attributes {kwargs}: {e}" + ) + return [] + + @abstractmethod + async def get_first(self, *args, **kwargs) -> Optional[T]: + try: + statement = self.__prepare_select_statement.where(*args).filter_by(**kwargs) + result = await self.session.execute(statement) + return result.scalars().first() + except SQLAlchemyError as e: + await logger.aerror( + f"Error searching for one {self.model.__name__} with attributes {kwargs}: {e}" + ) + return None + + @abstractmethod + async def create(self, obj: T) -> Optional[T]: + try: + self.session.add(obj) + await self.session.commit() + await self.session.refresh(obj) + return obj + except SQLAlchemyError as e: + await logger.aerror(f"Error creating {self.model.__name__}: {e}") + await self.session.rollback() + return None + + @abstractmethod + async def get_or_create(self, **kwargs) -> tuple[T, bool]: + + instance = await self.get_first(**kwargs) + if instance: + return instance, False + + instance = self.model(**kwargs) + obj = await self.create(obj=instance) + + if not obj: + raise + + return obj, True + + @abstractmethod + async def update(self, obj: T) -> Optional[T]: + try: + self.session.add(obj) + await self.session.commit() + await self.session.refresh(obj) + return obj + except SQLAlchemyError as e: + await logger.aerror(f"Error updating {self.model.__name__}: {e}") + await self.session.rollback() + return None + + @abstractmethod + async def delete(self, id: int) -> bool: + try: + obj = await self.session.get(self.model, id) + if obj: + await self.session.delete(obj) + await self.session.commit() + return True + return False + except SQLAlchemyError as e: + await logger.aerror(f"Error deleting {self.model.__name__}: {e}") + await self.session.rollback() + return False + + +__all__ = ["BaseDAO"] diff --git a/src/db/dao/channel.py b/src/db/dao/channel.py new file mode 100644 index 0000000..4804602 --- /dev/null +++ b/src/db/dao/channel.py @@ -0,0 +1,62 @@ +from typing import Optional +from typing import Sequence + +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload +from sqlmodel import select + +from .base import BaseDAO +from src.db.models import ChannelModel +from src.logger import logger + + +class ChannelDAO(BaseDAO[ChannelModel]): + def __init__(self, session: AsyncSession): + super().__init__(session, ChannelModel) + + @property + def __prepare_select_statement(self): + statement = ( + select(self.model) + .options(joinedload(self.model.user)) + .options(joinedload(self.model.type)) + ) + return statement + + async def get_many(self, *args, **kwargs) -> Sequence[ChannelModel]: + try: + statement = self.__prepare_select_statement.where(*args).filter_by(**kwargs) + results = await self.session.execute(statement) + return results.scalars().all() + except SQLAlchemyError as e: + await logger.aerror( + f"Error searching {self.model.__name__} with attributes {kwargs}: {e}" + ) + return [] + + async def get_first(self, *args, **kwargs) -> Optional[ChannelModel]: + try: + statement = self.__prepare_select_statement.where(*args).filter_by(**kwargs) + result = await self.session.execute(statement) + return result.scalars().first() + except SQLAlchemyError as e: + await logger.aerror( + f"Error searching for one {self.model.__name__} with attributes {kwargs}: {e}" + ) + return None + + async def create(self, obj: ChannelModel) -> Optional[ChannelModel]: + return await super().create(obj) + + async def get_or_create(self, **kwargs) -> tuple[ChannelModel, bool]: + return await super().get_or_create(**kwargs) + + async def update(self, obj: ChannelModel) -> Optional[ChannelModel]: + return await super().update(obj) + + async def delete(self, id: int) -> bool: + return await super().delete(id) + + +__all__ = ["ChannelDAO"] diff --git a/src/db/dao/channel_type.py b/src/db/dao/channel_type.py new file mode 100644 index 0000000..dc4a944 --- /dev/null +++ b/src/db/dao/channel_type.py @@ -0,0 +1,33 @@ +from typing import Optional + +from sqlalchemy.ext.asyncio import AsyncSession +from typing_extensions import Sequence + +from .base import BaseDAO +from src.db.models import ChannelTypeModel + + +class ChannelTypeDAO(BaseDAO[ChannelTypeModel]): + def __init__(self, session: AsyncSession): + super().__init__(session, ChannelTypeModel) + + async def get_first(self, *args, **kwargs) -> Optional[ChannelTypeModel]: + return await super().get_first(*args, **kwargs) + + async def get_many(self, *args, **kwargs) -> Sequence[ChannelTypeModel]: + return await super().get_many(*args, **kwargs) + + async def create(self, obj: ChannelTypeModel) -> Optional[ChannelTypeModel]: + return await super().create(obj) + + async def get_or_create(self, **kwargs) -> tuple[ChannelTypeModel, bool]: + return await super().get_or_create(**kwargs) + + async def update(self, obj: ChannelTypeModel) -> Optional[ChannelTypeModel]: + return await super().update(obj) + + async def delete(self, id: int) -> bool: + return await super().delete(id) + + +__all__ = ["ChannelTypeDAO"] diff --git a/src/db/dao/channels.py b/src/db/dao/channels.py deleted file mode 100644 index e09461f..0000000 --- a/src/db/dao/channels.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import Optional - -from sqlalchemy import CursorResult - -from ..models import ChannelORM -from .base import DAO -from .utils import sqlite_async_upsert -from src.dto import ChannelCreateDTO -from src.dto import ChannelRetrieveDTO - - -class ChannelDAO(DAO): - async def create( - self, channel_schema: ChannelCreateDTO - ) -> Optional[ChannelRetrieveDTO]: - result: CursorResult = await sqlite_async_upsert( - session=self.session, - model=ChannelORM, - data=channel_schema.model_dump(), - index_col="url", - ) - - channel_dto: Optional[ChannelRetrieveDTO] - - if result.lastrowid: - channel_dto = await self.get_by_pk(pk=result.lastrowid) - else: - channel_dto = await self.get_by_attr(url=channel_schema.url) - - return channel_dto - - -__all__ = ["ChannelDAO"] diff --git a/src/db/dao/message_log.py b/src/db/dao/message_log.py index baecdda..cf5b34e 100644 --- a/src/db/dao/message_log.py +++ b/src/db/dao/message_log.py @@ -1,40 +1,39 @@ from typing import Optional +from typing import Sequence -from sqlalchemy import CursorResult - -from ..models import MessageLogORM -from .base import DAO -from .utils import sqlite_async_upsert -from src.dto import MessageLogCreateDTO -from src.dto import MessageLogRetrieveDTO - - -class MessageLogDAO(DAO): - async def create( - self, message_log_schema: MessageLogCreateDTO - ) -> Optional[MessageLogRetrieveDTO]: - """ - :param message_log_schema: - :return: - """ - - result: CursorResult = await sqlite_async_upsert( - session=self.session, - model=MessageLogORM, - data=message_log_schema.model_dump(), - index_col="message_id", - ) - - message_log_dto: Optional[MessageLogRetrieveDTO] - - if result.lastrowid: - message_log_dto = await self.get_by_pk(pk=result.lastrowid) - else: - message_log_dto = await self.get_by_attr( - message_id=message_log_schema.message_id - ) - - return message_log_dto +from sqlalchemy.ext.asyncio import AsyncSession +from sqlmodel import select + +from .base import BaseDAO +from src.db.models import MessageLogModel + + +class MessageLogDAO(BaseDAO[MessageLogModel]): + def __init__(self, session: AsyncSession): + super().__init__(session, MessageLogModel) + + @property + def __prepare_select_statement(self): + statement = select(self.model).order_by(self.model.id) + return statement + + async def get_first(self, *args, **kwargs) -> Optional[MessageLogModel]: + return await super().get_first(*args, **kwargs) + + async def get_many(self, *args, **kwargs) -> Sequence[MessageLogModel]: + return await super().get_many(*args, **kwargs) + + async def create(self, obj: MessageLogModel) -> Optional[MessageLogModel]: + return await super().create(obj) + + async def get_or_create(self, **kwargs) -> tuple[MessageLogModel, bool]: + return await super().get_or_create(**kwargs) + + async def update(self, obj: MessageLogModel) -> Optional[MessageLogModel]: + return await super().update(obj) + + async def delete(self, id: int) -> bool: + return await super().delete(id) __all__ = ["MessageLogDAO"] diff --git a/src/db/dao/user.py b/src/db/dao/user.py new file mode 100644 index 0000000..abf9693 --- /dev/null +++ b/src/db/dao/user.py @@ -0,0 +1,58 @@ +from typing import Optional +from typing import Sequence + +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload +from sqlmodel import select + +from .base import BaseDAO +from src.db.models import UserModel +from src.logger import logger + + +class UserDAO(BaseDAO[UserModel]): + def __init__(self, session: AsyncSession): + super().__init__(session, UserModel) + + @property + def __prepare_select_statement(self): + statement = select(self.model).options(joinedload(self.model.role)) + return statement + + async def get_many(self, *args, **kwargs) -> Sequence[UserModel]: + try: + statement = self.__prepare_select_statement.where(*args).filter_by(**kwargs) + results = await self.session.execute(statement) + return results.scalars().all() + except SQLAlchemyError as e: + await logger.aerror( + f"Error searching {self.model.__name__} with attributes {kwargs}: {e}" + ) + return [] + + async def get_first(self, *args, **kwargs) -> Optional[UserModel]: + try: + statement = self.__prepare_select_statement.where(*args).filter_by(**kwargs) + result = await self.session.execute(statement) + return result.scalars().first() + except SQLAlchemyError as e: + await logger.aerror( + f"Error searching for one {self.model.__name__} with attributes {kwargs}: {e}" + ) + return None + + async def create(self, obj: UserModel) -> Optional[UserModel]: + return await super().create(obj) + + async def get_or_create(self, **kwargs) -> tuple[UserModel, bool]: + return await super().get_or_create(**kwargs) + + async def update(self, obj: UserModel) -> Optional[UserModel]: + return await super().update(obj) + + async def delete(self, id: int) -> bool: + return await super().delete(id) + + +__all__ = ["UserDAO"] diff --git a/src/db/dao/user_role.py b/src/db/dao/user_role.py new file mode 100644 index 0000000..14f4644 --- /dev/null +++ b/src/db/dao/user_role.py @@ -0,0 +1,33 @@ +from typing import Optional +from typing import Sequence + +from sqlalchemy.ext.asyncio import AsyncSession + +from .base import BaseDAO +from src.db.models import UserRoleModel + + +class UserRoleDAO(BaseDAO[UserRoleModel]): + def __init__(self, session: AsyncSession): + super().__init__(session, UserRoleModel) + + async def get_first(self, *args, **kwargs) -> Optional[UserRoleModel]: + return await super().get_first(*args, **kwargs) + + async def get_many(self, *args, **kwargs) -> Sequence[UserRoleModel]: + return await super().get_many(*args, **kwargs) + + async def create(self, obj: UserRoleModel) -> Optional[UserRoleModel]: + return await super().create(obj) + + async def get_or_create(self, **kwargs) -> tuple[UserRoleModel, bool]: + return await super().get_or_create(**kwargs) + + async def update(self, obj: UserRoleModel) -> Optional[UserRoleModel]: + return await super().update(obj) + + async def delete(self, id: int) -> bool: + return await super().delete(id) + + +__all__ = ["UserRoleDAO"] diff --git a/src/db/dao/users.py b/src/db/dao/users.py deleted file mode 100644 index 4e53453..0000000 --- a/src/db/dao/users.py +++ /dev/null @@ -1,35 +0,0 @@ -from typing import Optional - -from sqlalchemy import CursorResult - -from ..models import UserORM -from .base import DAO -from .utils import sqlite_async_upsert -from src.dto import UserCreateDTO -from src.dto import UserRetrieveDTO - - -class UserRepo(DAO): - async def create(self, user_schema: UserCreateDTO) -> Optional[UserRetrieveDTO]: - """ - :param user_schema: - :return: - """ - result: CursorResult = await sqlite_async_upsert( - session=self.session, - model=UserORM, - data=user_schema.model_dump(), - index_col="user_id", - ) - - user_dto: Optional[UserRetrieveDTO] - - if result.lastrowid: - user_dto = await self.get_by_pk(pk=result.lastrowid) - else: - user_dto = await self.get_by_attr(user_id=user_schema.user_id) - - return user_dto - - -__all__ = ["UserRepo"] diff --git a/src/db/dao/utils/__init__.py b/src/db/dao/utils/__init__.py deleted file mode 100644 index eb28af2..0000000 --- a/src/db/dao/utils/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .sqlite_async_upsert import sqlite_async_upsert - - -__all__ = ["sqlite_async_upsert"] diff --git a/src/db/dao/utils/sqlite_async_upsert.py b/src/db/dao/utils/sqlite_async_upsert.py deleted file mode 100644 index bb07d71..0000000 --- a/src/db/dao/utils/sqlite_async_upsert.py +++ /dev/null @@ -1,49 +0,0 @@ -import copy -from typing import Any -from typing import Type - -from sqlalchemy import CursorResult -from sqlalchemy.dialects.sqlite import insert -from sqlalchemy.ext.asyncio import AsyncSession - - -# source -# https://docs.sqlalchemy.org/en/20/dialects/sqlite.html#insert-on-conflict-upsert - - -async def sqlite_async_upsert( - session: AsyncSession, - model: Type[Any], - data: dict[str, Any], - index_col: str, - pk_col: str = "id", -) -> CursorResult: - """ - :param pk_col: - :param session: - :param model: - :param data: - :param index_col: - :return: - """ - - data_without_index = copy.deepcopy(data) - - if index_col in data_without_index.keys(): - data_without_index.pop(index_col) - - if pk_col in data_without_index.keys(): - data_without_index.pop(pk_col) - - insert_stmt = insert(model).values(**data) - do_update_stmt = insert_stmt.on_conflict_do_update( - index_elements=[index_col], set_=data_without_index - ) - - result = await session.execute(do_update_stmt) - await session.commit() - - return result - - -__all__ = ["sqlite_async_upsert"] diff --git a/src/db/migrations/env.py b/src/db/migrations/env.py index 1b36f85..88deba8 100644 --- a/src/db/migrations/env.py +++ b/src/db/migrations/env.py @@ -3,30 +3,17 @@ from alembic import context from sqlalchemy.engine import Connection +from sqlmodel import SQLModel +from src.db.models import * # noqa from src.db.session import engine -# this is the Alembic Config object, which provides -# access to the values within the .ini file in use. config = context.config -# Interpret the config file for Python logging. -# This line sets up loggers basically. if config.config_file_name is not None: fileConfig(config.config_file_name) -# add your model's MetaData object here -# for 'autogenerate' support -# from myapp import mymodel -# target_metadata = mymodel.Base.metadata -from src.db.models.mixins.base import ModelORM - -target_metadata = ModelORM.metadata - -# other values from the config, defined by the needs of env.py, -# can be acquired: -# my_important_option = config.get_main_option("my_important_option") -# ... etc. +target_metadata = SQLModel.metadata def run_migrations_offline() -> None: @@ -65,14 +52,6 @@ async def run_async_migrations() -> None: and associate a connection with the context. """ - - # default - # connectable = async_engine_from_config( - # config.get_section(config.config_ini_section, {}), - # prefix="sqlalchemy.", - # poolclass=pool.NullPool, - # ) - # edited connectable = engine diff --git a/src/db/migrations/utils.py b/src/db/migrations/utils.py new file mode 100644 index 0000000..2a11f5a --- /dev/null +++ b/src/db/migrations/utils.py @@ -0,0 +1,14 @@ +from alembic import op +from sqlalchemy import inspect + + +# source +# https://stackoverflow.com/a/71624331 +def column_exists(table_name, column_name): + bind = op.get_context().bind + insp = inspect(bind) + columns = insp.get_columns(table_name) + return any(c["name"] == column_name for c in columns) + + +__all__ = ["column_exists"] diff --git a/src/db/migrations/versions/9218e49d7217_init.py b/src/db/migrations/versions/9218e49d7217_init.py new file mode 100644 index 0000000..229203c --- /dev/null +++ b/src/db/migrations/versions/9218e49d7217_init.py @@ -0,0 +1,99 @@ +"""init + +Revision ID: 9218e49d7217 +Revises: +Create Date: 2024-09-23 12:20:41.097968 + +""" +from typing import Sequence +from typing import Union + +import sqlalchemy as sa +import sqlmodel +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = '9218e49d7217' +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('channel_types', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.Column('type', sa.Enum('YOUTUBE', 'TWITCH', 'KICK', name='channeltype', length=15), nullable=False), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('type') + ) + op.create_table('message_logs', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.Column('message_id', sa.Integer(), nullable=False), + sa.Column('text', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_message_logs_message_id'), 'message_logs', ['message_id'], unique=False) + op.create_table('user_roles', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.Column('role', sa.Enum('USER', 'SUPERUSER', 'UNKNOWN', name='userrole', length=15), nullable=False), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('role') + ) + op.create_table('users', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.Column('user_id', sa.Integer(), nullable=False), + sa.Column('username', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=True), + sa.Column('firstname', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=True), + sa.Column('lastname', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=True), + sa.Column('user_role_id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['user_role_id'], ['user_roles.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_users_firstname'), 'users', ['firstname'], unique=False) + op.create_index(op.f('ix_users_lastname'), 'users', ['lastname'], unique=False) + op.create_index(op.f('ix_users_user_id'), 'users', ['user_id'], unique=True) + op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=False) + op.create_table('channels', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=True), + sa.Column('url', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False), + sa.Column('label', sqlmodel.sql.sqltypes.AutoString(length=255), nullable=False), + sa.Column('enabled', sa.Boolean(), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=True), + sa.Column('channel_type_id', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['channel_type_id'], ['channel_types.id'], ), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_channels_enabled'), 'channels', ['enabled'], unique=False) + op.create_index(op.f('ix_channels_label'), 'channels', ['label'], unique=False) + op.create_index(op.f('ix_channels_url'), 'channels', ['url'], unique=True) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_channels_url'), table_name='channels') + op.drop_index(op.f('ix_channels_label'), table_name='channels') + op.drop_index(op.f('ix_channels_enabled'), table_name='channels') + op.drop_table('channels') + op.drop_index(op.f('ix_users_username'), table_name='users') + op.drop_index(op.f('ix_users_user_id'), table_name='users') + op.drop_index(op.f('ix_users_lastname'), table_name='users') + op.drop_index(op.f('ix_users_firstname'), table_name='users') + op.drop_table('users') + op.drop_table('user_roles') + op.drop_index(op.f('ix_message_logs_message_id'), table_name='message_logs') + op.drop_table('message_logs') + op.drop_table('channel_types') + # ### end Alembic commands ### diff --git a/src/db/migrations/versions/da6a2af82a8b_init.py b/src/db/migrations/versions/da6a2af82a8b_init.py deleted file mode 100644 index 634351e..0000000 --- a/src/db/migrations/versions/da6a2af82a8b_init.py +++ /dev/null @@ -1,123 +0,0 @@ -from typing import Sequence -from typing import Union - -import sqlalchemy as sa -from alembic import op - - -# revision identifiers, used by Alembic. -revision: str = "da6a2af82a8b" -down_revision: Union[str, None] = None -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - try: - op.create_table( - "message_logs", - sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), - sa.Column("message_id", sa.BigInteger(), nullable=False), - sa.Column("text", sa.String(), nullable=False), - sa.Column("created_at", sa.TIMESTAMP(), nullable=False), - sa.Column("updated_at", sa.TIMESTAMP(), nullable=False), - sa.PrimaryKeyConstraint("id"), - sa.UniqueConstraint("message_id", name="ix_uniq_message_id"), - ) - op.create_index( - op.f("ix_message_logs_message_id"), - "message_logs", - ["message_id"], - unique=False, - ) - except: - pass - - try: - op.create_table( - "users", - sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), - sa.Column("user_id", sa.BigInteger(), nullable=False), - sa.Column("username", sa.String(length=255), nullable=True), - sa.Column("firstname", sa.String(length=255), nullable=True), - sa.Column("lastname", sa.String(length=255), nullable=True), - sa.Column("is_superuser", sa.Boolean(), nullable=True), - sa.Column("created_at", sa.TIMESTAMP(), nullable=False), - sa.Column("updated_at", sa.TIMESTAMP(), nullable=False), - sa.PrimaryKeyConstraint("id"), - sa.UniqueConstraint("user_id", name="ix_uniq_telegram_user_id"), - ) - op.create_index( - op.f("ix_users_firstname"), "users", ["firstname"], unique=False - ) - op.create_index( - op.f("ix_users_is_superuser"), "users", ["is_superuser"], unique=False - ) - op.create_index(op.f("ix_users_lastname"), "users", ["lastname"], unique=False) - op.create_index(op.f("ix_users_user_id"), "users", ["user_id"], unique=False) - op.create_index(op.f("ix_users_username"), "users", ["username"], unique=False) - except: - pass - - try: - op.create_table( - "channels", - sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), - sa.Column("url", sa.String(length=255), nullable=False), - sa.Column("label", sa.String(length=255), nullable=False), - sa.Column("enabled", sa.Boolean(), nullable=True), - sa.Column("created_at", sa.TIMESTAMP(), nullable=False), - sa.Column("updated_at", sa.TIMESTAMP(), nullable=False), - sa.Column("user_id", sa.Integer(), nullable=False), - sa.ForeignKeyConstraint( - ["user_id"], ["users.id"], onupdate="CASCADE", ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("id"), - sa.UniqueConstraint("url", name="ix_uniq_url"), - ) - op.create_index( - op.f("ix_channels_enabled"), "channels", ["enabled"], unique=False - ) - op.create_index(op.f("ix_channels_label"), "channels", ["label"], unique=False) - op.create_index(op.f("ix_channels_url"), "channels", ["url"], unique=False) - except: - pass - try: - op.create_table( - "channel_errors", - sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), - sa.Column("error", sa.String(length=512), nullable=False), - sa.Column("created_at", sa.TIMESTAMP(), nullable=False), - sa.Column("updated_at", sa.TIMESTAMP(), nullable=False), - sa.Column("channel_id", sa.Integer(), nullable=False), - sa.ForeignKeyConstraint( - ["channel_id"], ["channels.id"], onupdate="CASCADE", ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("id"), - ) - op.create_index( - op.f("ix_channel_errors_error"), "channel_errors", ["error"], unique=False - ) - except: - pass - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_index(op.f("ix_channel_errors_error"), table_name="channel_errors") - op.drop_table("channel_errors") - op.drop_index(op.f("ix_channels_url"), table_name="channels") - op.drop_index(op.f("ix_channels_label"), table_name="channels") - op.drop_index(op.f("ix_channels_enabled"), table_name="channels") - op.drop_table("channels") - op.drop_index(op.f("ix_users_username"), table_name="users") - op.drop_index(op.f("ix_users_user_id"), table_name="users") - op.drop_index(op.f("ix_users_lastname"), table_name="users") - op.drop_index(op.f("ix_users_is_superuser"), table_name="users") - op.drop_index(op.f("ix_users_firstname"), table_name="users") - op.drop_table("users") - op.drop_index(op.f("ix_message_logs_message_id"), table_name="message_logs") - op.drop_table("message_logs") - # ### end Alembic commands ### diff --git a/src/db/migrations/versions/f78d1309ec45_init.py b/src/db/migrations/versions/f78d1309ec45_init.py deleted file mode 100644 index edd0c1d..0000000 --- a/src/db/migrations/versions/f78d1309ec45_init.py +++ /dev/null @@ -1,43 +0,0 @@ -from typing import Sequence -from typing import Union - -import sqlalchemy as sa -from alembic import op - - -# revision identifiers, used by Alembic. -revision: str = "f78d1309ec45" -down_revision: Union[str, None] = "da6a2af82a8b" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_index("ix_channel_errors_error", table_name="channel_errors") - op.drop_table("channel_errors") - op.drop_index("ix_users_is_admin", table_name="users") - op.drop_column("users", "is_admin") - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.add_column("users", sa.Column("is_admin", sa.BOOLEAN(), nullable=True)) - op.create_index("ix_users_is_admin", "users", ["is_admin"], unique=False) - op.create_table( - "channel_errors", - sa.Column("id", sa.INTEGER(), nullable=False), - sa.Column("error", sa.VARCHAR(length=512), nullable=False), - sa.Column("created_at", sa.TIMESTAMP(), nullable=False), - sa.Column("updated_at", sa.TIMESTAMP(), nullable=False), - sa.Column("channel_id", sa.INTEGER(), nullable=False), - sa.ForeignKeyConstraint( - ["channel_id"], ["channels.id"], onupdate="CASCADE", ondelete="CASCADE" - ), - sa.PrimaryKeyConstraint("id"), - ) - op.create_index( - "ix_channel_errors_error", "channel_errors", ["error"], unique=False - ) - # ### end Alembic commands ### diff --git a/src/db/models/__init__.py b/src/db/models/__init__.py index 1c9ff9d..62bbfaa 100644 --- a/src/db/models/__init__.py +++ b/src/db/models/__init__.py @@ -1,15 +1,14 @@ -from .channel import ChannelORM -from .channel import ChannelORMRelatedModel -from .message_log import MessageLogORM -from .mixins.base import ModelORM -from .user import UserORM -from .user import UserORMRelatedModel +from .channel import ChannelModel +from .channel_type import ChannelTypeModel +from .message_log import MessageLogModel +from .user import UserModel +from .user_role import UserRoleModel + __all__ = [ - "ChannelORM", - "ChannelORMRelatedModel", - "MessageLogORM", - "ModelORM", - "UserORM", - "UserORMRelatedModel", + "ChannelModel", + "ChannelTypeModel", + "MessageLogModel", + "UserModel", + "UserRoleModel", ] diff --git a/src/db/models/channel.py b/src/db/models/channel.py index 30bedb8..b598c6b 100644 --- a/src/db/models/channel.py +++ b/src/db/models/channel.py @@ -1,53 +1,66 @@ -from sqlalchemy import Boolean +from datetime import datetime +from typing import Optional +from typing import TYPE_CHECKING + from sqlalchemy import Column -from sqlalchemy import ForeignKey -from sqlalchemy import Integer -from sqlalchemy import String -from sqlalchemy import UniqueConstraint -from sqlalchemy.orm import relationship +from sqlalchemy import DateTime +from sqlmodel import Field +from sqlmodel import Relationship +from sqlmodel import SQLModel + -from .mixins import ModelORM -from .mixins import RepresentationMixin -from .mixins import TimestampsMixin -from .user import UserORMRelatedModel +if TYPE_CHECKING: + from .channel_type import ChannelTypeModel + from .user import UserModel -class ChannelORM(ModelORM, TimestampsMixin, UserORMRelatedModel, RepresentationMixin): - """ - Model for storing YT channels - """ +class ChannelModel(SQLModel, table=True): __tablename__ = "channels" - __table_args__ = (UniqueConstraint("url", name="ix_uniq_url"),) - # some features with autoincrement - # https://docs.sqlalchemy.org/en/20/dialects/sqlite.html#allowing-autoincrement-behavior-sqlalchemy-types-other-than-integer-integer - id = Column( - Integer, - primary_key=True, - autoincrement=True, + id: Optional[int] = Field(default=None, primary_key=True) + created_at: Optional[datetime] = Field( + sa_column=Column( + DateTime, + default=datetime.utcnow, + nullable=False, + ) + ) + updated_at: Optional[datetime] = Field( + sa_column=Column( + DateTime, + default=datetime.utcnow, + onupdate=datetime.utcnow, + ) ) - url = Column(String(length=255), nullable=False, index=True) - label = Column(String(length=255), nullable=False, index=True) - enabled = Column(Boolean, default=True, index=True) - user = relationship( - "UserORM", - backref="channels", - foreign_keys="ChannelORM.user_id", - uselist=False, - lazy="selectin", + url: str = Field(max_length=255, nullable=False, index=True, unique=True) + label: str = Field(max_length=255, nullable=False, index=True) + enabled: bool = Field(nullable=False, index=True) + user_id: int | None = Field(default=None, foreign_key="users.id") + channel_type_id: int | None = Field( + default=None, foreign_key="channel_types.id", nullable=False ) + user: "UserModel" = Relationship(back_populates="channels") + type: "ChannelTypeModel" = Relationship(back_populates="channels") -class ChannelORMRelatedModel: - __abstract__ = True + def to_html(self) -> str: - channel_id = Column( - ForeignKey( - f"{ChannelORM.__tablename__}.id", ondelete="CASCADE", onupdate="CASCADE" - ), - nullable=False, - ) + user_attribute_list = [self.user.username, self.user.user_id] + attribute = next(item for item in user_attribute_list if item is not None) + user_link = f'{attribute}' + + return ( + f"📺 Selected channel:
" + f"├──type: {self.type.type}
" + f"├──enabled: {self.enabled}
" + f"├──id: {self.id}
" + f"├──label: {self.label}
" + f"├──url: {self.url}
" + f"├──added by: {user_link}
" + f"├──added at: {self.created_at}
" + f"└──last modified at: {self.updated_at}
" + ) -__all__ = ["ChannelORM", "ChannelORMRelatedModel"] +__all__ = ["ChannelModel"] diff --git a/src/db/models/channel_type.py b/src/db/models/channel_type.py new file mode 100644 index 0000000..9cc6409 --- /dev/null +++ b/src/db/models/channel_type.py @@ -0,0 +1,61 @@ +import enum +from datetime import datetime +from typing import Optional +from typing import TYPE_CHECKING + +from sqlalchemy import Column +from sqlalchemy import DateTime +from sqlalchemy import Enum +from sqlmodel import Field +from sqlmodel import Relationship +from sqlmodel import SQLModel + + +if TYPE_CHECKING: + from .channel import ChannelModel + + +class ChannelType(enum.StrEnum): + YOUTUBE = "YOUTUBE" + TWITCH = "TWITCH" + KICK = "KICK" + + @classmethod + def list(cls): + return list(map(lambda c: c.value, cls)) + + +class ChannelTypeModel(SQLModel, table=True): + + __tablename__ = "channel_types" + id: Optional[int] = Field(default=None, primary_key=True) + created_at: Optional[datetime] = Field( + sa_column=Column( + DateTime, + default=datetime.utcnow, + nullable=False, + ) + ) + updated_at: Optional[datetime] = Field( + sa_column=Column( + DateTime, + default=datetime.utcnow, + onupdate=datetime.utcnow, + ) + ) + type: ChannelType = Field( + sa_column=Column( + Enum(ChannelType, length=15), + default=ChannelType.YOUTUBE, + nullable=False, + index=False, + unique=True, + ) + ) + + channels: list["ChannelModel"] = Relationship( + back_populates="type", cascade_delete=True + ) + + +__all__ = ["ChannelType", "ChannelTypeModel"] diff --git a/src/db/models/message_log.py b/src/db/models/message_log.py index d13b219..93a90c1 100644 --- a/src/db/models/message_log.py +++ b/src/db/models/message_log.py @@ -1,32 +1,32 @@ -from sqlalchemy import BigInteger -from sqlalchemy import Column -from sqlalchemy import Integer -from sqlalchemy import String -from sqlalchemy import UniqueConstraint +from datetime import datetime +from typing import Optional -from .mixins import ModelORM -from .mixins import RepresentationMixin -from .mixins import TimestampsMixin +from sqlalchemy import Column +from sqlalchemy import DateTime +from sqlmodel import Field +from sqlmodel import SQLModel -class MessageLogORM(ModelORM, TimestampsMixin, RepresentationMixin): - """ - Model for storing id of actual post message - """ +class MessageLogModel(SQLModel, table=True): __tablename__ = "message_logs" - - __table_args__ = (UniqueConstraint("message_id", name="ix_uniq_message_id"),) - - # some features with autoincrement - # https://docs.sqlalchemy.org/en/20/dialects/sqlite.html#allowing-autoincrement-behavior-sqlalchemy-types-other-than-integer-integer - id = Column( - Integer, - primary_key=True, - autoincrement=True, + id: Optional[int] = Field(default=None, primary_key=True) + created_at: Optional[datetime] = Field( + sa_column=Column( + DateTime, + default=datetime.utcnow, + nullable=False, + ) + ) + updated_at: Optional[datetime] = Field( + sa_column=Column( + DateTime, + default=datetime.utcnow, + onupdate=datetime.utcnow, + ) ) - message_id = Column(BigInteger, nullable=False, index=True) - text = Column(String, nullable=False, index=False) + message_id: int = Field(index=True) + text: str = Field(nullable=False, index=False) -__all__ = ["MessageLogORM"] +__all__ = ["MessageLogModel"] diff --git a/src/db/models/mixins/__init__.py b/src/db/models/mixins/__init__.py deleted file mode 100644 index a1122fd..0000000 --- a/src/db/models/mixins/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .base import ModelORM -from .repr import RepresentationMixin -from .timestamp import TimestampsMixin - -__all__ = ["ModelORM", "RepresentationMixin", "TimestampsMixin"] diff --git a/src/db/models/mixins/base.py b/src/db/models/mixins/base.py deleted file mode 100644 index 44ffe32..0000000 --- a/src/db/models/mixins/base.py +++ /dev/null @@ -1,9 +0,0 @@ -from sqlalchemy.ext.asyncio import AsyncAttrs -from sqlalchemy.orm import DeclarativeBase - - -class ModelORM(AsyncAttrs, DeclarativeBase): - pass - - -__all_ = ["ModelORM"] diff --git a/src/db/models/mixins/repr.py b/src/db/models/mixins/repr.py deleted file mode 100644 index 45428dc..0000000 --- a/src/db/models/mixins/repr.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Any - -from sqlalchemy.orm.exc import DetachedInstanceError - - -class RepresentationMixin: - """Mixin for pretty display""" - - __abstract__ = True - - def _repr(self, **fields: dict[str, Any]) -> str: - """ - Helper for __repr__ - """ - field_strings = [] - at_least_one_attached_attribute = False - for key, field in fields.items(): - try: - field_strings.append(f"{key}={field!r}") - except DetachedInstanceError: - field_strings.append(f"{key}=DetachedInstanceError") - else: - at_least_one_attached_attribute = True - if at_least_one_attached_attribute: - return f"<{self.__class__.__name__}({','.join(field_strings)})>" - return f"<{self.__class__.__name__} {id(self)}>" - - -__all__ = ["RepresentationMixin"] diff --git a/src/db/models/mixins/timestamp.py b/src/db/models/mixins/timestamp.py deleted file mode 100644 index b14ee4a..0000000 --- a/src/db/models/mixins/timestamp.py +++ /dev/null @@ -1,32 +0,0 @@ -# source -# https://github.com/absent1706/sqlalchemy-mixins/blob/master/sqlalchemy_mixins/timestamp.py -import datetime - -import sqlalchemy as sa - - -class TimestampsMixin: - """Mixin that define timestamp columns.""" - - __abstract__ = True - - __created_at_name__ = "created_at" - __updated_at_name__ = "updated_at" - __datetime_func__ = datetime.datetime.utcnow - - created_at = sa.Column( - __created_at_name__, - sa.TIMESTAMP(timezone=False), - default=__datetime_func__, - nullable=False, - ) - - updated_at = sa.Column( - __updated_at_name__, - sa.TIMESTAMP(timezone=False), - default=__datetime_func__, - onupdate=__datetime_func__, - nullable=False, - ) - - __all__ = ["TimestampsMixin"] diff --git a/src/db/models/user.py b/src/db/models/user.py index b53aa52..c8cca4f 100644 --- a/src/db/models/user.py +++ b/src/db/models/user.py @@ -1,47 +1,55 @@ -from sqlalchemy import BigInteger -from sqlalchemy import Boolean +from datetime import datetime +from typing import Optional +from typing import TYPE_CHECKING + from sqlalchemy import Column -from sqlalchemy import ForeignKey -from sqlalchemy import Integer -from sqlalchemy import String -from sqlalchemy import UniqueConstraint +from sqlalchemy import DateTime +from sqlmodel import Field +from sqlmodel import Relationship +from sqlmodel import SQLModel -from .mixins import ModelORM -from .mixins import RepresentationMixin -from .mixins import TimestampsMixin +if TYPE_CHECKING: + from .channel import ChannelModel + from .user_role import UserRoleModel -class UserORM(ModelORM, TimestampsMixin, RepresentationMixin): - """ - Model for storing TG users - """ - __tablename__ = "users" +class UserModel(SQLModel, table=True): - __table_args__ = (UniqueConstraint("user_id", name="ix_uniq_telegram_user_id"),) - # some features with autoincrement - # https://docs.sqlalchemy.org/en/20/dialects/sqlite.html#allowing-autoincrement-behavior-sqlalchemy-types-other-than-integer-integer - id = Column( - Integer, - primary_key=True, - autoincrement=True, + __tablename__ = "users" + id: Optional[int] = Field(default=None, primary_key=True) + created_at: Optional[datetime] = Field( + sa_column=Column( + DateTime, + default=datetime.utcnow, + nullable=False, + ) + ) + updated_at: Optional[datetime] = Field( + sa_column=Column( + DateTime, + default=datetime.utcnow, + onupdate=datetime.utcnow, + ) + ) + user_id: int = Field(index=True, unique=True, nullable=False) + username: str = Field(max_length=255, nullable=True, index=True) + firstname: str = Field(max_length=255, nullable=True, index=True) + lastname: str = Field(max_length=255, nullable=True, index=True) + user_role_id: int | None = Field( + default=None, foreign_key="user_roles.id", nullable=False ) - user_id = Column(BigInteger, nullable=False, index=True) - username = Column(String(length=255), nullable=True, index=True) - firstname = Column(String(length=255), nullable=True, index=True) - lastname = Column(String(length=255), nullable=True, index=True) - is_superuser = Column(Boolean, default=False, index=True) - - -class UserORMRelatedModel: - __abstract__ = True - user_id = Column( - ForeignKey( - f"{UserORM.__tablename__}.id", ondelete="CASCADE", onupdate="CASCADE" - ), - nullable=False, + role: "UserRoleModel" = Relationship( + back_populates="users", ) + channels: list["ChannelModel"] = Relationship( + back_populates="user", + ) + + @property + def get_url_generated_by_id(self) -> str: + return f"tg://openmessage?user_id={self.user_id}" -__all__ = ["UserORM", "UserORMRelatedModel"] +__all__ = ["UserModel"] diff --git a/src/db/models/user_role.py b/src/db/models/user_role.py new file mode 100644 index 0000000..ec9ae4a --- /dev/null +++ b/src/db/models/user_role.py @@ -0,0 +1,56 @@ +import enum +from datetime import datetime +from typing import Optional +from typing import TYPE_CHECKING + +from sqlalchemy import Column +from sqlalchemy import DateTime +from sqlalchemy import Enum +from sqlmodel import Field +from sqlmodel import Relationship +from sqlmodel import SQLModel + + +if TYPE_CHECKING: + from .user import UserModel + + +class UserRole(enum.StrEnum): + USER = "USER" + SUPERUSER = "SUPERUSER" + UNKNOWN = "UNKNOWN" + + +class UserRoleModel(SQLModel, table=True): + + __tablename__ = "user_roles" + + id: Optional[int] = Field(default=None, primary_key=True) + created_at: Optional[datetime] = Field( + sa_column=Column( + DateTime, + default=datetime.utcnow, + nullable=False, + ) + ) + updated_at: Optional[datetime] = Field( + sa_column=Column( + DateTime, + default=datetime.utcnow, + onupdate=datetime.utcnow, + ) + ) + role: "UserRole" = Field( + sa_column=Column( + Enum(UserRole, length=15), + default=UserRole.USER, + nullable=False, + index=False, + unique=True, + ) + ) + + users: list["UserModel"] = Relationship(back_populates="role", cascade_delete=True) + + +__all__ = ["UserRole", "UserRoleModel"] diff --git a/src/db/svg_schema/__init__.py b/src/db/svg_schema/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/db/svg_schema/db-schema b/src/db/svg_schema/db-schema deleted file mode 100644 index bc3c28b..0000000 --- a/src/db/svg_schema/db-schema +++ /dev/null @@ -1,189 +0,0 @@ - - - - - - - - %3 - - - - UserORM - - - - - UserORM - - - id - - - INTEGER (PK) - - - user_id - - - BIGINT (Index) - - - username - - - VARCHAR(255) (Index) - - - firstname - - - VARCHAR(255) (Index) - - - lastname - - - VARCHAR(255) (Index) - - - is_superuser - - - BOOLEAN (Index) - - - created_at - - - TIMESTAMP () - - - updated_at - - - TIMESTAMP () - - - - - - ChannelORM - - - - - ChannelORM - - - id - - - INTEGER (PK) - - - url - - - VARCHAR(255) (Index) - - - label - - - VARCHAR(255) (Index) - - - enabled - - - BOOLEAN (Index) - - - created_at - - - TIMESTAMP () - - - updated_at - - - TIMESTAMP () - - - user_id - - - INTEGER () - - - - - - UserORM->ChannelORM - - - - - - - channels - - - - ChannelORM->UserORM - - - - - - - user - - - - MessageLogORM - - - - - MessageLogORM - - - id - - - INTEGER (PK) - - - message_id - - - BIGINT (Index) - - - text - - - VARCHAR () - - - created_at - - - TIMESTAMP () - - - updated_at - - - TIMESTAMP () - - - - - diff --git a/src/db/svg_schema/db-schema.svg b/src/db/svg_schema/db-schema.svg deleted file mode 100644 index 48be8b5..0000000 --- a/src/db/svg_schema/db-schema.svg +++ /dev/null @@ -1,183 +0,0 @@ - - - - - - -%3 - - - -UserORM - - - -UserORM - - -id - - -INTEGER (PK) - - -user_id - - -BIGINT (Index) - - -username - - -VARCHAR(255) (Index) - - -firstname - - -VARCHAR(255) (Index) - - -lastname - - -VARCHAR(255) (Index) - - -is_superuser - - -BOOLEAN (Index) - - -created_at - - -TIMESTAMP () - - -updated_at - - -TIMESTAMP () - - - - - -ChannelORM - - - -ChannelORM - - -id - - -INTEGER (PK) - - -url - - -VARCHAR(255) (Index) - - -label - - -VARCHAR(255) (Index) - - -enabled - - -BOOLEAN (Index) - - -created_at - - -TIMESTAMP () - - -updated_at - - -TIMESTAMP () - - -user_id - - -INTEGER () - - - - - -UserORM->ChannelORM - - - - - -channels - - - -ChannelORM->UserORM - - - - - -user - - - -MessageLogORM - - - -MessageLogORM - - -id - - -INTEGER (PK) - - -message_id - - -BIGINT (Index) - - -text - - -VARCHAR () - - -created_at - - -TIMESTAMP () - - -updated_at - - -TIMESTAMP () - - - - - diff --git a/src/db/svg_schema/generator.py b/src/db/svg_schema/generator.py deleted file mode 100644 index 7cfcdcf..0000000 --- a/src/db/svg_schema/generator.py +++ /dev/null @@ -1,17 +0,0 @@ -from sqlalchemy_data_model_visualizer import add_web_font_and_interactivity -from sqlalchemy_data_model_visualizer import generate_data_model_diagram - -from src.db.models import ChannelORM -from src.db.models import MessageLogORM -from src.db.models import UserORM - - -def generate(): - models = [UserORM, MessageLogORM, ChannelORM] - generate_data_model_diagram(models=models, output_file="db-schema", add_labels=True) - add_web_font_and_interactivity( - input_svg_file="db-schema.svg", output_svg_file="db-schema" - ) - - -generate() diff --git a/src/dto/__init__.py b/src/dto/__init__.py deleted file mode 100644 index c2bb947..0000000 --- a/src/dto/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -from .base import DTO -from .channel import ChannelCreateDTO -from .channel import ChannelRetrieveDTO -from .message_log import MessageLogCreateDTO -from .message_log import MessageLogRetrieveDTO -from .user import UserCreateDTO -from .user import UserRetrieveDTO -from .youtube_videoinfo import YoutubeErrorInfoDTO -from .youtube_videoinfo import YoutubeVideoInfoDTO - -__all__ = [ - "ChannelCreateDTO", - "ChannelRetrieveDTO", - "DTO", - "MessageLogCreateDTO", - "MessageLogRetrieveDTO", - "UserCreateDTO", - "UserRetrieveDTO", - "YoutubeErrorInfoDTO", - "YoutubeVideoInfoDTO", -] diff --git a/src/dto/channel.py b/src/dto/channel.py deleted file mode 100644 index c6031cf..0000000 --- a/src/dto/channel.py +++ /dev/null @@ -1,55 +0,0 @@ -from datetime import datetime - -from pydantic import ConfigDict -from pydantic import Field -from pydantic import field_validator - -from .base import DTO -from .user import UserRetrieveDTO - - -class ChannelBaseDTO(DTO): - url: str = Field(max_length=255) - label: str = Field(max_length=255) - user_id: int - enabled: bool - - model_config = ConfigDict(from_attributes=True) - - @field_validator("url") - def url_in_username(cls, v: str) -> str: - if "/@" not in v: - raise ValueError( - "Url must be contain @username: https://www.youtube.com/@username" - ) - - return v.lower() - - -class ChannelCreateDTO(ChannelBaseDTO): ... - - -class ChannelRetrieveDTO(ChannelCreateDTO): - id: int - user: UserRetrieveDTO - created_at: datetime - updated_at: datetime - - def to_html(self) -> str: - user_attribute_list = [self.user.username, self.user.user_id] - attribute = next(item for item in user_attribute_list if item is not None) - user_link = f'{attribute}' - - return ( - f"📺 Selected channel:
" - f"├──enabled: {self.enabled}
" - f"├──id: {self.id}
" - f"├──label: {self.label}
" - f"├──url: {self.url}
" - f"├──added by: {user_link}
" - f"├──added at: {self.created_at}
" - f"└──last modified at: {self.updated_at}
" - ) - - -__all__ = ["ChannelCreateDTO", "ChannelRetrieveDTO"] diff --git a/src/dto/message_log.py b/src/dto/message_log.py deleted file mode 100644 index 0ce5535..0000000 --- a/src/dto/message_log.py +++ /dev/null @@ -1,25 +0,0 @@ -from datetime import datetime - -from pydantic import ConfigDict -from pydantic import Field - -from .base import DTO - - -class MessageLogBaseDTO(DTO): - message_id: int - text: str - updated_at: datetime = Field(default=datetime.utcnow()) - - model_config = ConfigDict(from_attributes=True) - - -class MessageLogCreateDTO(MessageLogBaseDTO): - ... - - -class MessageLogRetrieveDTO(MessageLogCreateDTO): - id: int - - -__all__ = ["MessageLogCreateDTO", "MessageLogRetrieveDTO"] diff --git a/src/dto/user.py b/src/dto/user.py deleted file mode 100644 index 5e75cbf..0000000 --- a/src/dto/user.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import Optional - -from pydantic import ConfigDict - -from .base import DTO - - -class UserBaseDTO(DTO): - user_id: int - username: Optional[str] = None - firstname: Optional[str] = None - lastname: Optional[str] = None - is_superuser: bool - - @property - def get_url_generated_by_id(self) -> str: - return f"tg://openmessage?user_id={self.user_id}" - - model_config = ConfigDict(from_attributes=True) - - -class UserCreateDTO(UserBaseDTO): - ... - - -class UserRetrieveDTO(UserCreateDTO): - id: int - - -__all__ = ["UserCreateDTO", "UserRetrieveDTO"] diff --git a/src/scheduler/jobs/telegram_notify_job/data_fetcher/__init__.py b/src/scheduler/jobs/telegram_notify_job/data_fetcher/__init__.py index 57527d9..c4af938 100644 --- a/src/scheduler/jobs/telegram_notify_job/data_fetcher/__init__.py +++ b/src/scheduler/jobs/telegram_notify_job/data_fetcher/__init__.py @@ -1,3 +1,4 @@ -from .fetcher import async_fetch_livestreams +from .twitch import async_twitch_fetch_livestreams +from .youtube import async_youtube_fetch_livestreams -__all__ = ["async_fetch_livestreams"] +__all__ = ["async_twitch_fetch_livestreams", "async_youtube_fetch_livestreams"] diff --git a/src/scheduler/jobs/telegram_notify_job/data_fetcher/twitch/__init__.py b/src/scheduler/jobs/telegram_notify_job/data_fetcher/twitch/__init__.py new file mode 100644 index 0000000..4da0ddb --- /dev/null +++ b/src/scheduler/jobs/telegram_notify_job/data_fetcher/twitch/__init__.py @@ -0,0 +1,4 @@ +from .fetcher import async_twitch_fetch_livestreams + + +__all__ = ["async_twitch_fetch_livestreams"] diff --git a/src/scheduler/jobs/telegram_notify_job/data_fetcher/twitch/fetcher.py b/src/scheduler/jobs/telegram_notify_job/data_fetcher/twitch/fetcher.py new file mode 100644 index 0000000..5ead091 --- /dev/null +++ b/src/scheduler/jobs/telegram_notify_job/data_fetcher/twitch/fetcher.py @@ -0,0 +1,84 @@ +import asyncio +import operator +from datetime import datetime +from typing import Optional + +from dateutil.tz import tzutc +from twitchAPI.helper import first +from twitchAPI.object.api import Stream +from twitchAPI.twitch import Twitch + +from src.db.models import ChannelModel +from src.logger import logger +from src.scheduler.jobs.telegram_notify_job.data_fetcher.utils import make_time_readable +from src.scheduler.jobs.telegram_notify_job.dto import ErrorVideoInfo +from src.scheduler.jobs.telegram_notify_job.dto import VideoInfo +from src.utils import extract_twitch_username + + +async def async_fetch_livestream( + channel: ChannelModel, twitch: Twitch +) -> Optional[VideoInfo] | ErrorVideoInfo: + """ + :param twitch: + :param channel: + :return: + """ + await logger.ainfo(channel.model_dump_json()) + + live_stream = None + try: + + username = extract_twitch_username(channel.url) + if not username: + raise Exception(f"Cannot extract username for {channel.url}") + data: Optional[Stream] = await first( + twitch.get_streams(user_login=[username], first=1, stream_type="live") + ) + + if data: + concurrent_view_count = data.viewer_count + duration = make_time_readable( + (datetime.now(tz=tzutc()) - data.started_at).seconds + ) + + live_stream = VideoInfo( + url=channel.url, + label=channel.label, + concurrent_view_count=concurrent_view_count, + duration=duration, + ) + + except Exception as ex: + await logger.aerror(f"Fetching info error: {channel.url} {ex}") + return ErrorVideoInfo(channel=channel.model_dump(), ex_message=str(ex)) + + return live_stream + + +async def async_twitch_fetch_livestreams( + channels: list[ChannelModel], twitch: Twitch +) -> tuple[list[VideoInfo], list[ErrorVideoInfo]]: + """ + :param twitch: + :param channels: + :return: + """ + tasks = [ + async_fetch_livestream(channel=channel, twitch=twitch) for channel in channels + ] + + data = await asyncio.gather(*tasks) + + errors = [stream for stream in data if isinstance(stream, ErrorVideoInfo)] + + live_streams = [stream for stream in data if isinstance(stream, VideoInfo)] + + live_streams = sorted( + live_streams, key=operator.attrgetter("concurrent_view_count"), reverse=True + ) + + return live_streams, errors + + +__all__ = ["async_twitch_fetch_livestreams"] diff --git a/src/scheduler/jobs/telegram_notify_job/data_fetcher/youtube/__init__.py b/src/scheduler/jobs/telegram_notify_job/data_fetcher/youtube/__init__.py new file mode 100644 index 0000000..e7df973 --- /dev/null +++ b/src/scheduler/jobs/telegram_notify_job/data_fetcher/youtube/__init__.py @@ -0,0 +1,3 @@ +from .fetcher import async_youtube_fetch_livestreams + +__all__ = ["async_youtube_fetch_livestreams"] diff --git a/src/scheduler/jobs/telegram_notify_job/data_fetcher/fetcher.py b/src/scheduler/jobs/telegram_notify_job/data_fetcher/youtube/fetcher.py similarity index 72% rename from src/scheduler/jobs/telegram_notify_job/data_fetcher/fetcher.py rename to src/scheduler/jobs/telegram_notify_job/data_fetcher/youtube/fetcher.py index 8896262..d7a1079 100644 --- a/src/scheduler/jobs/telegram_notify_job/data_fetcher/fetcher.py +++ b/src/scheduler/jobs/telegram_notify_job/data_fetcher/youtube/fetcher.py @@ -5,17 +5,17 @@ import yt_dlp -from .utils import make_time_readable +from src.db.models import ChannelModel from src.decorators import wrap_sync_to_async -from src.dto import ChannelRetrieveDTO -from src.dto import YoutubeErrorInfoDTO -from src.dto import YoutubeVideoInfoDTO from src.logger import logger +from src.scheduler.jobs.telegram_notify_job.data_fetcher.utils import make_time_readable +from src.scheduler.jobs.telegram_notify_job.dto import ErrorVideoInfo +from src.scheduler.jobs.telegram_notify_job.dto import VideoInfo def fetch_live_stream( - channel: ChannelRetrieveDTO, ydl: yt_dlp.YoutubeDL -) -> Optional[YoutubeVideoInfoDTO] | YoutubeErrorInfoDTO: + channel: ChannelModel, ydl: yt_dlp.YoutubeDL +) -> Optional[VideoInfo] | ErrorVideoInfo: """ :param ydl: :param channel: @@ -54,7 +54,7 @@ def fetch_live_stream( int(datetime.now().timestamp() - release_timestamp) ) url = live_info["original_url"] - live_stream = YoutubeVideoInfoDTO( + live_stream = VideoInfo( url=url, label=channel.label, like_count=like_count, @@ -65,7 +65,7 @@ def fetch_live_stream( except Exception as ex: logger.error(f"Fetching info error: {channel.url} {ex}") - return YoutubeErrorInfoDTO(channel=channel, ex_message=str(ex)) + return ErrorVideoInfo(channel=channel.model_dump(), ex_message=str(ex)) return live_stream @@ -73,9 +73,9 @@ def fetch_live_stream( async_fetch_livestream = wrap_sync_to_async(fetch_live_stream) -async def async_fetch_livestreams( - channels: list[ChannelRetrieveDTO], ydl: yt_dlp.YoutubeDL -) -> tuple[list[YoutubeVideoInfoDTO], list[YoutubeErrorInfoDTO]]: +async def async_youtube_fetch_livestreams( + channels: list[ChannelModel], ydl: yt_dlp.YoutubeDL +) -> tuple[list[VideoInfo], list[ErrorVideoInfo]]: """ :param ydl: :param channels: @@ -83,19 +83,17 @@ async def async_fetch_livestreams( """ tasks = [async_fetch_livestream(channel=channel, ydl=ydl) for channel in channels] - live_streams = await asyncio.gather(*tasks) + data = await asyncio.gather(*tasks) - errors = [ - stream for stream in live_streams if isinstance(stream, YoutubeErrorInfoDTO) - ] + errors = [stream for stream in data if isinstance(stream, ErrorVideoInfo)] + + live_streams = [stream for stream in data if isinstance(stream, VideoInfo)] - live_streams = [ - stream for stream in live_streams if isinstance(stream, YoutubeVideoInfoDTO) - ] live_streams = sorted( live_streams, key=operator.attrgetter("concurrent_view_count"), reverse=True ) + return live_streams, errors -__all__ = ["async_fetch_livestreams"] +__all__ = ["async_youtube_fetch_livestreams"] diff --git a/src/scheduler/jobs/telegram_notify_job/dto/__init__.py b/src/scheduler/jobs/telegram_notify_job/dto/__init__.py new file mode 100644 index 0000000..7c612e6 --- /dev/null +++ b/src/scheduler/jobs/telegram_notify_job/dto/__init__.py @@ -0,0 +1,4 @@ +from .videoinfo import ErrorVideoInfo +from .videoinfo import VideoInfo + +__all__ = ["ErrorVideoInfo", "VideoInfo"] diff --git a/src/dto/base.py b/src/scheduler/jobs/telegram_notify_job/dto/base.py similarity index 100% rename from src/dto/base.py rename to src/scheduler/jobs/telegram_notify_job/dto/base.py diff --git a/src/dto/youtube_videoinfo.py b/src/scheduler/jobs/telegram_notify_job/dto/videoinfo.py similarity index 53% rename from src/dto/youtube_videoinfo.py rename to src/scheduler/jobs/telegram_notify_job/dto/videoinfo.py index 1e5b857..6b45448 100644 --- a/src/dto/youtube_videoinfo.py +++ b/src/scheduler/jobs/telegram_notify_job/dto/videoinfo.py @@ -1,20 +1,19 @@ from typing import Optional from .base import DTO -from .channel import ChannelRetrieveDTO -class YoutubeVideoInfoDTO(DTO): +class ErrorVideoInfo(DTO): + channel: dict + ex_message: str + + +class VideoInfo(DTO): url: str label: str - like_count: Optional[int] = None concurrent_view_count: Optional[int] = None duration: Optional[str] = None + like_count: Optional[int] = None -class YoutubeErrorInfoDTO(DTO): - channel: ChannelRetrieveDTO - ex_message: str - - -__all__ = ["YoutubeErrorInfoDTO", "YoutubeVideoInfoDTO"] +__all__ = ["ErrorVideoInfo", "VideoInfo"] diff --git a/src/scheduler/jobs/telegram_notify_job/notifier/notify.py b/src/scheduler/jobs/telegram_notify_job/notifier/notify.py index d4a1934..3c0930e 100644 --- a/src/scheduler/jobs/telegram_notify_job/notifier/notify.py +++ b/src/scheduler/jobs/telegram_notify_job/notifier/notify.py @@ -7,14 +7,17 @@ from aiogram.exceptions import TelegramNetworkError from aiogram.utils.chat_action import ChatActionSender from sulguk import SULGUK_PARSE_MODE +from twitchAPI.twitch import Twitch -from ..data_fetcher import async_fetch_livestreams +from ..data_fetcher import async_twitch_fetch_livestreams +from ..data_fetcher import async_youtube_fetch_livestreams +from ..dto import ErrorVideoInfo +from ..dto import VideoInfo from ..report_generator import generate_jinja_report from .utils import check_if_need_send_instead_of_edit from src.db import DataAccessLayer -from src.dto import MessageLogCreateDTO -from src.dto import YoutubeErrorInfoDTO -from src.dto import YoutubeVideoInfoDTO +from src.db.models import MessageLogModel +from src.db.models.channel_type import ChannelType from src.logger import logger @@ -26,6 +29,7 @@ async def notify( empty_template: Optional[str], report_template: str, dal: DataAccessLayer, + twitch: Optional[Twitch] = None, ) -> None: """ :param dal: @@ -35,21 +39,42 @@ async def notify( :param ydl: :param bot: :param chat_id: + :param twitch :return: """ # get channels channels = await dal.get_channels(enabled=True) - data: tuple[ - list[YoutubeVideoInfoDTO], list[YoutubeErrorInfoDTO] - ] = await async_fetch_livestreams(channels=channels, ydl=ydl) + youtube_channels = [ + channel for channel in channels if channel.type.type == ChannelType.YOUTUBE + ] + twitch_channels = [ + channel for channel in channels if channel.type.type == ChannelType.TWITCH + ] + + data: tuple[list[VideoInfo], list[ErrorVideoInfo]] = ( + await async_youtube_fetch_livestreams(channels=youtube_channels, ydl=ydl) + ) live_list, errors = data + if twitch: + _twitch = await twitch + + twitch_data: tuple[list[VideoInfo], list[ErrorVideoInfo]] = ( + await async_twitch_fetch_livestreams( + channels=twitch_channels, twitch=_twitch + ) + ) + + twitch_live_list, twitch_errors = twitch_data + live_list.extend(twitch_live_list) + errors.extend(twitch_errors) + # logging errors for error in errors: - await logger.aerror(f"Error with {error.channel.id}: {error.ex_message}") + await logger.aerror(f"Error with {error.channel['id']}: {error.ex_message}") await logger.ainfo(f"Live list length {len(live_list)}") @@ -147,9 +172,7 @@ async def notify( if message_id: await dal.create_message( - message_log_schema=MessageLogCreateDTO( - message_id=message_id, text=message_text - ) + obj=MessageLogModel(message_id=message_id, text=message_text) ) diff --git a/src/scheduler/jobs/telegram_notify_job/report_generator.py b/src/scheduler/jobs/telegram_notify_job/report_generator.py index 0cec07d..864d6c6 100644 --- a/src/scheduler/jobs/telegram_notify_job/report_generator.py +++ b/src/scheduler/jobs/telegram_notify_job/report_generator.py @@ -2,11 +2,11 @@ from jinja2 import Template -from src.dto import YoutubeVideoInfoDTO +from src.scheduler.jobs.telegram_notify_job.dto import VideoInfo def generate_jinja_report( - data: list[YoutubeVideoInfoDTO], report_template: str, empty_template: Optional[str] + data: list[VideoInfo], report_template: str, empty_template: Optional[str] ) -> Optional[str]: """ :param empty_template: diff --git a/src/scheduler/utils.py b/src/scheduler/utils.py index 46ee85b..7a1e40c 100644 --- a/src/scheduler/utils.py +++ b/src/scheduler/utils.py @@ -7,9 +7,9 @@ from aiogram import Bot from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.triggers.interval import IntervalTrigger +from twitchAPI.twitch import Twitch from src.config import Config -from src.constants import COOKIES_FILE_PATH from src.db import DataAccessLayer from src.scheduler.jobs.telegram_notify_job import notify @@ -25,8 +25,14 @@ def setup_scheduler(conf: Config, bot: Bot, dal: DataAccessLayer) -> AsyncIOSche scheduler = AsyncIOScheduler() cookiefile: Optional[TextIO] + try: - cookiefile = open(file=COOKIES_FILE_PATH, encoding="utf-8") + youtube = conf.youtube + if youtube: + cookies_filepath = youtube.cookies_filepath + cookiefile = open(file=cookies_filepath, encoding="utf-8") + else: + cookiefile = None except FileNotFoundError: cookiefile = None @@ -41,6 +47,11 @@ def setup_scheduler(conf: Config, bot: Bot, dal: DataAccessLayer) -> AsyncIOSche "extractor_args": {"youtubetab": {"skip": "authcheck"}}, } + if conf.twitch: + twitch = Twitch(app_id=conf.twitch.app_id, app_secret=conf.twitch.app_secret) + else: + twitch = None + ydl = yt_dlp.YoutubeDL(ydl_opts) notify_kwargs = { @@ -51,6 +62,7 @@ def setup_scheduler(conf: Config, bot: Bot, dal: DataAccessLayer) -> AsyncIOSche "report_template": conf.report.template, "empty_template": conf.report.empty, "dal": dal, + "twitch": twitch, } scheduler.add_job( notify, @@ -62,19 +74,6 @@ def setup_scheduler(conf: Config, bot: Bot, dal: DataAccessLayer) -> AsyncIOSche next_run_time=datetime.now(), ) - # TODO TEMPRORARY COMMENT - # JUST CLARIFY ERROR - # auto_turn_off_kwargs = {"dal": dal} - # scheduler.add_job( - # auto_turn_off, - # trigger=IntervalTrigger(seconds=conf.interval_s * 2), - # kwargs=auto_turn_off_kwargs, - # replace_existing=True, - # max_instances=1, - # coalesce=True, - # next_run_time=datetime.now(), - # ) - return scheduler diff --git a/src/utils.py b/src/utils.py index a513b8a..59da213 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,13 +1,39 @@ import re +from typing import Optional YOUTUBE_USERNAME_CHANNEL_LINK_PATTERN = re.compile( r"^https?://(?:www\.)?youtube\.com/@[\w-]+/?$" ) +TWITCH_USERNAME_CHANNEL_LINK_PATTERN = re.compile( + r"^https?://(?:www\.)?twitch\.tv/([\w-]+)/?$" +) + -def youtube_channel_url_validator(link: str): +def youtube_channel_url_validator(link: str) -> bool: match = re.match(YOUTUBE_USERNAME_CHANNEL_LINK_PATTERN, link) return bool(match) -__all__ = ["youtube_channel_url_validator"] +def twitch_channel_url_validator(link: str) -> bool: + match = re.match(TWITCH_USERNAME_CHANNEL_LINK_PATTERN, link) + return bool(match) + + +def extract_twitch_username(link: str) -> Optional[str]: + match = re.match(TWITCH_USERNAME_CHANNEL_LINK_PATTERN, link) + if match: + return match.group(1) + return None + + +def kick_channel_url_validator(link: str) -> bool: + return False + + +__all__ = [ + "extract_twitch_username", + "kick_channel_url_validator", + "twitch_channel_url_validator", + "youtube_channel_url_validator", +]