frantics-bot / main.py
teepoat
Initial commit
74d4655 unverified
raw
history blame
1.9 kB
from typing import Final
from telegram import Update
from telegram.ext import Application, MessageHandler, filters, ContextTypes
from typing import Optional
import random
import os
from dotenv import load_dotenv
from models.seq2seq.model import Seq2SeqChatbot
import torch
load_dotenv()
TOKEN: Final = os.environ.get("TOKEN")
BOT_USERNAME: Final = os.environ.get("BOT_USERNAME")
CHAT_ID: Final = int(os.environ.get("CHAT_ID"))
CHECKPOINT_PATH: Final = "models/seq2seq/checkpoint/150_checkpoint.tar"
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
chatbot = Seq2SeqChatbot(500, 2, 2, 0.1, device)
chatbot.load_checkpoint(CHECKPOINT_PATH)
chatbot.eval_mode()
def handle_response(text: str) -> Optional[str]:
response_chance = 1.0
if random.random() < response_chance:
return chatbot(text)
return None
async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE):
if update.message.chat_id == CHAT_ID:
text: str = update.message.text.replace(BOT_USERNAME, '').strip().lower()
response: Optional[str] = handle_response(text)
if response:
await context.bot.sendMessage(update.message.chat_id, response)
async def error(update: Update, context: ContextTypes.DEFAULT_TYPE):
print(f"{update.message.from_user.username} in {update.message.chat.type} "
f"chat caused error \"{context.error}\"\n"
f"{update}\"")
def main() -> None:
"""Run the bot."""
application = Application.builder().token(TOKEN).build()
application.add_handler(MessageHandler(filters.TEXT, handle_message))
application.add_error_handler(error)
application.run_polling(allowed_updates=Update.ALL_TYPES)
if __name__ == '__main__':
print("Running main...")
# print(chatbot("test"))
main()