yunusajib commited on
Commit
5cc39c0
·
verified ·
1 Parent(s): db349e8

update app

Browse files
Files changed (1) hide show
  1. app.py +176 -104
app.py CHANGED
@@ -1,112 +1,184 @@
1
  import gradio as gr
2
- import torch
3
  import os
4
- from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
5
- from pyannote.audio import Pipeline as DiarizationPipeline
6
- import whisper
7
  import tempfile
8
- import shutil
9
  from pydub import AudioSegment
10
-
11
- # Load whisper model
12
- whisper_model = whisper.load_model("base") # Use "small" or "medium" if needed
13
-
14
- # Load summarization pipeline
15
- summarizer_tokenizer = AutoTokenizer.from_pretrained("t5-small")
16
- summarizer_model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
17
- summarizer = pipeline("summarization", model=summarizer_model, tokenizer=summarizer_tokenizer)
18
-
19
- def convert_to_wav(input_path, output_path):
20
- audio = AudioSegment.from_file(input_path)
21
- audio.export(output_path, format="wav")
22
-
23
- def transcribe_audio(audio_path):
24
- result = whisper_model.transcribe(audio_path, fp16=torch.cuda.is_available())
25
- return result['text']
26
-
27
- def diarize_audio(audio_path, hf_token):
28
- os.environ["HF_TOKEN"] = hf_token
29
- diarization_pipeline = DiarizationPipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
30
- diarization = diarization_pipeline(audio_path)
31
- return diarization
32
-
33
- def combine_diarized_transcript(diarization, full_text):
34
- # Basic speaker labeling using diarization and full text
35
- # Note: This is a simplified alignment using time chunks only
36
- chunks = []
37
- for turn, _, speaker in diarization.itertracks(yield_label=True):
38
- start, end = turn.start, turn.end
39
- chunks.append(f"{speaker}: [from {start:.1f}s to {end:.1f}s]")
40
- # Combine for display/demo
41
- return "\n".join(chunks) + "\n" + full_text
42
-
43
- def summarize_text(text):
44
- prefix = "summarize: " + text.strip()
45
- inputs = summarizer_tokenizer.encode(prefix, return_tensors="pt", max_length=512, truncation=True)
46
- summary_ids = summarizer_model.generate(inputs, max_length=100, min_length=30, length_penalty=2.0, num_beams=4, early_stopping=True)
47
- return summarizer_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
48
-
49
- def process_pipeline(audio_file, hf_token):
50
- if not hf_token:
51
- return "", "", "Error: HuggingFace token is required."
52
-
53
- if not os.path.exists(audio_file) or os.path.getsize(audio_file) == 0:
54
- return "", "", "Error: Uploaded file is missing or empty."
55
-
56
- # Step 1: Convert to WAV if needed
57
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_wav:
58
- try:
59
- sound = AudioSegment.from_file(audio_file)
60
- sound.export(tmp_wav.name, format="wav")
61
- tmp_path = tmp_wav.name
62
- except Exception as e:
63
- return "", "", f"Audio conversion failed: {str(e)}"
64
-
65
- # Step 2: Transcription (Whisper)
66
- try:
67
- transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base",return_timestamps=True, device=0 if torch.cuda.is_available() else -1)
68
- result = transcriber(tmp_path)
69
- transcript = result["text"]
70
- except Exception as e:
71
- return "", "", f"Transcription failed: {str(e)}"
72
-
73
- # Step 3: Summarization
74
  try:
75
- summarizer = pipeline("summarization", model="facebook/bart-large-cnn", device=0 if torch.cuda.is_available() else -1)
76
- summary = summarizer(transcript, max_length=130, min_length=30, do_sample=False)[0]["summary_text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  except Exception as e:
78
- return transcript, "", f"Summarization failed: {str(e)}"
79
-
80
- return tmp_path, transcript, summary
81
-
82
- description = """
83
- ### 🩺 GP Consultation Summarizer (Demo App)
84
-
85
- This app:
86
- 1. Transcribes short consultation audio using Whisper
87
- 2. Identifies who spoke when using PyAnnote speaker diarization
88
- 3. Combines both into a labeled transcript
89
- 4. Generates a short summary using T5-small
90
-
91
- ⚠️ **Note:** Best for short consultations (under 5–6 minutes).
92
- ⚠️ You must provide your own Hugging Face token (required for diarization).
93
- """
94
-
95
- app = gr.Interface(
96
- fn=process_pipeline,
97
- inputs=[
98
- gr.Audio(type="filepath", label="Upload Consultation Audio (.wav)"),
99
- gr.Textbox(label="Your Hugging Face Token", type="password")
100
- ],
101
- outputs=[
102
- gr.Textbox(label="Raw Transcript"),
103
- gr.Textbox(label="Labeled Transcript (with Speaker Info)"),
104
- gr.Textbox(label="Summary")
105
- ],
106
- title="GP Consultation Summarizer",
107
- description=description,
108
- allow_flagging="never"
109
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  if __name__ == "__main__":
112
- app.launch(share=True)
 
1
  import gradio as gr
 
2
  import os
 
 
 
3
  import tempfile
4
+ import torch
5
  from pydub import AudioSegment
6
+ import whisper
7
+ from pyannote.audio import Pipeline
8
+ from pyannote.core import Segment
9
+ from lmdeploy import pipeline as lm_pipeline
10
+ from lmdeploy import GenerationConfig, TurbomindEngineConfig
11
+ from transformers import pipeline as hf_pipeline
12
+ from presidio_analyzer import AnalyzerEngine
13
+ from presidio_anonymizer import AnonymizerEngine
14
+
15
+ # --- Configuration ---
16
+ MEDICAL_NER_MODEL = "d4data/biomedical-ner-all"
17
+ WHISPER_MODEL_SIZE = "base" # "small" or "medium" for better accuracy
18
+ DEFAULT_HF_TOKEN = "your_huggingface_token_here" # Replace with your token
19
+
20
+ # --- Global Models ---
21
+ whisper_model = None
22
+ diarization_pipeline = None
23
+ med_ner = None
24
+ phi_analyzer = AnalyzerEngine()
25
+ phi_anonymizer = AnonymizerEngine()
26
+
27
+ qwen_models = {
28
+ "Qwen Medical 7B": "Qwen/Qwen2.5-7B-Instruct-1M",
29
+ "Qwen Fast 3B": "Qwen/Qwen2.5-3B-Instruct",
30
+ }
31
+
32
+ # --- Helper Functions ---
33
+ def load_models(hf_token):
34
+ """Load all required models"""
35
+ global whisper_model, diarization_pipeline, med_ner
36
+
37
+ # Load Whisper (Speech-to-Text)
38
+ if whisper_model is None:
39
+ whisper_model = whisper.load_model(WHISPER_MODEL_SIZE, device="cuda" if torch.cuda.is_available() else "cpu")
40
+
41
+ # Load Diarization
42
+ if diarization_pipeline is None:
43
+ diarization_pipeline = Pipeline.from_pretrained(
44
+ "pyannote/speaker-diarization",
45
+ use_auth_token=hf_token
46
+ )
47
+
48
+ # Load Medical NER
49
+ if med_ner is None:
50
+ med_ner = hf_pipeline("ner", model=MEDICAL_NER_MODEL, aggregation_strategy="simple")
51
+
52
+ return "Models loaded successfully"
53
+
54
+ def convert_audio_to_wav(input_file):
55
+ """Convert any audio file to 16kHz WAV format"""
56
+ audio = AudioSegment.from_file(input_file)
57
+ wav_path = os.path.join(tempfile.gettempdir(), "consultation.wav")
58
+ audio.set_frame_rate(16000).export(wav_path, format="wav")
59
+ return wav_path
60
+
61
+ def anonymize_phi(text):
62
+ """Remove personally identifiable health information"""
63
+ results = phi_analyzer.analyze(text=text, language="en")
64
+ anonymized = phi_anonymizer.anonymize(text, results)
65
+ return anonymized.text
66
+
67
+ # --- Core Processing Functions ---
68
+ def transcribe_and_diarize(audio_file, hf_token):
69
+ """Convert audio to text with speaker labels"""
70
  try:
71
+ # Convert audio
72
+ wav_path = convert_audio_to_wav(audio_file)
73
+
74
+ # Transcribe
75
+ transcript = whisper_model.transcribe(wav_path)["segments"]
76
+
77
+ # Diarize
78
+ diarization = diarization_pipeline(wav_path)
79
+
80
+ # Combine results
81
+ output = []
82
+ for seg in transcript:
83
+ start, end, text = seg["start"], seg["end"], seg["text"]
84
+ speaker = next(diarization.itertracks(yield_label=True)).label
85
+ output.append(f"[{start:.1f}s] {speaker}: {text}")
86
+
87
+ return "\n".join(output), transcript
88
+
89
  except Exception as e:
90
+ return f"Error: {str(e)}", None
91
+
92
+ def extract_medical_entities(text):
93
+ """Identify drugs, conditions, and procedures"""
94
+ entities = med_ner(text)
95
+ return {
96
+ "Drugs": [e["word"] for e in entities if e["entity_group"] == "DRUG"],
97
+ "Conditions": [e["word"] for e in entities if e["entity_group"] == "DISEASE"],
98
+ "Procedures": [e["word"] for e in entities if e["entity_group"] == "TREATMENT"]
99
+ }
100
+
101
+ def generate_soap_notes(transcript, model_choice, anonymize_phi_flag):
102
+ """Generate structured medical notes using Qwen"""
103
+ # Anonymize if requested
104
+ if anonymize_phi_flag:
105
+ transcript = anonymize_phi(transcript)
106
+
107
+ # Initialize Qwen
108
+ engine_config = TurbomindEngineConfig(
109
+ cache_max_entry_count=0.5,
110
+ session_len=131072
111
+ )
112
+
113
+ pipe = lm_pipeline(qwen_models[model_choice], backend_config=engine_config)
114
+
115
+ # Medical prompt template
116
+ system_prompt = """You are a clinical assistant. Convert this doctor-patient conversation into SOAP notes:
117
+ - Subjective: Patient-reported symptoms
118
+ - Objective: Clinician observations
119
+ - Assessment: Diagnosis/differential
120
+ - Plan: Treatment and follow-up"""
121
+
122
+ response = pipe([{
123
+ "role": "system",
124
+ "content": system_prompt
125
+ }, {
126
+ "role": "user",
127
+ "content": f"Consultation Transcript:\n{transcript}\n\nGenerate concise SOAP notes:"
128
+ }], GenerationConfig(max_new_tokens=1024))
129
+
130
+ return response.text
131
+
132
+ # --- Gradio Interface ---
133
+ with gr.Blocks(title="Clinical Consultation Summarizer", theme=gr.themes.Soft()) as app:
134
+ gr.Markdown("""# 🩺 Patient-Doctor Consultation Summarizer""")
135
+
136
+ with gr.Row():
137
+ with gr.Column():
138
+ audio_input = gr.Audio(
139
+ sources=["upload", "microphone"],
140
+ type="filepath",
141
+ label="Upload Consultation Recording"
142
+ )
143
+ hf_token = gr.Textbox(
144
+ label="Hugging Face Token",
145
+ value=DEFAULT_HF_TOKEN,
146
+ type="password"
147
+ )
148
+ model_choice = gr.Dropdown(
149
+ choices=list(qwen_models.keys()),
150
+ value="Qwen Medical 7B",
151
+ label="Model"
152
+ )
153
+ anonymize_check = gr.Checkbox(
154
+ label="Anonymize Protected Health Info (PHI)",
155
+ value=True
156
+ )
157
+ process_btn = gr.Button("Process Consultation")
158
+
159
+ with gr.Column():
160
+ with gr.Tabs():
161
+ with gr.Tab("Transcript"):
162
+ transcript_output = gr.Textbox(label="Transcribed Conversation", lines=15)
163
+ with gr.Tab("SOAP Notes"):
164
+ soap_output = gr.Textbox(label="Clinical Summary", lines=15)
165
+ with gr.Tab("Medical Entities"):
166
+ entity_output = gr.JSON(label="Extracted Medical Terms")
167
+
168
+ # Processing
169
+ process_btn.click(
170
+ fn=lambda audio, token: load_models(token) or transcribe_and_diarize(audio, token),
171
+ inputs=[audio_input, hf_token],
172
+ outputs=[transcript_output, gr.State()]
173
+ ).success(
174
+ fn=generate_soap_notes,
175
+ inputs=[transcript_output, model_choice, anonymize_check],
176
+ outputs=soap_output
177
+ ).success(
178
+ fn=extract_medical_entities,
179
+ inputs=transcript_output,
180
+ outputs=entity_output
181
+ )
182
 
183
  if __name__ == "__main__":
184
+ app.launch(server_port=7860, share=True)