mirror of
https://gitflic.ru/project/alt-gnome/karkas.git
synced 2025-10-08 21:53:15 +03:00
feat: add spam detection
This commit is contained in:
@@ -1 +1 @@
|
||||
from .main import module_init
|
||||
from .main import module_init, module_late_init
|
||||
|
@@ -2,14 +2,14 @@ import os
|
||||
|
||||
from karkas_piccolo.conf.apps import AppConfig
|
||||
|
||||
from .tables import SpamLog
|
||||
from .tables import SpamLog, VerifiedUsers
|
||||
|
||||
CURRENT_DIRECTORY = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
APP_CONFIG = AppConfig(
|
||||
app_name="standard.spam",
|
||||
migrations_folder_path=os.path.join(CURRENT_DIRECTORY, "piccolo_migrations"),
|
||||
table_classes=[SpamLog],
|
||||
table_classes=[SpamLog, VerifiedUsers],
|
||||
migration_dependencies=[],
|
||||
commands=[],
|
||||
)
|
||||
|
@@ -0,0 +1,64 @@
|
||||
from piccolo.apps.migrations.auto.migration_manager import MigrationManager
|
||||
from piccolo.columns.column_types import Boolean, Text
|
||||
from piccolo.columns.indexes import IndexMethod
|
||||
|
||||
ID = "2024-10-08T18:31:17:058814"
|
||||
VERSION = "1.20.0"
|
||||
DESCRIPTION = ""
|
||||
|
||||
|
||||
async def forwards():
|
||||
manager = MigrationManager(
|
||||
migration_id=ID, app_name="standard.spam", description=DESCRIPTION
|
||||
)
|
||||
|
||||
manager.add_table(
|
||||
class_name="VerifiedUsers",
|
||||
tablename="verified_users",
|
||||
schema=None,
|
||||
columns=None,
|
||||
)
|
||||
|
||||
manager.add_column(
|
||||
table_class_name="VerifiedUsers",
|
||||
tablename="verified_users",
|
||||
column_name="key",
|
||||
db_column_name="key",
|
||||
column_class_name="Text",
|
||||
column_class=Text,
|
||||
params={
|
||||
"default": "",
|
||||
"null": False,
|
||||
"primary_key": True,
|
||||
"unique": False,
|
||||
"index": False,
|
||||
"index_method": IndexMethod.btree,
|
||||
"choices": None,
|
||||
"db_column_name": None,
|
||||
"secret": False,
|
||||
},
|
||||
schema=None,
|
||||
)
|
||||
|
||||
manager.add_column(
|
||||
table_class_name="VerifiedUsers",
|
||||
tablename="verified_users",
|
||||
column_name="verified",
|
||||
db_column_name="verified",
|
||||
column_class_name="Boolean",
|
||||
column_class=Boolean,
|
||||
params={
|
||||
"default": False,
|
||||
"null": False,
|
||||
"primary_key": False,
|
||||
"unique": False,
|
||||
"index": False,
|
||||
"index_method": IndexMethod.btree,
|
||||
"choices": None,
|
||||
"db_column_name": None,
|
||||
"secret": False,
|
||||
},
|
||||
schema=None,
|
||||
)
|
||||
|
||||
return manager
|
@@ -1,4 +1,4 @@
|
||||
from piccolo.columns import JSON, Date, Text
|
||||
from piccolo.columns import JSON, Boolean, Date, Text
|
||||
from piccolo.table import Table
|
||||
|
||||
|
||||
@@ -6,3 +6,9 @@ class SpamLog(Table):
|
||||
message_text = Text(null=True)
|
||||
attachments = JSON()
|
||||
created_at = Date()
|
||||
|
||||
|
||||
class VerifiedUsers(Table):
|
||||
# Key format: `{message_chat_id}-{message_id}`
|
||||
key = Text(primary_key=True)
|
||||
verified = Boolean()
|
||||
|
@@ -7,7 +7,8 @@
|
||||
"privileged": true,
|
||||
"dependencies": {
|
||||
"required": {
|
||||
"standard.config": "^1.0.0"
|
||||
"standard.config": "^1.0.0",
|
||||
"standard.database": "^1.0.0"
|
||||
},
|
||||
"optional": {
|
||||
"standard.command_helper": "^1.0.0",
|
||||
|
@@ -1,11 +1,22 @@
|
||||
from typing import TYPE_CHECKING, Type
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Type
|
||||
|
||||
from aiogram import Bot, Router
|
||||
from aiogram.filters import Command
|
||||
from aiogram import BaseMiddleware, Bot, Router
|
||||
from aiogram.filters import JOIN_TRANSITION, ChatMemberUpdatedFilter, Command
|
||||
|
||||
from karkas_core.modules_system.public_api import get_module, register_router
|
||||
from karkas_core.modules_system.public_api import (
|
||||
get_module,
|
||||
log,
|
||||
register_outer_chat_member_middleware,
|
||||
register_outer_message_middleware,
|
||||
register_router,
|
||||
)
|
||||
|
||||
from .db.tables import SpamLog
|
||||
if TYPE_CHECKING:
|
||||
from karkas_blocks.standard.config import IConfig
|
||||
from aiogram.types import ChatMemberUpdated, Message, TelegramObject
|
||||
from karkas_blocks.standard.filters import SimpleAdminFilter as ISimpleAdminFilter
|
||||
|
||||
from .db.tables import SpamLog, VerifiedUsers
|
||||
|
||||
try:
|
||||
register_command = get_module("standard.command_helper", "register_command")
|
||||
@@ -14,23 +25,9 @@ except Exception:
|
||||
COMMAND_HELPER_MODULE_LOADED = False
|
||||
pass
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aiogram.types import Message
|
||||
|
||||
from karkas_blocks.standard.filters import SimpleAdminFilter as ISimpleAdminFilter
|
||||
|
||||
|
||||
async def spam(message: "Message", bot: "Bot"):
|
||||
if not message.reply_to_message:
|
||||
return
|
||||
|
||||
if message.reply_to_message.from_user.id in (message.from_user.id, bot.id):
|
||||
return
|
||||
|
||||
spam_message = message.reply_to_message
|
||||
chat_id = message.chat.id
|
||||
|
||||
message.reply_to_message.media_group_id
|
||||
async def delete_spam(spam_message: "Message", bot: "Bot"):
|
||||
chat_id = spam_message.chat.id
|
||||
|
||||
attachments = {
|
||||
"version": 1,
|
||||
@@ -52,11 +49,6 @@ async def spam(message: "Message", bot: "Bot"):
|
||||
message_id=spam_message.message_id,
|
||||
)
|
||||
|
||||
await bot.delete_message(
|
||||
chat_id=chat_id,
|
||||
message_id=message.message_id,
|
||||
)
|
||||
|
||||
await bot.ban_chat_member(
|
||||
chat_id=chat_id,
|
||||
user_id=spam_message.from_user.id,
|
||||
@@ -65,18 +57,172 @@ async def spam(message: "Message", bot: "Bot"):
|
||||
await SpamLog.insert(spam_log)
|
||||
|
||||
|
||||
async def delete_spam_by_request(message: "Message", bot: "Bot"):
|
||||
if not message.reply_to_message:
|
||||
return
|
||||
|
||||
if message.reply_to_message.from_user.id in (message.from_user.id, bot.id):
|
||||
return
|
||||
|
||||
await delete_spam(message.reply_to_message)
|
||||
|
||||
await bot.delete_message(
|
||||
chat_id=message.chat.id,
|
||||
message_id=message.message_id,
|
||||
)
|
||||
|
||||
|
||||
config: "IConfig" = get_module("standard.config", "config")
|
||||
|
||||
try:
|
||||
get_user_stats = get_module("standard.statistics", "get_user_stats")
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
USER_STATS_AVAILABLE = True
|
||||
except Exception:
|
||||
log("User stats not available, so Spam detection not will work")
|
||||
USER_STATS_AVAILABLE = False
|
||||
pass
|
||||
|
||||
tokenizer = None
|
||||
model = None
|
||||
|
||||
|
||||
def spam_detection_init():
|
||||
if not USER_STATS_AVAILABLE:
|
||||
return
|
||||
|
||||
# Загрузка модели и токенизатора
|
||||
model_path = config.get("spam::spam_detection::model")
|
||||
|
||||
global tokenizer, model
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(model_path)
|
||||
|
||||
|
||||
def spam_predict(text):
|
||||
if not USER_STATS_AVAILABLE:
|
||||
return
|
||||
|
||||
# Токенизация текста
|
||||
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
|
||||
|
||||
# Отключаем градиенты для ускорения и экономии памяти
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
logits = outputs.logits
|
||||
|
||||
# Преобразование логитов в вероятности с помощью softmax
|
||||
probabilities = F.softmax(logits, dim=1)
|
||||
# Находим индекс класса с максимальной вероятностью
|
||||
predicted_class = torch.argmax(probabilities, dim=1).item()
|
||||
# Извлекаем вероятность предсказанного класса
|
||||
confidence = probabilities[0, predicted_class].item()
|
||||
# Возвращаем предсказанную метку и уверенность
|
||||
return (predicted_class == 1, confidence)
|
||||
|
||||
|
||||
class SpamDetectionMiddleware(BaseMiddleware):
|
||||
async def __call__(
|
||||
self,
|
||||
handler: Callable[["TelegramObject", Dict[str, Any]], Awaitable[Any]],
|
||||
event: "Message",
|
||||
data: Dict[str, Any],
|
||||
) -> Any:
|
||||
if USER_STATS_AVAILABLE and event.text:
|
||||
stats = await get_user_stats(event.chat.id, event.from_user.id)
|
||||
|
||||
verified_users_row = (
|
||||
await VerifiedUsers.select(VerifiedUsers.verified)
|
||||
.where(
|
||||
VerifiedUsers.key == f"{event.chat.id}-{event.from_user.id}",
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if verified_users_row:
|
||||
if not verified_users_row["verified"]:
|
||||
if event.text:
|
||||
is_spam, confidence = spam_predict(event.text)
|
||||
log(f"{event.text} is {is_spam} with confidence {confidence}")
|
||||
if is_spam:
|
||||
if confidence >= config.get(
|
||||
"spam::spam_detection::confidence"
|
||||
):
|
||||
await delete_spam(event, event.bot)
|
||||
return
|
||||
if stats["messages_count"] > config.get(
|
||||
"spam::spam_detection::check_messages_count"
|
||||
):
|
||||
await VerifiedUsers.update(
|
||||
{VerifiedUsers.verified: True}
|
||||
).where(
|
||||
VerifiedUsers.key
|
||||
== f"{event.chat.id}-{event.from_user.id}",
|
||||
)
|
||||
|
||||
result = await handler(event, data)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
joinFilter = ChatMemberUpdatedFilter(JOIN_TRANSITION)
|
||||
|
||||
|
||||
class ChatMemberMiddleware(BaseMiddleware):
|
||||
async def __call__(
|
||||
self,
|
||||
handler: Callable[["TelegramObject", Dict[str, Any]], Awaitable[Any]],
|
||||
event: "ChatMemberUpdated",
|
||||
data: Dict[str, Any],
|
||||
) -> Any:
|
||||
|
||||
if await joinFilter(event):
|
||||
await VerifiedUsers.insert(
|
||||
VerifiedUsers(
|
||||
key=f"{event.chat.id}-{event.new_chat_member.user.id}",
|
||||
verified=False,
|
||||
)
|
||||
)
|
||||
|
||||
result = await handler(event, data)
|
||||
return result
|
||||
|
||||
|
||||
def module_init():
|
||||
config.register("spam::spam_detection::enabled", "boolean", default_value=True)
|
||||
|
||||
config.register(
|
||||
"spam::spam_detection::check_messages_count", "int", default_value=5
|
||||
)
|
||||
|
||||
config.register("spam::spam_detection::confidence", "float", default_value=0.9)
|
||||
|
||||
config.register(
|
||||
"spam::spam_detection::model", "string", default_value="RUSpam/spam_deberta_v4"
|
||||
)
|
||||
|
||||
register_outer_message_middleware(SpamDetectionMiddleware())
|
||||
register_outer_chat_member_middleware(ChatMemberMiddleware())
|
||||
|
||||
register_app_config = get_module("standard.database", "register_app_config")
|
||||
SimpleAdminFilter: "Type[ISimpleAdminFilter]" = get_module(
|
||||
"standard.filters", "SimpleAdminFilter"
|
||||
)
|
||||
|
||||
from .db import APP_CONFIG
|
||||
|
||||
register_app_config(APP_CONFIG)
|
||||
|
||||
router = Router()
|
||||
|
||||
router.message.register(spam, SimpleAdminFilter(), Command("spam"))
|
||||
router.message.register(
|
||||
delete_spam_by_request, SimpleAdminFilter(), Command("spam")
|
||||
)
|
||||
|
||||
register_router(router)
|
||||
|
||||
@@ -87,3 +233,7 @@ def module_init():
|
||||
long_description="Удалить спам и забанить пользователя. "
|
||||
"Собирает обезличенные данные для дальнейшего создания алгоритма",
|
||||
)
|
||||
|
||||
|
||||
def module_late_init():
|
||||
spam_detection_init()
|
||||
|
@@ -5,7 +5,7 @@ from karkas_core.modules_system.public_api import (
|
||||
register_outer_message_middleware,
|
||||
)
|
||||
|
||||
from .main import StatisticsMiddleware
|
||||
from .main import StatisticsMiddleware, get_user_stats
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from karkas_blocks.standard.config import IConfig
|
||||
|
@@ -1,9 +1,9 @@
|
||||
from piccolo.apps.migrations.auto.migration_manager import MigrationManager
|
||||
from piccolo.columns.column_types import Date, Integer, Text
|
||||
from piccolo.columns.defaults.date import DateNow
|
||||
from piccolo.columns.column_types import Integer, Text, Timestamptz
|
||||
from piccolo.columns.defaults.timestamptz import TimestamptzNow
|
||||
from piccolo.columns.indexes import IndexMethod
|
||||
|
||||
ID = "2024-08-20T16:28:38:371951"
|
||||
ID = "2024-10-08T16:38:25:584922"
|
||||
VERSION = "1.16.0"
|
||||
DESCRIPTION = ""
|
||||
|
||||
@@ -13,18 +13,81 @@ async def forwards():
|
||||
migration_id=ID, app_name="standard.statistics", description=DESCRIPTION
|
||||
)
|
||||
|
||||
manager.add_table(
|
||||
class_name="Messages", tablename="messages", schema=None, columns=None
|
||||
)
|
||||
|
||||
manager.add_table(
|
||||
class_name="ChatStats", tablename="chat_stats", schema=None, columns=None
|
||||
)
|
||||
|
||||
manager.add_table(
|
||||
class_name="Messages", tablename="messages", schema=None, columns=None
|
||||
)
|
||||
|
||||
manager.add_table(
|
||||
class_name="UserStats", tablename="user_stats", schema=None, columns=None
|
||||
)
|
||||
|
||||
manager.add_column(
|
||||
table_class_name="ChatStats",
|
||||
tablename="chat_stats",
|
||||
column_name="chat_id",
|
||||
db_column_name="chat_id",
|
||||
column_class_name="Integer",
|
||||
column_class=Integer,
|
||||
params={
|
||||
"default": 0,
|
||||
"null": False,
|
||||
"primary_key": True,
|
||||
"unique": False,
|
||||
"index": False,
|
||||
"index_method": IndexMethod.btree,
|
||||
"choices": None,
|
||||
"db_column_name": None,
|
||||
"secret": False,
|
||||
},
|
||||
schema=None,
|
||||
)
|
||||
|
||||
manager.add_column(
|
||||
table_class_name="ChatStats",
|
||||
tablename="chat_stats",
|
||||
column_name="date",
|
||||
db_column_name="date",
|
||||
column_class_name="Timestamptz",
|
||||
column_class=Timestamptz,
|
||||
params={
|
||||
"default": TimestamptzNow(),
|
||||
"null": False,
|
||||
"primary_key": False,
|
||||
"unique": False,
|
||||
"index": False,
|
||||
"index_method": IndexMethod.btree,
|
||||
"choices": None,
|
||||
"db_column_name": None,
|
||||
"secret": False,
|
||||
},
|
||||
schema=None,
|
||||
)
|
||||
|
||||
manager.add_column(
|
||||
table_class_name="ChatStats",
|
||||
tablename="chat_stats",
|
||||
column_name="messages_count",
|
||||
db_column_name="messages_count",
|
||||
column_class_name="Integer",
|
||||
column_class=Integer,
|
||||
params={
|
||||
"default": 0,
|
||||
"null": False,
|
||||
"primary_key": False,
|
||||
"unique": False,
|
||||
"index": False,
|
||||
"index_method": IndexMethod.btree,
|
||||
"choices": None,
|
||||
"db_column_name": None,
|
||||
"secret": False,
|
||||
},
|
||||
schema=None,
|
||||
)
|
||||
|
||||
manager.add_column(
|
||||
table_class_name="Messages",
|
||||
tablename="messages",
|
||||
@@ -151,69 +214,6 @@ async def forwards():
|
||||
schema=None,
|
||||
)
|
||||
|
||||
manager.add_column(
|
||||
table_class_name="ChatStats",
|
||||
tablename="chat_stats",
|
||||
column_name="chat_id",
|
||||
db_column_name="chat_id",
|
||||
column_class_name="Integer",
|
||||
column_class=Integer,
|
||||
params={
|
||||
"default": 0,
|
||||
"null": False,
|
||||
"primary_key": True,
|
||||
"unique": False,
|
||||
"index": False,
|
||||
"index_method": IndexMethod.btree,
|
||||
"choices": None,
|
||||
"db_column_name": None,
|
||||
"secret": False,
|
||||
},
|
||||
schema=None,
|
||||
)
|
||||
|
||||
manager.add_column(
|
||||
table_class_name="ChatStats",
|
||||
tablename="chat_stats",
|
||||
column_name="date",
|
||||
db_column_name="date",
|
||||
column_class_name="Date",
|
||||
column_class=Date,
|
||||
params={
|
||||
"default": DateNow(),
|
||||
"null": False,
|
||||
"primary_key": False,
|
||||
"unique": False,
|
||||
"index": False,
|
||||
"index_method": IndexMethod.btree,
|
||||
"choices": None,
|
||||
"db_column_name": None,
|
||||
"secret": False,
|
||||
},
|
||||
schema=None,
|
||||
)
|
||||
|
||||
manager.add_column(
|
||||
table_class_name="ChatStats",
|
||||
tablename="chat_stats",
|
||||
column_name="messages_count",
|
||||
db_column_name="messages_count",
|
||||
column_class_name="Integer",
|
||||
column_class=Integer,
|
||||
params={
|
||||
"default": 0,
|
||||
"null": False,
|
||||
"primary_key": False,
|
||||
"unique": False,
|
||||
"index": False,
|
||||
"index_method": IndexMethod.btree,
|
||||
"choices": None,
|
||||
"db_column_name": None,
|
||||
"secret": False,
|
||||
},
|
||||
schema=None,
|
||||
)
|
||||
|
||||
manager.add_column(
|
||||
table_class_name="UserStats",
|
||||
tablename="user_stats",
|
||||
@@ -282,10 +282,10 @@ async def forwards():
|
||||
tablename="user_stats",
|
||||
column_name="date",
|
||||
db_column_name="date",
|
||||
column_class_name="Date",
|
||||
column_class=Date,
|
||||
column_class_name="Timestamptz",
|
||||
column_class=Timestamptz,
|
||||
params={
|
||||
"default": DateNow(),
|
||||
"default": TimestamptzNow(),
|
||||
"null": False,
|
||||
"primary_key": False,
|
||||
"unique": False,
|
@@ -1,10 +1,10 @@
|
||||
from piccolo.columns import Date, Integer, Text
|
||||
from piccolo.columns import Integer, Text, Timestamptz
|
||||
from piccolo.table import Table
|
||||
|
||||
|
||||
class ChatStats(Table):
|
||||
chat_id = Integer(primary_key=True)
|
||||
date = Date()
|
||||
date = Timestamptz()
|
||||
messages_count = Integer(default=0)
|
||||
|
||||
|
||||
@@ -29,5 +29,5 @@ class UserStats(Table):
|
||||
|
||||
chat_id = Integer()
|
||||
user_id = Integer()
|
||||
date = Date()
|
||||
date = Timestamptz()
|
||||
messages_count = Integer(default=0)
|
||||
|
1769
src/karkas_blocks/poetry.lock
generated
1769
src/karkas_blocks/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -17,9 +17,17 @@ dash = "^2.17.1"
|
||||
dash-extensions = "^1.0.18"
|
||||
dash-bootstrap-components = "^1.6.0"
|
||||
|
||||
torch = { version = "^2.0", source="pytorch-cpu" }
|
||||
transformers = "^4.45.2"
|
||||
|
||||
[tool.poetry-monorepo.deps]
|
||||
enabled = true
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[[tool.poetry.source]]
|
||||
name = "pytorch-cpu"
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
priority = "explicit"
|
||||
|
Reference in New Issue
Block a user