import time from typing import Final import re from telegram import Update from telegram.ext import Application, MessageHandler, filters, ContextTypes from typing import Optional import random import os from dotenv import load_dotenv from models.seq2seq.model import Seq2SeqChatbot import torch load_dotenv() TOKEN: Final = os.environ.get("TOKEN") BOT_USERNAME: Final = os.environ.get("BOT_USERNAME") CHAT_ID: Final = int(os.environ.get("CHAT_ID")) CHECKPOINT_PATH: Final = "models/seq2seq/checkpoint/150_checkpoint.tar" ROMANTIKI_GIF_ID: Final = "CgACAgIAAxkBAAE4zMlojLmMwqrxG5e2rnYS2f9_PZZgVwACL2oAAjbWyUqiyR5II6u6YDYE" BEZUMTSI_GIF_ID: Final = "CgACAgIAAxkBAAE4zMlojLmMwqrxG5e2rnYS2f9_PZZgVwACL2oAAjbWyUqiyR5II6u6YDYE" last_gif_sent = 1.0 gif_sent_cooldown = 180.0 response_chance = 1.0 torch.manual_seed(0) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") chatbot = Seq2SeqChatbot(500, 8856, 2, 2, 0.1, device) chatbot.load_checkpoint(CHECKPOINT_PATH) chatbot.eval_mode() def handle_response(text: str) -> Optional[str]: if random.random() < response_chance: return chatbot(text) return None def edit_response(text: Optional[str]) -> Optional[str]: if text is None: return None text = re.sub(r'\s+([,.!?;])\s+', r'\1 ', text) return text async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE): if update.message.chat_id == CHAT_ID: global last_gif_sent if "роман" in update.message.text.lower() and \ time.time() - last_gif_sent >= gif_sent_cooldown: await context.bot.send_animation( chat_id=update.message.chat_id, animation=ROMANTIKI_GIF_ID) last_gif_sent = time.time() elif "безу" in update.message.text.lower() and \ time.time() - last_gif_sent >= gif_sent_cooldown: await context.bot.send_animation(chat_id=update.message.chat_id, animation=BEZUMTSI_GIF_ID) last_gif_sent = time.time() else: text = update.message.text.replace(BOT_USERNAME, '').strip().lower() response = edit_response(handle_response(text)) if response: await context.bot.sendMessage(update.message.chat_id, response, reply_to_message_id=update.message.id) async def error(update: Update, context: ContextTypes.DEFAULT_TYPE): print(f"{update.message.from_user.username} in {update.message.chat.type} " f"chat caused error \"{context.error}\"\n" f"{update}\"") def main() -> None: """Run the bot.""" application = Application.builder().token(TOKEN).build() application.add_handler(MessageHandler(filters.TEXT, handle_message)) application.add_error_handler(error) application.run_polling(allowed_updates=Update.ALL_TYPES, drop_pending_updates=True) if __name__ == '__main__': print("Running main...") main()