Spaces:
Sleeping
Sleeping
File size: 2,613 Bytes
93f34b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from termcolor import colored
# --- Model and Tokenizer Loading ---
MODEL_PATH = "01/medical_model_rl/final"
TOKENIZER_PATH = "01/medical_model_rl/final"
print("Loading model and tokenizer...")
try:
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, padding_side='left')
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH)
model.resize_token_embeddings(len(tokenizer))
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
print(colored("Model loaded successfully.", "green"))
except Exception as e:
print(colored(f"Error loading model: {e}", "red"))
model = None
tokenizer = None
# --- Chatbot Inference Function ---
def medical_chatbot(message, history):
"""
Generates a response from the medical chatbot model.
"""
if not model or not tokenizer:
return "Error: Model is not loaded. Please check the console for errors."
try:
# Format the prompt
full_prompt = f"Question: {message}\n\nAnswer:"
# Tokenize the input
inputs = tokenizer(full_prompt, return_tensors="pt", truncation=True).to(device)
# Generate a response
with torch.no_grad():
output_sequences = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=128,
do_sample=True,
top_k=50,
top_p=0.95,
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id,
)
# Decode the response
response_text = tokenizer.decode(output_sequences[0], skip_special_tokens=True)
# Extract only the answer part
answer = response_text.split("Answer:")[-1].strip()
return answer
except Exception as e:
print(colored(f"An error occurred during inference: {e}", "red"))
return "Sorry, I encountered an error. Please try again."
# --- Gradio UI ---
chatbot_interface = gr.ChatInterface(
fn=medical_chatbot,
title="Medical Chatbot",
description="Ask any medical question, and the AI will try to answer.",
examples=[
["What are the symptoms of diabetes?"],
["How does metformin work?"],
["What is the difference between a virus and a bacteria?"],
],
theme="soft",
).launch(share=True) |