Spaces:
Build error
Build error
update app
Browse files
app.py
CHANGED
@@ -1,112 +1,184 @@
|
|
1 |
import gradio as gr
|
2 |
-
import torch
|
3 |
import os
|
4 |
-
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
|
5 |
-
from pyannote.audio import Pipeline as DiarizationPipeline
|
6 |
-
import whisper
|
7 |
import tempfile
|
8 |
-
import
|
9 |
from pydub import AudioSegment
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
if
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
try:
|
75 |
-
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
except Exception as e:
|
78 |
-
return
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
if __name__ == "__main__":
|
112 |
-
app.launch(share=True)
|
|
|
1 |
import gradio as gr
|
|
|
2 |
import os
|
|
|
|
|
|
|
3 |
import tempfile
|
4 |
+
import torch
|
5 |
from pydub import AudioSegment
|
6 |
+
import whisper
|
7 |
+
from pyannote.audio import Pipeline
|
8 |
+
from pyannote.core import Segment
|
9 |
+
from lmdeploy import pipeline as lm_pipeline
|
10 |
+
from lmdeploy import GenerationConfig, TurbomindEngineConfig
|
11 |
+
from transformers import pipeline as hf_pipeline
|
12 |
+
from presidio_analyzer import AnalyzerEngine
|
13 |
+
from presidio_anonymizer import AnonymizerEngine
|
14 |
+
|
15 |
+
# --- Configuration ---
|
16 |
+
MEDICAL_NER_MODEL = "d4data/biomedical-ner-all"
|
17 |
+
WHISPER_MODEL_SIZE = "base" # "small" or "medium" for better accuracy
|
18 |
+
DEFAULT_HF_TOKEN = "your_huggingface_token_here" # Replace with your token
|
19 |
+
|
20 |
+
# --- Global Models ---
|
21 |
+
whisper_model = None
|
22 |
+
diarization_pipeline = None
|
23 |
+
med_ner = None
|
24 |
+
phi_analyzer = AnalyzerEngine()
|
25 |
+
phi_anonymizer = AnonymizerEngine()
|
26 |
+
|
27 |
+
qwen_models = {
|
28 |
+
"Qwen Medical 7B": "Qwen/Qwen2.5-7B-Instruct-1M",
|
29 |
+
"Qwen Fast 3B": "Qwen/Qwen2.5-3B-Instruct",
|
30 |
+
}
|
31 |
+
|
32 |
+
# --- Helper Functions ---
|
33 |
+
def load_models(hf_token):
|
34 |
+
"""Load all required models"""
|
35 |
+
global whisper_model, diarization_pipeline, med_ner
|
36 |
+
|
37 |
+
# Load Whisper (Speech-to-Text)
|
38 |
+
if whisper_model is None:
|
39 |
+
whisper_model = whisper.load_model(WHISPER_MODEL_SIZE, device="cuda" if torch.cuda.is_available() else "cpu")
|
40 |
+
|
41 |
+
# Load Diarization
|
42 |
+
if diarization_pipeline is None:
|
43 |
+
diarization_pipeline = Pipeline.from_pretrained(
|
44 |
+
"pyannote/speaker-diarization",
|
45 |
+
use_auth_token=hf_token
|
46 |
+
)
|
47 |
+
|
48 |
+
# Load Medical NER
|
49 |
+
if med_ner is None:
|
50 |
+
med_ner = hf_pipeline("ner", model=MEDICAL_NER_MODEL, aggregation_strategy="simple")
|
51 |
+
|
52 |
+
return "Models loaded successfully"
|
53 |
+
|
54 |
+
def convert_audio_to_wav(input_file):
|
55 |
+
"""Convert any audio file to 16kHz WAV format"""
|
56 |
+
audio = AudioSegment.from_file(input_file)
|
57 |
+
wav_path = os.path.join(tempfile.gettempdir(), "consultation.wav")
|
58 |
+
audio.set_frame_rate(16000).export(wav_path, format="wav")
|
59 |
+
return wav_path
|
60 |
+
|
61 |
+
def anonymize_phi(text):
|
62 |
+
"""Remove personally identifiable health information"""
|
63 |
+
results = phi_analyzer.analyze(text=text, language="en")
|
64 |
+
anonymized = phi_anonymizer.anonymize(text, results)
|
65 |
+
return anonymized.text
|
66 |
+
|
67 |
+
# --- Core Processing Functions ---
|
68 |
+
def transcribe_and_diarize(audio_file, hf_token):
|
69 |
+
"""Convert audio to text with speaker labels"""
|
70 |
try:
|
71 |
+
# Convert audio
|
72 |
+
wav_path = convert_audio_to_wav(audio_file)
|
73 |
+
|
74 |
+
# Transcribe
|
75 |
+
transcript = whisper_model.transcribe(wav_path)["segments"]
|
76 |
+
|
77 |
+
# Diarize
|
78 |
+
diarization = diarization_pipeline(wav_path)
|
79 |
+
|
80 |
+
# Combine results
|
81 |
+
output = []
|
82 |
+
for seg in transcript:
|
83 |
+
start, end, text = seg["start"], seg["end"], seg["text"]
|
84 |
+
speaker = next(diarization.itertracks(yield_label=True)).label
|
85 |
+
output.append(f"[{start:.1f}s] {speaker}: {text}")
|
86 |
+
|
87 |
+
return "\n".join(output), transcript
|
88 |
+
|
89 |
except Exception as e:
|
90 |
+
return f"Error: {str(e)}", None
|
91 |
+
|
92 |
+
def extract_medical_entities(text):
|
93 |
+
"""Identify drugs, conditions, and procedures"""
|
94 |
+
entities = med_ner(text)
|
95 |
+
return {
|
96 |
+
"Drugs": [e["word"] for e in entities if e["entity_group"] == "DRUG"],
|
97 |
+
"Conditions": [e["word"] for e in entities if e["entity_group"] == "DISEASE"],
|
98 |
+
"Procedures": [e["word"] for e in entities if e["entity_group"] == "TREATMENT"]
|
99 |
+
}
|
100 |
+
|
101 |
+
def generate_soap_notes(transcript, model_choice, anonymize_phi_flag):
|
102 |
+
"""Generate structured medical notes using Qwen"""
|
103 |
+
# Anonymize if requested
|
104 |
+
if anonymize_phi_flag:
|
105 |
+
transcript = anonymize_phi(transcript)
|
106 |
+
|
107 |
+
# Initialize Qwen
|
108 |
+
engine_config = TurbomindEngineConfig(
|
109 |
+
cache_max_entry_count=0.5,
|
110 |
+
session_len=131072
|
111 |
+
)
|
112 |
+
|
113 |
+
pipe = lm_pipeline(qwen_models[model_choice], backend_config=engine_config)
|
114 |
+
|
115 |
+
# Medical prompt template
|
116 |
+
system_prompt = """You are a clinical assistant. Convert this doctor-patient conversation into SOAP notes:
|
117 |
+
- Subjective: Patient-reported symptoms
|
118 |
+
- Objective: Clinician observations
|
119 |
+
- Assessment: Diagnosis/differential
|
120 |
+
- Plan: Treatment and follow-up"""
|
121 |
+
|
122 |
+
response = pipe([{
|
123 |
+
"role": "system",
|
124 |
+
"content": system_prompt
|
125 |
+
}, {
|
126 |
+
"role": "user",
|
127 |
+
"content": f"Consultation Transcript:\n{transcript}\n\nGenerate concise SOAP notes:"
|
128 |
+
}], GenerationConfig(max_new_tokens=1024))
|
129 |
+
|
130 |
+
return response.text
|
131 |
+
|
132 |
+
# --- Gradio Interface ---
|
133 |
+
with gr.Blocks(title="Clinical Consultation Summarizer", theme=gr.themes.Soft()) as app:
|
134 |
+
gr.Markdown("""# 🩺 Patient-Doctor Consultation Summarizer""")
|
135 |
+
|
136 |
+
with gr.Row():
|
137 |
+
with gr.Column():
|
138 |
+
audio_input = gr.Audio(
|
139 |
+
sources=["upload", "microphone"],
|
140 |
+
type="filepath",
|
141 |
+
label="Upload Consultation Recording"
|
142 |
+
)
|
143 |
+
hf_token = gr.Textbox(
|
144 |
+
label="Hugging Face Token",
|
145 |
+
value=DEFAULT_HF_TOKEN,
|
146 |
+
type="password"
|
147 |
+
)
|
148 |
+
model_choice = gr.Dropdown(
|
149 |
+
choices=list(qwen_models.keys()),
|
150 |
+
value="Qwen Medical 7B",
|
151 |
+
label="Model"
|
152 |
+
)
|
153 |
+
anonymize_check = gr.Checkbox(
|
154 |
+
label="Anonymize Protected Health Info (PHI)",
|
155 |
+
value=True
|
156 |
+
)
|
157 |
+
process_btn = gr.Button("Process Consultation")
|
158 |
+
|
159 |
+
with gr.Column():
|
160 |
+
with gr.Tabs():
|
161 |
+
with gr.Tab("Transcript"):
|
162 |
+
transcript_output = gr.Textbox(label="Transcribed Conversation", lines=15)
|
163 |
+
with gr.Tab("SOAP Notes"):
|
164 |
+
soap_output = gr.Textbox(label="Clinical Summary", lines=15)
|
165 |
+
with gr.Tab("Medical Entities"):
|
166 |
+
entity_output = gr.JSON(label="Extracted Medical Terms")
|
167 |
+
|
168 |
+
# Processing
|
169 |
+
process_btn.click(
|
170 |
+
fn=lambda audio, token: load_models(token) or transcribe_and_diarize(audio, token),
|
171 |
+
inputs=[audio_input, hf_token],
|
172 |
+
outputs=[transcript_output, gr.State()]
|
173 |
+
).success(
|
174 |
+
fn=generate_soap_notes,
|
175 |
+
inputs=[transcript_output, model_choice, anonymize_check],
|
176 |
+
outputs=soap_output
|
177 |
+
).success(
|
178 |
+
fn=extract_medical_entities,
|
179 |
+
inputs=transcript_output,
|
180 |
+
outputs=entity_output
|
181 |
+
)
|
182 |
|
183 |
if __name__ == "__main__":
|
184 |
+
app.launch(server_port=7860, share=True)
|