Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from langgraph.graph import StateGraph, END | |
from typing import TypedDict, List, Tuple | |
import json | |
# Model configuration | |
LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf" | |
MEDITRON_MODEL = "epfl-llm/meditron-7b" | |
SYSTEM_PROMPT = """You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition, symptoms, medical history, medications, lifestyle, and other relevant data. | |
Ask 1-2 follow-up questions at a time to gather more details about: | |
- Detailed description of symptoms | |
- Duration (when did it start?) | |
- Severity (scale of 1-10) | |
- Aggravating or alleviating factors | |
- Related symptoms | |
- Medical history | |
- Current medications and allergies | |
After collecting sufficient information (4-5 exchanges), summarize findings and suggest when they should seek professional care. Do NOT make specific diagnoses or recommend specific treatments. | |
Respond empathetically and clearly. Always be professional and thorough.""" | |
MEDITRON_PROMPT = """<|im_start|>system | |
You are a specialized medical assistant focusing ONLY on suggesting over-the-counter medicines and home remedies based on patient information. | |
Based on the following patient information, provide ONLY: | |
1. One specific over-the-counter medicine with proper adult dosing instructions | |
2. One practical home remedy that might help | |
3. Clear guidance on when to seek professional medical care | |
Be concise, practical, and focus only on general symptom relief. Do not diagnose. Include a disclaimer that you are not a licensed medical professional. | |
<|im_end|> | |
<|im_start|>user | |
Patient information: {patient_info} | |
<|im_end|> | |
<|im_start|>assistant | |
""" | |
# Load models | |
print("Loading Llama-2 model...") | |
tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL) | |
model = AutoModelForCausalLM.from_pretrained( | |
LLAMA_MODEL, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
print("Llama-2 model loaded successfully!") | |
print("Loading Meditron model...") | |
meditron_tokenizer = AutoTokenizer.from_pretrained(MEDITRON_MODEL) | |
meditron_model = AutoModelForCausalLM.from_pretrained( | |
MEDITRON_MODEL, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
print("Meditron model loaded successfully!") | |
# Define the state for LangGraph | |
class ConversationState(TypedDict): | |
messages: List[str] | |
history: List[Tuple[str, str]] | |
current_message: str | |
conversation_turns: int | |
patient_data: List[str] | |
llama_response: str | |
final_response: str | |
should_get_suggestions: bool | |
def build_llama2_prompt(system_prompt, history, user_input): | |
"""Format the conversation history and user input for Llama-2 chat models.""" | |
prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n" | |
# Add conversation history | |
for user_msg, assistant_msg in history: | |
prompt += f"{user_msg} [/INST] {assistant_msg} </s><s>[INST] " | |
# Add the current user input | |
prompt += f"{user_input} [/INST] " | |
return prompt | |
def get_meditron_suggestions(patient_info): | |
"""Use Meditron model to generate medicine and remedy suggestions.""" | |
prompt = MEDITRON_PROMPT.format(patient_info=patient_info) | |
inputs = meditron_tokenizer(prompt, return_tensors="pt").to(meditron_model.device) | |
with torch.no_grad(): | |
outputs = meditron_model.generate( | |
inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=256, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True | |
) | |
suggestion = meditron_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) | |
return suggestion | |
# LangGraph Node Functions | |
def initialize_conversation(state: ConversationState) -> ConversationState: | |
"""Initialize or update conversation state.""" | |
# Update conversation turns | |
state["conversation_turns"] = state.get("conversation_turns", 0) + 1 | |
# Add current message to patient data | |
if "patient_data" not in state: | |
state["patient_data"] = [] | |
state["patient_data"].append(state["current_message"]) | |
# Determine if we should get suggestions (after 4 turns) | |
state["should_get_suggestions"] = state["conversation_turns"] >= 4 | |
return state | |
def generate_llama_response(state: ConversationState) -> ConversationState: | |
"""Generate response using Llama-2 model.""" | |
# Build the prompt with proper Llama-2 formatting | |
prompt = build_llama2_prompt(SYSTEM_PROMPT, state["history"], state["current_message"]) | |
# Add summarization instruction after 4 turns | |
if state["conversation_turns"] >= 4: | |
prompt = prompt.replace("[/INST] ", "[/INST] Now summarize what you've learned and suggest when professional care may be needed. ") | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
# Generate the Llama-2 response | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=512, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
# Decode and extract Llama-2's response | |
full_response = tokenizer.decode(outputs[0], skip_special_tokens=False) | |
llama_response = full_response.split('[/INST]')[-1].split('</s>')[0].strip() | |
state["llama_response"] = llama_response | |
return state | |
def generate_medicine_suggestions(state: ConversationState) -> ConversationState: | |
"""Generate medicine suggestions using Meditron model.""" | |
# Collect full patient conversation | |
full_patient_info = "\n".join(state["patient_data"]) + "\n\nSummary: " + state["llama_response"] | |
# Get medicine suggestions | |
medicine_suggestions = get_meditron_suggestions(full_patient_info) | |
# Format final response | |
final_response = ( | |
f"{state['llama_response']}\n\n" | |
f"--- MEDICATION AND HOME CARE SUGGESTIONS ---\n\n" | |
f"{medicine_suggestions}" | |
) | |
state["final_response"] = final_response | |
return state | |
def finalize_response(state: ConversationState) -> ConversationState: | |
"""Finalize the response without medicine suggestions.""" | |
state["final_response"] = state["llama_response"] | |
return state | |
def should_get_suggestions(state: ConversationState) -> str: | |
"""Conditional edge to determine next step.""" | |
if state["should_get_suggestions"]: | |
return "get_suggestions" | |
else: | |
return "finalize" | |
# Create the LangGraph workflow | |
def create_medical_workflow(): | |
"""Create the LangGraph workflow for medical assistant.""" | |
workflow = StateGraph(ConversationState) | |
# Add nodes | |
workflow.add_node("initialize", initialize_conversation) | |
workflow.add_node("generate_llama", generate_llama_response) | |
workflow.add_node("get_suggestions", generate_medicine_suggestions) | |
workflow.add_node("finalize", finalize_response) | |
# Define the flow | |
workflow.set_entry_point("initialize") | |
workflow.add_edge("initialize", "generate_llama") | |
workflow.add_conditional_edges( | |
"generate_llama", | |
should_get_suggestions, | |
{ | |
"get_suggestions": "get_suggestions", | |
"finalize": "finalize" | |
} | |
) | |
workflow.add_edge("get_suggestions", END) | |
workflow.add_edge("finalize", END) | |
return workflow.compile() | |
# Initialize the workflow | |
medical_workflow = create_medical_workflow() | |
# Conversation state tracking (for Gradio session management) | |
conversation_states = {} | |
def generate_response(message, history): | |
"""Generate a response using the LangGraph workflow.""" | |
session_id = "default-session" | |
# Initialize or get existing conversation state | |
if session_id not in conversation_states: | |
conversation_states[session_id] = { | |
"messages": [], | |
"history": [], | |
"conversation_turns": 0, | |
"patient_data": [] | |
} | |
# Update state with current message and history | |
state = conversation_states[session_id].copy() | |
state["current_message"] = message | |
state["history"] = history | |
# Run the workflow | |
result = medical_workflow.invoke(state) | |
# Update the stored conversation state | |
conversation_states[session_id] = { | |
"messages": result["messages"] if "messages" in result else [], | |
"history": history, | |
"conversation_turns": result["conversation_turns"], | |
"patient_data": result["patient_data"] | |
} | |
return result["final_response"] | |
# Create the Gradio interface | |
demo = gr.ChatInterface( | |
fn=generate_response, | |
title="Medical Assistant with LangGraph & Medicine Suggestions", | |
description="Tell me about your symptoms, and after gathering enough information, I'll suggest potential remedies using an AI workflow.", | |
examples=[ | |
"I have a cough and my throat hurts", | |
"I've been having headaches for a week", | |
"My stomach has been hurting since yesterday" | |
], | |
theme="soft" | |
) | |
if __name__ == "__main__": | |
demo.launch() |