File size: 3,548 Bytes
8730736
74d4655
8730736
 
74d4655
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8730736
 
 
 
 
0efd01c
 
8730736
74d4655
 
 
99bafa0
74d4655
 
 
99bafa0
74d4655
8730736
74d4655
 
 
 
 
0efd01c
 
 
8730736
 
 
 
74d4655
 
0efd01c
8730736
 
 
 
0efd01c
8730736
 
 
 
 
 
 
 
 
 
 
0efd01c
 
74d4655
 
 
 
 
 
 
99bafa0
74d4655
 
8730736
 
74d4655
 
 
0efd01c
74d4655
 
 
 
 
 
 
99bafa0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import time
from typing import Final
import requests
import re
from telegram import Update
from telegram.ext import Application, MessageHandler, filters, ContextTypes
from typing import Optional
import random
import os
from dotenv import load_dotenv
from models.seq2seq.model import Seq2SeqChatbot
import torch

load_dotenv()

TOKEN: Final = os.environ.get("TOKEN")
BOT_USERNAME: Final = os.environ.get("BOT_USERNAME")
CHAT_ID: Final = int(os.environ.get("CHAT_ID"))

CHECKPOINT_PATH: Final = "models/seq2seq/checkpoint/150_checkpoint.tar"

romantiki_gif_id = "CgACAgIAAxkBAAE4zMlojLmMwqrxG5e2rnYS2f9_PZZgVwACL2oAAjbWyUqiyR5II6u6YDYE"
bezumtsi_gif_id = "CgACAgIAAxkBAAE4zMtojLmiH_CGW5cT7G0QVXHR7D4g6wAC53UAApkBmEmM-VxqunRc6zYE"

last_maxim_insult = 1.0
last_gif_sent = 1.0
maxim_insult_cooldown = 180.0
gif_sent_cooldown = 180.0

torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

chatbot = Seq2SeqChatbot(500, 8856, 2, 2, 0.1, device)
chatbot.load_checkpoint(CHECKPOINT_PATH)
chatbot.eval_mode()


def handle_response(text: str) -> Optional[str]:
    response_chance = 0.02
    if random.random() < response_chance:
        return chatbot(text)
    return None


def edit_response(text: Optional[str]) -> Optional[str]:
    if text is None:
        return None
    text = re.sub(r'\s+([,.!?;])\s+', r'\1 ', text)
    return text


async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE):
    if update.message.chat_id == CHAT_ID:
        # response: Optional[str] = ""
        global last_maxim_insult, last_gif_sent
        if update.message.from_user.username == "WhoReadThisWillDie" and \
                time.time() - last_maxim_insult >= maxim_insult_cooldown:
            last_maxim_insult = time.time()
            await context.bot.sendMessage(update.message.chat_id, "Максим, иди нахуй", reply_to_message_id=update.message.id)
        elif "роман" in update.message.text.lower() and \
                time.time() - last_gif_sent >= gif_sent_cooldown:
            await context.bot.send_animation( chat_id=update.message.chat_id, animation=romantiki_gif_id)
            last_gif_sent = time.time()
        elif "безу" in update.message.text.lower() and \
                time.time() - last_gif_sent >= gif_sent_cooldown:
            await context.bot.send_animation(chat_id=update.message.chat_id, animation=bezumtsi_gif_id)
            last_gif_sent = time.time()
        else:
            text = update.message.text.replace(BOT_USERNAME, '').strip().lower()
            response = edit_response(handle_response(text))
            if response:
                await context.bot.sendMessage(update.message.chat_id, response, reply_to_message_id=update.message.id)


async def error(update: Update, context: ContextTypes.DEFAULT_TYPE):
    print(f"{update.message.from_user.username} in {update.message.chat.type} "
          f"chat caused error \"{context.error}\"\n"
          f"{update}\"")


def main() -> None:
    """Run the bot."""
    requests.post(f"https://api.telegram.org/bot{TOKEN}/getUpdates?offset=-1")

    application = Application.builder().token(TOKEN).build()

    application.add_handler(MessageHandler(filters.TEXT, handle_message))
    # application.add_error_handler(error)

    application.run_polling(allowed_updates=Update.ALL_TYPES)


if __name__ == '__main__':
    print("Running main...")
    # print(chatbot("test"))
    main()