Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from typing import Annotated, List, Dict, Any | |
from typing_extensions import TypedDict | |
from langgraph.graph import StateGraph, START | |
from langgraph.graph.message import add_messages | |
# Model configuration | |
LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf" | |
MEDITRON_MODEL = "epfl-llm/meditron-7b" | |
SYSTEM_PROMPT = """You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition, symptoms, medical history, medications, lifestyle, and other relevant data. | |
Ask 1-2 follow-up questions at a time to gather more details about: | |
- Detailed description of symptoms | |
- Duration (when did it start?) | |
- Severity (scale of 1-10) | |
- Aggravating or alleviating factors | |
- Related symptoms | |
- Medical history | |
- Current medications and allergies | |
After collecting sufficient information (4-5 exchanges), summarize findings and suggest when they should seek professional care. Do NOT make specific diagnoses or recommend specific treatments. | |
Respond empathetically and clearly. Always be professional and thorough.""" | |
MEDITRON_PROMPT = """<|im_start|>system | |
You are a specialized medical assistant focusing ONLY on suggesting over-the-counter medicines and home remedies based on patient information. | |
Based on the following patient information, provide ONLY: | |
1. One specific over-the-counter medicine with proper adult dosing instructions | |
2. One practical home remedy that might help | |
3. Clear guidance on when to seek professional medical care | |
Be concise, practical, and focus only on general symptom relief. Do not diagnose. Include a disclaimer that you are not a licensed medical professional. | |
<|im_end|> | |
<|im_start|>user | |
Patient information: {patient_info} | |
<|im_end|> | |
<|im_start|>assistant | |
""" | |
print("Loading Llama-2 model...") | |
tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
model = AutoModelForCausalLM.from_pretrained( | |
LLAMA_MODEL, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
print("Llama-2 model loaded successfully!") | |
print("Loading Meditron model...") | |
meditron_tokenizer = AutoTokenizer.from_pretrained(MEDITRON_MODEL) | |
if meditron_tokenizer.pad_token is None: | |
meditron_tokenizer.pad_token = meditron_tokenizer.eos_token | |
meditron_model = AutoModelForCausalLM.from_pretrained( | |
MEDITRON_MODEL, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
print("Meditron model loaded successfully!") | |
# Define the state for our LangGraph | |
class ChatbotState(TypedDict): | |
messages: Annotated[List, add_messages] | |
turn_count: int | |
patient_info: List[str] | |
# Function to build Llama-2 prompt | |
def build_llama2_prompt(messages): | |
"""Format the conversation history for Llama-2 chat models.""" | |
prompt = f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<</SYS>>\n\n" | |
# Add conversation history | |
for i, msg in enumerate(messages[:-1]): | |
if i % 2 == 0: # User message | |
prompt += f"{msg.content} [/INST] " | |
else: # Assistant message | |
prompt += f"{msg.content} </s><s>[INST] " | |
# Add the current user input | |
prompt += f"{messages[-1].content} [/INST] " | |
return prompt | |
# Function to get Llama-2 response | |
def get_llama2_response(prompt, turn_count): | |
"""Generate response from Llama-2 model.""" | |
# Add summarization instruction after 4 turns | |
if turn_count >= 4: | |
prompt = prompt.replace("[/INST] ", "[/INST] Now summarize what you've learned and suggest when professional care may be needed. ") | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=512, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
pad_token_id=tokenizer.pad_token_id | |
) | |
full_response = tokenizer.decode(outputs[0], skip_special_tokens=False) | |
response = full_response.split('[/INST]')[-1].split('</s>')[0].strip() | |
return response | |
# Function to get Meditron suggestions | |
def get_meditron_suggestions(patient_info): | |
"""Generate medicine and remedy suggestions from Meditron model.""" | |
prompt = MEDITRON_PROMPT.format(patient_info=patient_info) | |
inputs = meditron_tokenizer(prompt, return_tensors="pt").to(meditron_model.device) | |
with torch.no_grad(): | |
outputs = meditron_model.generate( | |
inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=256, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
pad_token_id=meditron_tokenizer.pad_token_id | |
) | |
suggestion = meditron_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) | |
return suggestion | |
# Define LangGraph nodes | |
def process_user_input(state: ChatbotState) -> ChatbotState: | |
"""Process user input and update state.""" | |
# Extract the latest user message | |
user_message = state["messages"][-1].content | |
# Update patient info | |
return { | |
"patient_info": state["patient_info"] + [user_message], | |
"turn_count": state["turn_count"] + 1 | |
} | |
def generate_llama_response(state: ChatbotState) -> ChatbotState: | |
"""Generate response using Llama-2 model.""" | |
prompt = build_llama2_prompt(state["messages"]) | |
response = get_llama2_response(prompt, state["turn_count"]) | |
return {"messages": [{"role": "assistant", "content": response}]} | |
def check_turn_count(state: ChatbotState) -> str: | |
"""Check if we need to add medicine suggestions.""" | |
if state["turn_count"] >= 4: | |
return "add_suggestions" | |
return "continue" | |
def add_medicine_suggestions(state: ChatbotState) -> ChatbotState: | |
"""Add medicine suggestions from Meditron model.""" | |
# Get the last assistant response | |
last_response = state["messages"][-1].content | |
# Collect full patient conversation | |
full_patient_info = "\n".join(state["patient_info"]) + "\n\nSummary: " + last_response | |
# Get medicine suggestions | |
medicine_suggestions = get_meditron_suggestions(full_patient_info) | |
# Format final response | |
final_response = ( | |
f"{last_response}\n\n" | |
f"--- MEDICATION AND HOME CARE SUGGESTIONS ---\n\n" | |
f"{medicine_suggestions}" | |
) | |
# Return updated message | |
return {"messages": [{"role": "assistant", "content": final_response}]} | |
# Build the LangGraph | |
def build_graph(): | |
"""Build and return the LangGraph for our chatbot.""" | |
graph = StateGraph(ChatbotState) | |
# Add nodes | |
graph.add_node("process_input", process_user_input) | |
graph.add_node("generate_response", generate_llama_response) | |
graph.add_node("add_suggestions", add_medicine_suggestions) | |
# Add edges | |
graph.add_edge(START, "process_input") | |
graph.add_edge("process_input", "generate_response") | |
graph.add_conditional_edges( | |
"generate_response", | |
check_turn_count, | |
{ | |
"add_suggestions": "add_suggestions", | |
"continue": END | |
} | |
) | |
graph.add_edge("add_suggestions", END) | |
return graph.compile() | |
# Initialize the graph | |
chatbot_graph = build_graph() | |
# Function for Gradio interface | |
def chat_response(message, history): | |
"""Generate chatbot response using LangGraph.""" | |
# Initialize state if this is the first message | |
if not history: | |
state = { | |
"messages": [{"role": "user", "content": message}], | |
"turn_count": 0, | |
"patient_info": [] | |
} | |
else: | |
# Convert history to messages format | |
messages = [] | |
for user_msg, bot_msg in history: | |
messages.append({"role": "user", "content": user_msg}) | |
messages.append({"role": "assistant", "content": bot_msg}) | |
# Add current message | |
messages.append({"role": "user", "content": message}) | |
# Get turn count from history | |
turn_count = len(history) | |
# Build patient info from history | |
patient_info = [user_msg for user_msg, _ in history] | |
state = { | |
"messages": messages, | |
"turn_count": turn_count, | |
"patient_info": patient_info | |
} | |
# Process through LangGraph | |
result = chatbot_graph.invoke(state) | |
# Return the latest assistant message | |
return result["messages"][-1].content | |
# Create the Gradio interface | |
demo = gr.ChatInterface( | |
fn=chat_response, | |
title="Medical Assistant with LangGraph", | |
description="Tell me about your symptoms, and after gathering enough information, I'll suggest potential remedies.", | |
examples=[ | |
"I have a cough and my throat hurts", | |
"I've been having headaches for a week", | |
"My stomach has been hurting since yesterday" | |
], | |
theme="soft" | |
) | |
if __name__ == "__main__": | |
demo.launch() | |