feat: add spam detection

This commit is contained in:
2024-10-08 19:10:54 +03:00
parent 6c9e13b5db
commit b4eac59f80
25 changed files with 4797 additions and 873 deletions

View File

@@ -1 +1 @@
from .main import module_init
from .main import module_init, module_late_init

View File

@@ -2,14 +2,14 @@ import os
from karkas_piccolo.conf.apps import AppConfig
from .tables import SpamLog
from .tables import SpamLog, VerifiedUsers
CURRENT_DIRECTORY = os.path.dirname(os.path.abspath(__file__))
APP_CONFIG = AppConfig(
app_name="standard.spam",
migrations_folder_path=os.path.join(CURRENT_DIRECTORY, "piccolo_migrations"),
table_classes=[SpamLog],
table_classes=[SpamLog, VerifiedUsers],
migration_dependencies=[],
commands=[],
)

View File

@@ -0,0 +1,64 @@
from piccolo.apps.migrations.auto.migration_manager import MigrationManager
from piccolo.columns.column_types import Boolean, Text
from piccolo.columns.indexes import IndexMethod
ID = "2024-10-08T18:31:17:058814"
VERSION = "1.20.0"
DESCRIPTION = ""
async def forwards():
manager = MigrationManager(
migration_id=ID, app_name="standard.spam", description=DESCRIPTION
)
manager.add_table(
class_name="VerifiedUsers",
tablename="verified_users",
schema=None,
columns=None,
)
manager.add_column(
table_class_name="VerifiedUsers",
tablename="verified_users",
column_name="key",
db_column_name="key",
column_class_name="Text",
column_class=Text,
params={
"default": "",
"null": False,
"primary_key": True,
"unique": False,
"index": False,
"index_method": IndexMethod.btree,
"choices": None,
"db_column_name": None,
"secret": False,
},
schema=None,
)
manager.add_column(
table_class_name="VerifiedUsers",
tablename="verified_users",
column_name="verified",
db_column_name="verified",
column_class_name="Boolean",
column_class=Boolean,
params={
"default": False,
"null": False,
"primary_key": False,
"unique": False,
"index": False,
"index_method": IndexMethod.btree,
"choices": None,
"db_column_name": None,
"secret": False,
},
schema=None,
)
return manager

View File

@@ -1,4 +1,4 @@
from piccolo.columns import JSON, Date, Text
from piccolo.columns import JSON, Boolean, Date, Text
from piccolo.table import Table
@@ -6,3 +6,9 @@ class SpamLog(Table):
message_text = Text(null=True)
attachments = JSON()
created_at = Date()
class VerifiedUsers(Table):
# Key format: `{message_chat_id}-{message_id}`
key = Text(primary_key=True)
verified = Boolean()

View File

@@ -7,7 +7,8 @@
"privileged": true,
"dependencies": {
"required": {
"standard.config": "^1.0.0"
"standard.config": "^1.0.0",
"standard.database": "^1.0.0"
},
"optional": {
"standard.command_helper": "^1.0.0",

View File

@@ -1,11 +1,22 @@
from typing import TYPE_CHECKING, Type
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Type
from aiogram import Bot, Router
from aiogram.filters import Command
from aiogram import BaseMiddleware, Bot, Router
from aiogram.filters import JOIN_TRANSITION, ChatMemberUpdatedFilter, Command
from karkas_core.modules_system.public_api import get_module, register_router
from karkas_core.modules_system.public_api import (
get_module,
log,
register_outer_chat_member_middleware,
register_outer_message_middleware,
register_router,
)
from .db.tables import SpamLog
if TYPE_CHECKING:
from karkas_blocks.standard.config import IConfig
from aiogram.types import ChatMemberUpdated, Message, TelegramObject
from karkas_blocks.standard.filters import SimpleAdminFilter as ISimpleAdminFilter
from .db.tables import SpamLog, VerifiedUsers
try:
register_command = get_module("standard.command_helper", "register_command")
@@ -14,23 +25,9 @@ except Exception:
COMMAND_HELPER_MODULE_LOADED = False
pass
if TYPE_CHECKING:
from aiogram.types import Message
from karkas_blocks.standard.filters import SimpleAdminFilter as ISimpleAdminFilter
async def spam(message: "Message", bot: "Bot"):
if not message.reply_to_message:
return
if message.reply_to_message.from_user.id in (message.from_user.id, bot.id):
return
spam_message = message.reply_to_message
chat_id = message.chat.id
message.reply_to_message.media_group_id
async def delete_spam(spam_message: "Message", bot: "Bot"):
chat_id = spam_message.chat.id
attachments = {
"version": 1,
@@ -52,11 +49,6 @@ async def spam(message: "Message", bot: "Bot"):
message_id=spam_message.message_id,
)
await bot.delete_message(
chat_id=chat_id,
message_id=message.message_id,
)
await bot.ban_chat_member(
chat_id=chat_id,
user_id=spam_message.from_user.id,
@@ -65,18 +57,172 @@ async def spam(message: "Message", bot: "Bot"):
await SpamLog.insert(spam_log)
async def delete_spam_by_request(message: "Message", bot: "Bot"):
if not message.reply_to_message:
return
if message.reply_to_message.from_user.id in (message.from_user.id, bot.id):
return
await delete_spam(message.reply_to_message)
await bot.delete_message(
chat_id=message.chat.id,
message_id=message.message_id,
)
config: "IConfig" = get_module("standard.config", "config")
try:
get_user_stats = get_module("standard.statistics", "get_user_stats")
import torch
import torch.nn.functional as F
from transformers import AutoModelForSequenceClassification, AutoTokenizer
USER_STATS_AVAILABLE = True
except Exception:
log("User stats not available, so Spam detection not will work")
USER_STATS_AVAILABLE = False
pass
tokenizer = None
model = None
def spam_detection_init():
if not USER_STATS_AVAILABLE:
return
# Загрузка модели и токенизатора
model_path = config.get("spam::spam_detection::model")
global tokenizer, model
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
def spam_predict(text):
if not USER_STATS_AVAILABLE:
return
# Токенизация текста
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
# Отключаем градиенты для ускорения и экономии памяти
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Преобразование логитов в вероятности с помощью softmax
probabilities = F.softmax(logits, dim=1)
# Находим индекс класса с максимальной вероятностью
predicted_class = torch.argmax(probabilities, dim=1).item()
# Извлекаем вероятность предсказанного класса
confidence = probabilities[0, predicted_class].item()
# Возвращаем предсказанную метку и уверенность
return (predicted_class == 1, confidence)
class SpamDetectionMiddleware(BaseMiddleware):
async def __call__(
self,
handler: Callable[["TelegramObject", Dict[str, Any]], Awaitable[Any]],
event: "Message",
data: Dict[str, Any],
) -> Any:
if USER_STATS_AVAILABLE and event.text:
stats = await get_user_stats(event.chat.id, event.from_user.id)
verified_users_row = (
await VerifiedUsers.select(VerifiedUsers.verified)
.where(
VerifiedUsers.key == f"{event.chat.id}-{event.from_user.id}",
)
.first()
)
if verified_users_row:
if not verified_users_row["verified"]:
if event.text:
is_spam, confidence = spam_predict(event.text)
log(f"{event.text} is {is_spam} with confidence {confidence}")
if is_spam:
if confidence >= config.get(
"spam::spam_detection::confidence"
):
await delete_spam(event, event.bot)
return
if stats["messages_count"] > config.get(
"spam::spam_detection::check_messages_count"
):
await VerifiedUsers.update(
{VerifiedUsers.verified: True}
).where(
VerifiedUsers.key
== f"{event.chat.id}-{event.from_user.id}",
)
result = await handler(event, data)
return result
joinFilter = ChatMemberUpdatedFilter(JOIN_TRANSITION)
class ChatMemberMiddleware(BaseMiddleware):
async def __call__(
self,
handler: Callable[["TelegramObject", Dict[str, Any]], Awaitable[Any]],
event: "ChatMemberUpdated",
data: Dict[str, Any],
) -> Any:
if await joinFilter(event):
await VerifiedUsers.insert(
VerifiedUsers(
key=f"{event.chat.id}-{event.new_chat_member.user.id}",
verified=False,
)
)
result = await handler(event, data)
return result
def module_init():
config.register("spam::spam_detection::enabled", "boolean", default_value=True)
config.register(
"spam::spam_detection::check_messages_count", "int", default_value=5
)
config.register("spam::spam_detection::confidence", "float", default_value=0.9)
config.register(
"spam::spam_detection::model", "string", default_value="RUSpam/spam_deberta_v4"
)
register_outer_message_middleware(SpamDetectionMiddleware())
register_outer_chat_member_middleware(ChatMemberMiddleware())
register_app_config = get_module("standard.database", "register_app_config")
SimpleAdminFilter: "Type[ISimpleAdminFilter]" = get_module(
"standard.filters", "SimpleAdminFilter"
)
from .db import APP_CONFIG
register_app_config(APP_CONFIG)
router = Router()
router.message.register(spam, SimpleAdminFilter(), Command("spam"))
router.message.register(
delete_spam_by_request, SimpleAdminFilter(), Command("spam")
)
register_router(router)
@@ -87,3 +233,7 @@ def module_init():
long_description="Удалить спам и забанить пользователя. "
"Собирает обезличенные данные для дальнейшего создания алгоритма",
)
def module_late_init():
spam_detection_init()

View File

@@ -5,7 +5,7 @@ from karkas_core.modules_system.public_api import (
register_outer_message_middleware,
)
from .main import StatisticsMiddleware
from .main import StatisticsMiddleware, get_user_stats
if TYPE_CHECKING:
from karkas_blocks.standard.config import IConfig

View File

@@ -1,9 +1,9 @@
from piccolo.apps.migrations.auto.migration_manager import MigrationManager
from piccolo.columns.column_types import Date, Integer, Text
from piccolo.columns.defaults.date import DateNow
from piccolo.columns.column_types import Integer, Text, Timestamptz
from piccolo.columns.defaults.timestamptz import TimestamptzNow
from piccolo.columns.indexes import IndexMethod
ID = "2024-08-20T16:28:38:371951"
ID = "2024-10-08T16:38:25:584922"
VERSION = "1.16.0"
DESCRIPTION = ""
@@ -13,18 +13,81 @@ async def forwards():
migration_id=ID, app_name="standard.statistics", description=DESCRIPTION
)
manager.add_table(
class_name="Messages", tablename="messages", schema=None, columns=None
)
manager.add_table(
class_name="ChatStats", tablename="chat_stats", schema=None, columns=None
)
manager.add_table(
class_name="Messages", tablename="messages", schema=None, columns=None
)
manager.add_table(
class_name="UserStats", tablename="user_stats", schema=None, columns=None
)
manager.add_column(
table_class_name="ChatStats",
tablename="chat_stats",
column_name="chat_id",
db_column_name="chat_id",
column_class_name="Integer",
column_class=Integer,
params={
"default": 0,
"null": False,
"primary_key": True,
"unique": False,
"index": False,
"index_method": IndexMethod.btree,
"choices": None,
"db_column_name": None,
"secret": False,
},
schema=None,
)
manager.add_column(
table_class_name="ChatStats",
tablename="chat_stats",
column_name="date",
db_column_name="date",
column_class_name="Timestamptz",
column_class=Timestamptz,
params={
"default": TimestamptzNow(),
"null": False,
"primary_key": False,
"unique": False,
"index": False,
"index_method": IndexMethod.btree,
"choices": None,
"db_column_name": None,
"secret": False,
},
schema=None,
)
manager.add_column(
table_class_name="ChatStats",
tablename="chat_stats",
column_name="messages_count",
db_column_name="messages_count",
column_class_name="Integer",
column_class=Integer,
params={
"default": 0,
"null": False,
"primary_key": False,
"unique": False,
"index": False,
"index_method": IndexMethod.btree,
"choices": None,
"db_column_name": None,
"secret": False,
},
schema=None,
)
manager.add_column(
table_class_name="Messages",
tablename="messages",
@@ -151,69 +214,6 @@ async def forwards():
schema=None,
)
manager.add_column(
table_class_name="ChatStats",
tablename="chat_stats",
column_name="chat_id",
db_column_name="chat_id",
column_class_name="Integer",
column_class=Integer,
params={
"default": 0,
"null": False,
"primary_key": True,
"unique": False,
"index": False,
"index_method": IndexMethod.btree,
"choices": None,
"db_column_name": None,
"secret": False,
},
schema=None,
)
manager.add_column(
table_class_name="ChatStats",
tablename="chat_stats",
column_name="date",
db_column_name="date",
column_class_name="Date",
column_class=Date,
params={
"default": DateNow(),
"null": False,
"primary_key": False,
"unique": False,
"index": False,
"index_method": IndexMethod.btree,
"choices": None,
"db_column_name": None,
"secret": False,
},
schema=None,
)
manager.add_column(
table_class_name="ChatStats",
tablename="chat_stats",
column_name="messages_count",
db_column_name="messages_count",
column_class_name="Integer",
column_class=Integer,
params={
"default": 0,
"null": False,
"primary_key": False,
"unique": False,
"index": False,
"index_method": IndexMethod.btree,
"choices": None,
"db_column_name": None,
"secret": False,
},
schema=None,
)
manager.add_column(
table_class_name="UserStats",
tablename="user_stats",
@@ -282,10 +282,10 @@ async def forwards():
tablename="user_stats",
column_name="date",
db_column_name="date",
column_class_name="Date",
column_class=Date,
column_class_name="Timestamptz",
column_class=Timestamptz,
params={
"default": DateNow(),
"default": TimestamptzNow(),
"null": False,
"primary_key": False,
"unique": False,

View File

@@ -1,10 +1,10 @@
from piccolo.columns import Date, Integer, Text
from piccolo.columns import Integer, Text, Timestamptz
from piccolo.table import Table
class ChatStats(Table):
chat_id = Integer(primary_key=True)
date = Date()
date = Timestamptz()
messages_count = Integer(default=0)
@@ -29,5 +29,5 @@ class UserStats(Table):
chat_id = Integer()
user_id = Integer()
date = Date()
date = Timestamptz()
messages_count = Integer(default=0)

File diff suppressed because it is too large Load Diff

View File

@@ -17,9 +17,17 @@ dash = "^2.17.1"
dash-extensions = "^1.0.18"
dash-bootstrap-components = "^1.6.0"
torch = { version = "^2.0", source="pytorch-cpu" }
transformers = "^4.45.2"
[tool.poetry-monorepo.deps]
enabled = true
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[[tool.poetry.source]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
priority = "explicit"