Medical-Bot / app.py
Sanchit2207's picture
Update app.py
06bb798 verified
raw
history blame
1.47 kB
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import gradio as gr
# Load pre-trained model
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
# Global chat history
chat_history_ids = None
def chat(user_input, history=[]):
global chat_history_ids
# Tokenize user input
new_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt')
# Append to chat history
bot_input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1) if chat_history_ids is not None else new_input_ids
# Generate response with controlled output
chat_history_ids = model.generate(
bot_input_ids,
max_length=500, # shorter for safety
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
top_k=50,
top_p=0.95,
temperature=0.7
)
# Decode model output
response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
# Append to chat history
history.append((user_input, response))
return history, history
# Create a Gradio ChatInterface
chatbot_ui = gr.ChatInterface(
fn=chat,
title="Teen Mental Health Chatbot πŸ€–πŸ’¬",
description="Talk to a supportive AI. Not a replacement for professional help.",
)
# Launch the app (required!)
if __name__ == "__main__":
chatbot_ui.launch()