Spaces:
Sleeping
Sleeping
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import gradio as gr | |
# Load pre-trained model | |
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") | |
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium") | |
# Global chat history | |
chat_history_ids = None | |
def chat(user_input, history=[]): | |
global chat_history_ids | |
# Tokenize user input | |
new_input_ids = tokenizer.encode(user_input + tokenizer.eos_token, return_tensors='pt') | |
# Append to chat history | |
bot_input_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1) if chat_history_ids is not None else new_input_ids | |
# Generate response with controlled output | |
chat_history_ids = model.generate( | |
bot_input_ids, | |
max_length=500, # shorter for safety | |
pad_token_id=tokenizer.eos_token_id, | |
do_sample=True, | |
top_k=50, | |
top_p=0.95, | |
temperature=0.7 | |
) | |
# Decode model output | |
response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True) | |
# Append to chat history | |
history.append((user_input, response)) | |
return history, history | |
# Create a Gradio ChatInterface | |
chatbot_ui = gr.ChatInterface( | |
fn=chat, | |
title="Teen Mental Health Chatbot π€π¬", | |
description="Talk to a supportive AI. Not a replacement for professional help.", | |
) | |
# Launch the app (required!) | |
if __name__ == "__main__": | |
chatbot_ui.launch() | |