Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
import re
|
4 |
+
from collections import Counter
|
5 |
+
from llama_cpp import Llama
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
+
|
8 |
+
# Configuration: set environment variables for model repository and file
|
9 |
+
HF_REPO_ID = os.getenv("HF_REPO_ID", "QuantFactory/Meta-Llama-3-8B-Instruct-GGUF")
|
10 |
+
HF_MODEL_BASENAME = os.getenv("HF_MODEL_BASENAME", "Meta-Llama-3-8B-Instruct.Q8_0.gguf")
|
11 |
+
|
12 |
+
# Download or locate the quantized LLaMA model
|
13 |
+
model_path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_MODEL_BASENAME)
|
14 |
+
|
15 |
+
# Initialize LLaMA via llama_cpp
|
16 |
+
llm = Llama(
|
17 |
+
model_path=model_path,
|
18 |
+
n_threads=int(os.getenv("LLAMA_THREADS", "4")),
|
19 |
+
n_batch=int(os.getenv("LLAMA_BATCH", "256")),
|
20 |
+
n_gpu_layers=int(os.getenv("LLAMA_GPU_LAYERS", "43")),
|
21 |
+
n_ctx=int(os.getenv("LLAMA_CTX", "8192"))
|
22 |
+
)
|
23 |
+
|
24 |
+
# Prompt templates from notebook
|
25 |
+
system_prompt_few_shot = """
|
26 |
+
SYSTEM:
|
27 |
+
You are an AI medical assistant specializing in differential diagnosis.
|
28 |
+
Generate the most likely list of diagnoses based on examples.
|
29 |
+
|
30 |
+
USER: A 45-year-old male, fever, cough, fatigue.
|
31 |
+
SYSTEM: [Flu, COVID-19, Pneumonia]
|
32 |
+
|
33 |
+
USER: A 30-year-old female, severe abdominal pain, nausea.
|
34 |
+
SYSTEM: [Appendicitis, Gallstones, Gastritis]
|
35 |
+
|
36 |
+
USER: A 10-year-old female, wheezing.
|
37 |
+
SYSTEM: [Asthma, Respiratory Infection]
|
38 |
+
|
39 |
+
USER:
|
40 |
+
"""
|
41 |
+
system_prompt_cot = """
|
42 |
+
SYSTEM:
|
43 |
+
You are a medical expert performing differential diagnosis through step-by-step reasoning.
|
44 |
+
Provide intermediate reasoning and final diagnoses.
|
45 |
+
|
46 |
+
USER:
|
47 |
+
"""
|
48 |
+
system_prompt_tot = """
|
49 |
+
SYSTEM:
|
50 |
+
You are a medical expert using a tree-of-thought approach for differential diagnosis.
|
51 |
+
Construct a reasoning tree then provide final diagnoses.
|
52 |
+
|
53 |
+
USER:
|
54 |
+
"""
|
55 |
+
|
56 |
+
def lcpp_llm(prompt, max_tokens=2048, temperature=0, stop=["USER"]):
|
57 |
+
return llm(prompt=prompt, max_tokens=max_tokens, temperature=temperature, stop=stop)
|
58 |
+
|
59 |
+
def extract_list(text):
|
60 |
+
match = re.search(r'\[(.*?)\]', text)
|
61 |
+
if match:
|
62 |
+
return [item.strip() for item in match.group(1).split(",")]
|
63 |
+
return []
|
64 |
+
|
65 |
+
def determine_most_probable(few_shot, cot, tot):
|
66 |
+
counts = Counter(few_shot + cot + tot)
|
67 |
+
if not counts:
|
68 |
+
return "No Clear Diagnosis"
|
69 |
+
max_occ = max(counts.values())
|
70 |
+
for diag, cnt in counts.items():
|
71 |
+
if cnt == max_occ:
|
72 |
+
return diag
|
73 |
+
|
74 |
+
def medical_diagnosis(symptoms: str):
|
75 |
+
try:
|
76 |
+
# Generate responses
|
77 |
+
resp_few = lcpp_llm(system_prompt_few_shot + symptoms)
|
78 |
+
resp_cot = lcpp_llm(system_prompt_cot + symptoms)
|
79 |
+
resp_tot = lcpp_llm(system_prompt_tot + symptoms)
|
80 |
+
|
81 |
+
# Extract text
|
82 |
+
text_few = resp_few['choices'][0]['text'].strip()
|
83 |
+
text_cot = resp_cot['choices'][0]['text'].strip()
|
84 |
+
text_tot = resp_tot['choices'][0]['text'].strip()
|
85 |
+
|
86 |
+
# Parse lists
|
87 |
+
few = extract_list(text_few)
|
88 |
+
cot = extract_list(text_cot)
|
89 |
+
tot = extract_list(text_tot)
|
90 |
+
most = determine_most_probable(few, cot, tot)
|
91 |
+
|
92 |
+
# Format Markdown output
|
93 |
+
return f"""
|
94 |
+
### Differential Diagnosis Results
|
95 |
+
|
96 |
+
**Few-Shot Diagnoses:** {', '.join(few) if few else 'No Diagnosis'}
|
97 |
+
|
98 |
+
**Chain-of-Thought Diagnoses:** {', '.join(cot) if cot else 'No Diagnosis'}
|
99 |
+
|
100 |
+
**Tree-of-Thought Diagnoses:** {', '.join(tot) if tot else 'No Diagnosis'}
|
101 |
+
|
102 |
+
**Most Probable Diagnosis:** {most}
|
103 |
+
"""
|
104 |
+
except Exception as e:
|
105 |
+
return f"Error: {e}"
|
106 |
+
|
107 |
+
# Gradio app definition
|
108 |
+
with gr.Blocks() as demo:
|
109 |
+
gr.Markdown("# Differential Diagnosis Explorer (Local LLaMA)")
|
110 |
+
cond = gr.Textbox(label="Patient Condition", placeholder="A 35-year-old male, fever, wheezing, nausea.")
|
111 |
+
out = gr.Markdown()
|
112 |
+
btn = gr.Button("Diagnose")
|
113 |
+
btn.click(fn=medical_diagnosis, inputs=cond, outputs=out)
|
114 |
+
|
115 |
+
if __name__ == "__main__":
|
116 |
+
demo.launch(server_name="0.0.0.0", server_port=int(os.getenv('PORT', 7860)))
|