Spaces:
Running
Running
import time | |
from typing import Final | |
import re | |
from telegram import Update | |
from telegram.ext import Application, MessageHandler, filters, ContextTypes | |
from typing import Optional | |
import random | |
import os | |
from dotenv import load_dotenv | |
from models.seq2seq.model import Seq2SeqChatbot | |
import torch | |
load_dotenv() | |
TOKEN: Final = os.environ.get("TOKEN") | |
BOT_USERNAME: Final = os.environ.get("BOT_USERNAME") | |
CHAT_ID: Final = int(os.environ.get("CHAT_ID")) | |
CHECKPOINT_PATH: Final = "models/seq2seq/checkpoint/150_checkpoint.tar" | |
ROMANTIKI_GIF_ID: Final = "CgACAgIAAxkBAAE4zMlojLmMwqrxG5e2rnYS2f9_PZZgVwACL2oAAjbWyUqiyR5II6u6YDYE" | |
BEZUMTSI_GIF_ID: Final = "CgACAgIAAxkBAAE4zMlojLmMwqrxG5e2rnYS2f9_PZZgVwACL2oAAjbWyUqiyR5II6u6YDYE" | |
last_gif_sent = 1.0 | |
gif_sent_cooldown = 180.0 | |
response_chance = 1.0 | |
torch.manual_seed(0) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
chatbot = Seq2SeqChatbot(500, 8856, 2, 2, 0.1, device) | |
chatbot.load_checkpoint(CHECKPOINT_PATH) | |
chatbot.eval_mode() | |
def handle_response(text: str) -> Optional[str]: | |
if random.random() < response_chance: | |
return chatbot(text) | |
return None | |
def edit_response(text: Optional[str]) -> Optional[str]: | |
if text is None: | |
return None | |
text = re.sub(r'\s+([,.!?;])\s+', r'\1 ', text) | |
return text | |
async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE): | |
if update.message.chat_id == CHAT_ID: | |
global last_gif_sent | |
if "роман" in update.message.text.lower() and \ | |
time.time() - last_gif_sent >= gif_sent_cooldown: | |
await context.bot.send_animation( chat_id=update.message.chat_id, animation=ROMANTIKI_GIF_ID) | |
last_gif_sent = time.time() | |
elif "безу" in update.message.text.lower() and \ | |
time.time() - last_gif_sent >= gif_sent_cooldown: | |
await context.bot.send_animation(chat_id=update.message.chat_id, animation=BEZUMTSI_GIF_ID) | |
last_gif_sent = time.time() | |
else: | |
text = update.message.text.replace(BOT_USERNAME, '').strip().lower() | |
response = edit_response(handle_response(text)) | |
if response: | |
await context.bot.sendMessage(update.message.chat_id, response, reply_to_message_id=update.message.id) | |
async def error(update: Update, context: ContextTypes.DEFAULT_TYPE): | |
print(f"{update.message.from_user.username} in {update.message.chat.type} " | |
f"chat caused error \"{context.error}\"\n" | |
f"{update}\"") | |
def main() -> None: | |
"""Run the bot.""" | |
application = Application.builder().token(TOKEN).build() | |
application.add_handler(MessageHandler(filters.TEXT, handle_message)) | |
application.add_error_handler(error) | |
application.run_polling(allowed_updates=Update.ALL_TYPES, drop_pending_updates=True) | |
if __name__ == '__main__': | |
print("Running main...") | |
main() | |