Skip to content

Commit

Permalink
Add in-memory Telegram storage
Browse files Browse the repository at this point in the history
  • Loading branch information
Saluev committed Dec 17, 2023
1 parent b5b524e commit 12c97a2
Show file tree
Hide file tree
Showing 8 changed files with 334 additions and 65 deletions.
121 changes: 121 additions & 0 deletions suppgram/bridges/inmemory_telegram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from dataclasses import replace
from typing import List, Optional, Any
from uuid import uuid4

from suppgram.frontends.telegram import TelegramStorage
from suppgram.frontends.telegram.storage import (
TelegramMessage,
TelegramMessageKind,
TelegramGroup,
TelegramGroupRole,
)


class InMemoryTelegramStorage(TelegramStorage):
"""In-memory implementation of [Storage][suppgram.storage.Storage] used in tests."""

def __init__(self) -> None:
self.groups: List[TelegramGroup] = []
self.messages: List[TelegramMessage] = []

async def get_group(self, telegram_chat_id: int) -> TelegramGroup:
try:
return next(g for g in self.groups if g.telegram_chat_id == telegram_chat_id)
except StopIteration:
raise ValueError

async def create_or_update_group(self, telegram_chat_id: int) -> TelegramGroup:
try:
return await self.get_group(telegram_chat_id)
except ValueError:
group = TelegramGroup(telegram_chat_id=telegram_chat_id, roles=frozenset())
self.groups.append(group)
return group

async def add_group_roles(self, telegram_chat_id: int, *roles: TelegramGroupRole):
try:
idx = next(
i for i, g in enumerate(self.groups) if g.telegram_chat_id == telegram_chat_id
)
except StopIteration:
raise ValueError
group = self.groups.pop(idx)
group = replace(group, roles=group.roles | {*roles})
self.groups.append(group)
return group

async def get_groups_by_role(self, role: TelegramGroupRole) -> List[TelegramGroup]:
return [g for g in self.groups if role in g.roles]

async def insert_message(
self,
telegram_bot_id: int,
group: TelegramGroup,
telegram_message_id: int,
kind: TelegramMessageKind,
*,
agent_id: Optional[Any] = None,
customer_id: Optional[Any] = None,
conversation_id: Optional[Any] = None,
telegram_bot_username: Optional[str] = None
) -> TelegramMessage:
message = TelegramMessage(
id=uuid4(),
telegram_bot_id=telegram_bot_id,
group=group,
telegram_message_id=telegram_message_id,
kind=kind,
agent_id=agent_id,
customer_id=customer_id,
conversation_id=conversation_id,
telegram_bot_username=telegram_bot_username,
)
self.messages.append(message)
return message

async def get_message(self, group: TelegramGroup, telegram_message_id: int) -> TelegramMessage:
try:
return next(
m
for m in self.messages
if m.group.telegram_chat_id == group.telegram_chat_id
and m.telegram_message_id == telegram_message_id
)
except StopIteration:
raise ValueError

async def get_messages(
self,
kind: TelegramMessageKind,
*,
agent_id: Optional[Any] = None,
conversation_id: Optional[Any] = None,
telegram_bot_username: Optional[str] = None
) -> List[TelegramMessage]:
return [
m
for m in self.messages
if m.kind == kind
and (agent_id is None or m.agent_id == agent_id)
and (conversation_id is None or m.conversation_id == conversation_id)
and (telegram_bot_username is None or m.telegram_bot_username == telegram_bot_username)
]

async def delete_messages(self, messages: List[TelegramMessage]):
print(messages, self.messages)
message_ids = {m.id for m in messages}
self.messages = [m for m in self.messages if m.id not in message_ids]

async def get_newer_messages_of_kind(
self, messages: List[TelegramMessage]
) -> List[TelegramMessage]:
return [
newer
for newer in self.messages
if any(
newer.group.telegram_chat_id == older.group.telegram_chat_id
and newer.telegram_message_id > older.telegram_message_id
and newer.kind == older.kind
for older in messages
)
]
21 changes: 21 additions & 0 deletions tests/bridges/test_inmemory_telegram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Any
from uuid import uuid4

import pytest_asyncio

from suppgram.bridges.inmemory_telegram import InMemoryTelegramStorage
from suppgram.storages.inmemory import InMemoryStorage
from tests.frontends.telegram.storage import TelegramStorageTestSuite

pytest_plugins = ("pytest_asyncio",)


class TestInMemoryTelegramStorage(TelegramStorageTestSuite):
@pytest_asyncio.fixture(autouse=True)
async def _create_storage(self):
self.telegram_storage = InMemoryTelegramStorage()
self.storage = InMemoryStorage()
await self.telegram_storage.initialize()

def generate_id(self) -> Any:
return uuid4()
6 changes: 3 additions & 3 deletions tests/bridges/test_mongodb_telegram.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
class TestMongoDBTelegramBridge(TelegramStorageTestSuite):
@pytest_asyncio.fixture(autouse=True)
async def _create_storage(self, mongodb_database, mongodb_storage):
self.storage = MongoDBTelegramBridge(mongodb_database)
self.suppgram_storage = mongodb_storage
await self.storage.initialize()
self.telegram_storage = MongoDBTelegramBridge(mongodb_database)
self.storage = mongodb_storage
await self.telegram_storage.initialize()

def generate_id(self) -> Any:
return ObjectId()
6 changes: 3 additions & 3 deletions tests/bridges/test_sqlalchemy_telegram.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ class TestSQLAlchemyTelegramBridge(TelegramStorageTestSuite):
@pytest_asyncio.fixture(autouse=True)
async def _create_storage(self, sqlite_engine, sqlalchemy_storage):
# SQLAlchemyStorage implementation is needed for related tables to exist.
self.storage = SQLAlchemyTelegramBridge(sqlite_engine)
self.suppgram_storage = sqlalchemy_storage
await self.storage.initialize()
self.telegram_storage = SQLAlchemyTelegramBridge(sqlite_engine)
self.storage = sqlalchemy_storage
await self.telegram_storage.initialize()

@pytest.fixture(autouse=True)
def _make_generate_id(self, generate_sqlite_id: Callable[[], int]):
Expand Down
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ async def mongodb_storage(mongodb_database) -> Storage:

@pytest.fixture(scope="session")
def generate_sqlite_id() -> Callable[[], int]:
return count(1).__next__
# Not supposed to intersect with any natively generated IDs.
return count(100000).__next__


@pytest.fixture(scope="session")
Expand Down
Loading

0 comments on commit 12c97a2

Please sign in to comment.