yunusajib's picture
update app
5cc39c0 verified
raw
history blame
6.44 kB
import gradio as gr
import os
import tempfile
import torch
from pydub import AudioSegment
import whisper
from pyannote.audio import Pipeline
from pyannote.core import Segment
from lmdeploy import pipeline as lm_pipeline
from lmdeploy import GenerationConfig, TurbomindEngineConfig
from transformers import pipeline as hf_pipeline
from presidio_analyzer import AnalyzerEngine
from presidio_anonymizer import AnonymizerEngine
# --- Configuration ---
MEDICAL_NER_MODEL = "d4data/biomedical-ner-all"
WHISPER_MODEL_SIZE = "base" # "small" or "medium" for better accuracy
DEFAULT_HF_TOKEN = "your_huggingface_token_here" # Replace with your token
# --- Global Models ---
whisper_model = None
diarization_pipeline = None
med_ner = None
phi_analyzer = AnalyzerEngine()
phi_anonymizer = AnonymizerEngine()
qwen_models = {
"Qwen Medical 7B": "Qwen/Qwen2.5-7B-Instruct-1M",
"Qwen Fast 3B": "Qwen/Qwen2.5-3B-Instruct",
}
# --- Helper Functions ---
def load_models(hf_token):
"""Load all required models"""
global whisper_model, diarization_pipeline, med_ner
# Load Whisper (Speech-to-Text)
if whisper_model is None:
whisper_model = whisper.load_model(WHISPER_MODEL_SIZE, device="cuda" if torch.cuda.is_available() else "cpu")
# Load Diarization
if diarization_pipeline is None:
diarization_pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization",
use_auth_token=hf_token
)
# Load Medical NER
if med_ner is None:
med_ner = hf_pipeline("ner", model=MEDICAL_NER_MODEL, aggregation_strategy="simple")
return "Models loaded successfully"
def convert_audio_to_wav(input_file):
"""Convert any audio file to 16kHz WAV format"""
audio = AudioSegment.from_file(input_file)
wav_path = os.path.join(tempfile.gettempdir(), "consultation.wav")
audio.set_frame_rate(16000).export(wav_path, format="wav")
return wav_path
def anonymize_phi(text):
"""Remove personally identifiable health information"""
results = phi_analyzer.analyze(text=text, language="en")
anonymized = phi_anonymizer.anonymize(text, results)
return anonymized.text
# --- Core Processing Functions ---
def transcribe_and_diarize(audio_file, hf_token):
"""Convert audio to text with speaker labels"""
try:
# Convert audio
wav_path = convert_audio_to_wav(audio_file)
# Transcribe
transcript = whisper_model.transcribe(wav_path)["segments"]
# Diarize
diarization = diarization_pipeline(wav_path)
# Combine results
output = []
for seg in transcript:
start, end, text = seg["start"], seg["end"], seg["text"]
speaker = next(diarization.itertracks(yield_label=True)).label
output.append(f"[{start:.1f}s] {speaker}: {text}")
return "\n".join(output), transcript
except Exception as e:
return f"Error: {str(e)}", None
def extract_medical_entities(text):
"""Identify drugs, conditions, and procedures"""
entities = med_ner(text)
return {
"Drugs": [e["word"] for e in entities if e["entity_group"] == "DRUG"],
"Conditions": [e["word"] for e in entities if e["entity_group"] == "DISEASE"],
"Procedures": [e["word"] for e in entities if e["entity_group"] == "TREATMENT"]
}
def generate_soap_notes(transcript, model_choice, anonymize_phi_flag):
"""Generate structured medical notes using Qwen"""
# Anonymize if requested
if anonymize_phi_flag:
transcript = anonymize_phi(transcript)
# Initialize Qwen
engine_config = TurbomindEngineConfig(
cache_max_entry_count=0.5,
session_len=131072
)
pipe = lm_pipeline(qwen_models[model_choice], backend_config=engine_config)
# Medical prompt template
system_prompt = """You are a clinical assistant. Convert this doctor-patient conversation into SOAP notes:
- Subjective: Patient-reported symptoms
- Objective: Clinician observations
- Assessment: Diagnosis/differential
- Plan: Treatment and follow-up"""
response = pipe([{
"role": "system",
"content": system_prompt
}, {
"role": "user",
"content": f"Consultation Transcript:\n{transcript}\n\nGenerate concise SOAP notes:"
}], GenerationConfig(max_new_tokens=1024))
return response.text
# --- Gradio Interface ---
with gr.Blocks(title="Clinical Consultation Summarizer", theme=gr.themes.Soft()) as app:
gr.Markdown("""# 🩺 Patient-Doctor Consultation Summarizer""")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Upload Consultation Recording"
)
hf_token = gr.Textbox(
label="Hugging Face Token",
value=DEFAULT_HF_TOKEN,
type="password"
)
model_choice = gr.Dropdown(
choices=list(qwen_models.keys()),
value="Qwen Medical 7B",
label="Model"
)
anonymize_check = gr.Checkbox(
label="Anonymize Protected Health Info (PHI)",
value=True
)
process_btn = gr.Button("Process Consultation")
with gr.Column():
with gr.Tabs():
with gr.Tab("Transcript"):
transcript_output = gr.Textbox(label="Transcribed Conversation", lines=15)
with gr.Tab("SOAP Notes"):
soap_output = gr.Textbox(label="Clinical Summary", lines=15)
with gr.Tab("Medical Entities"):
entity_output = gr.JSON(label="Extracted Medical Terms")
# Processing
process_btn.click(
fn=lambda audio, token: load_models(token) or transcribe_and_diarize(audio, token),
inputs=[audio_input, hf_token],
outputs=[transcript_output, gr.State()]
).success(
fn=generate_soap_notes,
inputs=[transcript_output, model_choice, anonymize_check],
outputs=soap_output
).success(
fn=extract_medical_entities,
inputs=transcript_output,
outputs=entity_output
)
if __name__ == "__main__":
app.launch(server_port=7860, share=True)