Spaces:
Running
Running
File size: 1,895 Bytes
74d4655 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
from typing import Final
from telegram import Update
from telegram.ext import Application, MessageHandler, filters, ContextTypes
from typing import Optional
import random
import os
from dotenv import load_dotenv
from models.seq2seq.model import Seq2SeqChatbot
import torch
load_dotenv()
TOKEN: Final = os.environ.get("TOKEN")
BOT_USERNAME: Final = os.environ.get("BOT_USERNAME")
CHAT_ID: Final = int(os.environ.get("CHAT_ID"))
CHECKPOINT_PATH: Final = "models/seq2seq/checkpoint/150_checkpoint.tar"
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
chatbot = Seq2SeqChatbot(500, 2, 2, 0.1, device)
chatbot.load_checkpoint(CHECKPOINT_PATH)
chatbot.eval_mode()
def handle_response(text: str) -> Optional[str]:
response_chance = 1.0
if random.random() < response_chance:
return chatbot(text)
return None
async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE):
if update.message.chat_id == CHAT_ID:
text: str = update.message.text.replace(BOT_USERNAME, '').strip().lower()
response: Optional[str] = handle_response(text)
if response:
await context.bot.sendMessage(update.message.chat_id, response)
async def error(update: Update, context: ContextTypes.DEFAULT_TYPE):
print(f"{update.message.from_user.username} in {update.message.chat.type} "
f"chat caused error \"{context.error}\"\n"
f"{update}\"")
def main() -> None:
"""Run the bot."""
application = Application.builder().token(TOKEN).build()
application.add_handler(MessageHandler(filters.TEXT, handle_message))
application.add_error_handler(error)
application.run_polling(allowed_updates=Update.ALL_TYPES)
if __name__ == '__main__':
print("Running main...")
# print(chatbot("test"))
main() |