File size: 4,341 Bytes
b80af5b
 
1cf7fb2
b80af5b
aca454d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cf7fb2
 
 
 
 
b80af5b
1cf7fb2
 
 
b80af5b
1cf7fb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b80af5b
1cf7fb2
 
 
b80af5b
1cf7fb2
 
 
 
 
 
aca454d
 
 
 
1cf7fb2
aca454d
 
 
 
1cf7fb2
 
aca454d
1cf7fb2
 
 
 
 
 
 
 
aca454d
1cf7fb2
b80af5b
1cf7fb2
 
aca454d
 
b80af5b
1cf7fb2
 
 
 
 
 
b80af5b
1cf7fb2
 
 
 
b80af5b
 
1cf7fb2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer

# Define the medical assistant system prompt
SYSTEM_PROMPT = """
You are a knowledgeable medical assistant. Follow these steps in order:

1) INITIAL ASSESSMENT: First, warmly greet the user and ask about their primary concern.

2) ASK FOLLOW-UP QUESTIONS: For any health concern mentioned, systematically gather information by asking 1-2 specific follow-up questions at a time about:
   - Detailed description of symptoms
   - Duration (when did it start?)
   - Severity (scale of 1-10)
   - Aggravating or alleviating factors
   - Related symptoms
   - Medical history
   - Current medications and allergies
   - Family history of similar conditions

3) SUMMARIZE FINDINGS: Once you have gathered sufficient information (at least 4-5 exchanges with the user), organize what you've learned into clear categories:
   - Symptoms
   - Duration
   - Severity
   - Possible Causes
   - Medications/Allergies
   - Family History

4) PROVIDE RECOMMENDATIONS: Only after gathering comprehensive information, suggest:
   - One specific OTC medicine with proper adult dosing
   - One practical home remedy
   - When they should seek professional medical care

5) END WITH DISCLAIMER: Always end with a clear medical disclaimer that you are not a licensed medical professional and your suggestions are not a substitute for professional medical advice.

IMPORTANT: Do not skip ahead to recommendations without gathering comprehensive information through multiple exchanges. Your primary goal is information gathering through thoughtful questions.
"""

# Define model options
MODELS = {
    "TinyLlama-1.1B": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    "Llama-2-7b": "meta-llama/Llama-2-7b-chat-hf"
}

# Global variables to store loaded models and tokenizers
loaded_models = {}
loaded_tokenizers = {}

def load_model(model_name):
    """Load model and tokenizer if not already loaded"""
    if model_name not in loaded_models:
        print(f"Loading {model_name}...")
        model_path = MODELS[model_name]
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype="auto",
            device_map="auto"  # Use GPU if available
        )
        loaded_models[model_name] = model
        loaded_tokenizers[model_name] = tokenizer
        print(f"{model_name} loaded successfully!")
    return loaded_models[model_name], loaded_tokenizers[model_name]

# Pre-load the smaller model to start with
print("Pre-loading TinyLlama model...")
load_model("TinyLlama-1.1B")

@spaces.GPU  # Required by ZeroGPU!
def generate_response(message, history, model_choice):
    """Generate a response from the selected model"""
    # Load the selected model if not already loaded
    model, tokenizer = load_model(model_choice)
    
    # Format the prompt based on the history and system prompt
    formatted_prompt = SYSTEM_PROMPT + "\n\n"
    
    # Add conversation history
    for human, assistant in history:
        formatted_prompt += f"User: {human}\nAssistant: {assistant}\n"
    
    # Add the current message
    formatted_prompt += f"User: {message}\nAssistant:"
    
    # Generate the response
    inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(
        inputs["input_ids"],
        max_new_tokens=512,
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
    )
    response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
    
    return response.strip()

# Create the Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Medical Assistant Chatbot")
    gr.Markdown("This chatbot uses LLM models to provide medical information and assistance. Please note that this is not a substitute for professional medical advice.")
    
    with gr.Row():
        model_dropdown = gr.Dropdown(
            choices=list(MODELS.keys()),
            value="TinyLlama-1.1B",
            label="Select Model"
        )
    
    chatbot = gr.ChatInterface(
        fn=lambda message, history, model_choice: generate_response(message, history, model_choice),
        additional_inputs=[model_dropdown],
    )

if __name__ == "__main__":
    demo.launch()