Skip to content

Commit

Permalink
feat: change spam prediction to internal APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
MrMissx committed Jul 15, 2024
1 parent 88cc20a commit 7856e32
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 135 deletions.
135 changes: 70 additions & 65 deletions anjani/internal_plugins/spam_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@

import asyncio
import re
from datetime import datetime, time, timedelta
from hashlib import md5, sha256
from random import randint
from typing import Any, Callable, ClassVar, List, MutableMapping, Optional, Tuple
from typing import Any, Callable, ClassVar, List, Literal, MutableMapping, Optional, Tuple

from pyrogram.errors import (
ChatAdminRequired,
Expand All @@ -38,43 +37,60 @@
InlineKeyboardMarkup,
Message,
)

try:
from userbotindo import Classifier

_run_predict = True
except ImportError:
from anjani.util.types import Classifier

_run_predict = False
from pydantic import BaseModel, field_validator

from anjani import command, filters, listener, plugin, util
from anjani.core.metrics import SpamPredictionStat
from anjani.util.misc import StopPropagation


class TextLanguage(BaseModel):
language: str
probability: float


class PredictionResult(BaseModel):
is_spam: bool
spam_score: float
ham_score: float

@field_validator("spam_score", "ham_score")
def calc_score(cls, value: float) -> float:
return value * 100

def get_raw(self, field: Literal["ham", "spam"]) -> float:
return getattr(self, f"{field}_score") / 100


class SpamDetectionResponse(BaseModel):
prediction: PredictionResult
language: TextLanguage
processed_text: str


class SpamPrediction(plugin.Plugin):
name: ClassVar[str] = "SpamPredict"
helpable: ClassVar[bool] = True
disabled: ClassVar[bool] = not _run_predict

db: util.db.AsyncCollection
user_db: util.db.AsyncCollection
setting_db: util.db.AsyncCollection
model: Classifier

__predict_cost: int = 10
__log_channel: int = -1001314588569

async def on_load(self) -> None:
self.model = Classifier()
self._api_key = self.bot.config.SPAM_PREDICTION_API
self._predict_url = self.bot.config.SPAM_PREDICTION_URL

if not self._api_key or not self._predict_url:
self.bot.unload_plugin(self)
return

self.db = self.bot.db.get_collection("SPAM_DUMP")
self.user_db = self.bot.db.get_collection("USERS")
self.setting_db = self.bot.db.get_collection("SPAM_PREDICT_SETTING")

await self.__load_model()
self.bot.loop.create_task(self.__refresh_model())

async def on_chat_migrate(self, message: Message) -> None:
await self.db.update_one(
{"chat_id": message.migrate_from_chat_id},
Expand All @@ -90,26 +106,6 @@ async def on_plugin_restore(self, chat_id: int, data: MutableMapping[str, Any])
{"chat_id": chat_id}, {"$set": data[self.name]}, upsert=True
)

async def __refresh_model(self) -> None:
scheduled_time = time(hour=17) # Run at 00:00 WIB
while True:
now = datetime.utcnow()
date = now.date()
if now.time() > scheduled_time:
date = now.date() + timedelta(days=1)
then = datetime.combine(date, scheduled_time)
self.log.debug("Next model refresh at %s UTC", then)
await asyncio.sleep((then - now).total_seconds())
await self.__load_model()

async def __load_model(self) -> None:
self.log.info("Downloading spam prediction model!")
try:
await self.model.load_model(self.bot.http)
except RuntimeError:
self.log.warning("Failed to download prediction model!")
self.bot.unload_plugin(self)

@staticmethod
def _build_hash(content: str) -> str:
return sha256(content.strip().encode()).hexdigest()
Expand All @@ -136,6 +132,20 @@ async def _collect_random_sample(self, proba: float, uid: Optional[int]) -> None
},
) # Do not upsert

async def check_spam(self, text: str) -> SpamDetectionResponse:
async with self.bot.http.post(
self._predict_url,
json={"text": text},
headers={"x-api-key": self._api_key},
) as resp:
if resp.status != 200:
raise ValueError(f"Failed to get prediction: {resp.status}")
res = await resp.json()
if not res["data"]:
raise ValueError("Unexpected response")

return SpamDetectionResponse(**res["data"])

@listener.filters(
filters.regex(r"spam_check_(?P<value>t|f)") | filters.regex(r"spam_ban_(?P<user>.*)")
)
Expand Down Expand Up @@ -314,26 +324,25 @@ async def spam_check(self, message: Message, text: str) -> None:
except AttributeError:
user = None

text_norm = self.model.normalize(text)
if len(text_norm.split()) < 4: # Skip short messages
try:
result = await self.check_spam(text)
except ValueError:
self.bot.log.debug("Failed to get prediction")
return

response = await self.model.predict(text_norm)
await self.bot.log_stat("predicted")
SpamPredictionStat.labels("predicted").inc()
if response.size == 0:
return

probability = response[0][1]
probability = result.prediction.spam_score

await self._collect_random_sample(probability, user)
await self._collect_random_sample(result.prediction.get_raw("spam"), user)

if probability <= 0.5:
if probability <= 50:
return

content_hash = self._build_hash(text)
identifier = self._build_hex(user)
proba_str = self.model.prob_to_string(probability)
proba_str = str(probability)
msg_id = None

# only log public chat
Expand Down Expand Up @@ -372,7 +381,7 @@ async def spam_check(self, message: Message, text: str) -> None:
"proba": probability,
"msg_id": msg.id,
"date": util.time.sec(),
"text": text_norm,
"text": result.processed_text,
},
)

Expand Down Expand Up @@ -451,11 +460,6 @@ async def spam_check(self, message: Message, text: str) -> None:
)
raise StopPropagation

@command.filters(filters.staff_only)
async def cmd_update_model(self, ctx: command.Context) -> Optional[str]:
await self.__load_model()
await ctx.respond("Done", delete_after=5)

@command.filters(filters.staff_only)
async def cmd_spam(self, ctx: command.Context) -> Optional[str]:
"""Manual spam detection by bot staff"""
Expand All @@ -475,13 +479,13 @@ async def cmd_spam(self, ctx: command.Context) -> Optional[str]:

identifier = self._build_hex(user_id)
content_hash = self._build_hash(content)
content_normalized = self.model.normalize(content.strip())
pred = await self.model.predict(content_normalized)
if pred.size == 0:
try:
prediction = await self.check_spam(content)
except ValueError:
return "Prediction failed"

proba = pred[0][1]
text = f"#SPAM\n\n**CPU Prediction**: `{self.model.prob_to_string(proba)}`\n"
proba = prediction.prediction.spam_score
text = f"#SPAM\n\n**CPU Prediction**: `{proba}`\n"
if identifier:
text += f"**Identifier**: `{identifier}`\n"

Expand All @@ -491,7 +495,7 @@ async def cmd_spam(self, ctx: command.Context) -> Optional[str]:
{"_id": content_hash},
{
"$set": {
"text": content_normalized,
"text": prediction.processed_text,
"spam": 1,
"ham": 0,
}
Expand Down Expand Up @@ -538,16 +542,17 @@ async def cmd_predict(self, ctx: command.Context) -> Optional[str]:
content = replied.text or replied.caption
if not content:
return await ctx.get_text("spampredict-empty")
content = self.model.normalize(content.strip())
pred = await self.model.predict(content)
await self.bot.log_stat("predicted")
if pred.size == 0:
try:
prediction = await self.check_spam(content)
except ValueError:
return await ctx.get_text("spampredict-failed")

await self.bot.log_stat("predicted")

textPrediction = (
f"**Is Spam**: {await self.model.is_spam(content)}\n"
f"**Spam Prediction**: `{self.model.prob_to_string(pred[0][1])}`\n"
f"**Ham Prediction**: `{self.model.prob_to_string(pred[0][0])}`"
f"**Is Spam**: {prediction.prediction.is_spam}\n"
f"**Spam Prediction**: `{prediction.prediction.spam_score}`\n"
f"**Ham Prediction**: `{prediction.prediction.ham_score}`"
)
await asyncio.gather(
self.bot.log_stat("predicted"),
Expand Down
6 changes: 6 additions & 0 deletions anjani/util/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class Config:
HEALTH_CHECK_INTERVAL: Optional[int]
HEALTH_CHECK_WEBHOOK_URL: Optional[str]

SPAM_PREDICTION_URL: Optional[str]
SPAM_PREDICTION_API: Optional[str]

IS_CI: bool

def __init__(self) -> None:
Expand Down Expand Up @@ -59,6 +62,9 @@ def __init__(self) -> None:
self.HEALTH_CHECK_INTERVAL = int(getenv("HEALTH_CHECK_INTERVAL", 60))
self.HEALTH_CHECK_WEBHOOK_URL = getenv("HEALTH_CHECK_WEBHOOK_URL")

self.SPAM_PREDICTION_URL = getenv("SPAM_PREDICTION_URL")
self.SPAM_PREDICTION_API = getenv("SPAM_PREDICTION_API")

self.IS_CI = getenv("IS_CI", "false").lower() == "true"

# check if all the required variables are set
Expand Down
69 changes: 1 addition & 68 deletions anjani/util/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Anjani custom types"""

# Copyright (C) 2020 - 2023 UserbotIndo Team, <https://github.com/userbotindo.git>
#
# This program is free software: you can redistribute it and/or modify
Expand All @@ -14,10 +15,8 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from abc import abstractmethod, abstractproperty
from typing import TYPE_CHECKING, Any, Callable, Protocol, TypeVar

from aiohttp import ClientSession
from pyrogram.filters import Filter

if TYPE_CHECKING:
Expand All @@ -27,74 +26,8 @@
ChatId = TypeVar("ChatId", int, None, covariant=True)
TextName = TypeVar("TextName", bound=str, covariant=True)
NoFormat = TypeVar("NoFormat", bound=bool, covariant=True)
TypeData = TypeVar("TypeData", covariant=True)
DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any])


class Instantiable(Protocol):
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError


class CustomFilter(Filter): # skipcq: PYL-W0223
anjani: "Anjani"
include_bot: bool


class NDArray(Protocol[TypeData]):
@abstractmethod
def __getitem__(self, key: int) -> Any:
raise NotImplementedError

@abstractproperty
def size(self) -> int:
raise NotImplementedError


class Classifier(Protocol):
@abstractmethod
async def predict(self, text: str, **predict_params: Any) -> NDArray[Any]:
raise NotImplementedError

@abstractmethod
async def load_model(self, http_client: ClientSession) -> None:
raise NotImplementedError

@abstractmethod
async def is_spam(self, text: str) -> bool:
raise NotImplementedError

@staticmethod
@abstractmethod
def normalize(text: str) -> str:
raise NotImplementedError

@staticmethod
@abstractmethod
def prob_to_string(value: float) -> str:
raise NotImplementedError


class WebServer(Protocol):
@abstractmethod
async def run(self) -> None:
raise NotImplementedError

@abstractmethod
async def stop(self) -> None:
raise NotImplementedError

@abstractmethod
async def add_router(self, **router_param: Any) -> None:
raise NotImplementedError


class Router(Instantiable):
def get(self, *args, **kwargs) -> Callable[[DecoratedCallable], DecoratedCallable]:
raise NotImplementedError

def post(self, *args, **kwargs) -> Callable[[DecoratedCallable], DecoratedCallable]:
raise NotImplementedError

def put(self, *args, **kwargs) -> Callable[[DecoratedCallable], DecoratedCallable]:
raise NotImplementedError
Loading

0 comments on commit 7856e32

Please sign in to comment.