File size: 3,419 Bytes
b80af5b
 
9f6ac99
 
b80af5b
6d5190c
aca454d
9f6ac99
d5f0232
9f6ac99
 
 
 
 
 
 
 
 
 
aca454d
9f6ac99
aca454d
9f6ac99
 
 
 
 
 
 
 
aca454d
9f6ac99
 
aca454d
9f6ac99
 
 
645c015
9f6ac99
 
 
645c015
9f6ac99
 
645c015
9f6ac99
645c015
9f6ac99
 
 
 
8b29c0d
9f6ac99
 
 
 
 
 
 
 
 
 
 
 
8b29c0d
9f6ac99
 
 
 
 
 
 
 
 
8b29c0d
9f6ac99
 
 
 
 
 
b80af5b
8b29c0d
6d5190c
9f6ac99
 
 
8b29c0d
 
 
 
 
9f6ac99
6d5190c
b80af5b
 
9f6ac99
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_NAME = "meta-llama/Llama-2-7b-chat-hf"

SYSTEM_PROMPT = """You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition, symptoms, medical history, medications, lifestyle, and other relevant data.

Ask 1-2 follow-up questions at a time to gather more details about:
- Detailed description of symptoms
- Duration (when did it start?)
- Severity (scale of 1-10)
- Aggravating or alleviating factors
- Related symptoms
- Medical history
- Current medications and allergies

After collecting sufficient information (4-5 exchanges), summarize findings and suggest when they should seek professional care. Do NOT make specific diagnoses or recommend specific treatments.

Respond empathetically and clearly. Always be professional and thorough."""

print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto"
)
print("Model loaded successfully!")

# Conversation state tracking
conversation_turns = {}

def build_llama2_prompt(system_prompt, history, user_input):
    """Format the conversation history and user input for Llama-2 chat models."""
    prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
    
    # Add conversation history
    for user_msg, assistant_msg in history:
        prompt += f"{user_msg} [/INST] {assistant_msg} </s><s>[INST] "
    
    # Add the current user input
    prompt += f"{user_input} [/INST] "
    
    return prompt

@spaces.GPU
def generate_response(message, history):
    """Generate a response using the Llama-2 model with proper formatting."""
    # Track conversation turns
    session_id = "default-session"
    if session_id not in conversation_turns:
        conversation_turns[session_id] = 0
    conversation_turns[session_id] += 1
    
    # Build the prompt with proper Llama-2 formatting
    prompt = build_llama2_prompt(SYSTEM_PROMPT, history, message)
    
    # Add summarization instruction after 4 turns
    if conversation_turns[session_id] >= 4:
        prompt = prompt.replace("[/INST] ", "[/INST] Now summarize what you've learned and suggest when professional care may be needed. ")
    
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    # Generate the response
    with torch.no_grad():
        outputs = model.generate(
            inputs.input_ids,
            max_new_tokens=512,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    
    # Decode and extract only the assistant's response
    full_response = tokenizer.decode(outputs[0], skip_special_tokens=False)
    assistant_response = full_response.split('[/INST]')[-1].split('</s>')[0].strip()
    
    return assistant_response

# Create the Gradio interface
demo = gr.ChatInterface(
    fn=generate_response,
    title="Medical Assistant Chatbot",
    description="Ask about your symptoms and I'll help gather relevant information.",
    examples=[
        "I have a cough and my throat hurts",
        "I've been having headaches for a week",
        "My stomach has been hurting since yesterday"
    ],
    theme="soft"
)

if __name__ == "__main__":
    demo.launch()