File size: 6,437 Bytes
2606de1
 
 
5cc39c0
2606de1
5cc39c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2606de1
5cc39c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2606de1
5cc39c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2606de1
 
5cc39c0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import gradio as gr
import os
import tempfile
import torch
from pydub import AudioSegment
import whisper
from pyannote.audio import Pipeline
from pyannote.core import Segment
from lmdeploy import pipeline as lm_pipeline
from lmdeploy import GenerationConfig, TurbomindEngineConfig
from transformers import pipeline as hf_pipeline
from presidio_analyzer import AnalyzerEngine
from presidio_anonymizer import AnonymizerEngine

# --- Configuration ---
MEDICAL_NER_MODEL = "d4data/biomedical-ner-all"
WHISPER_MODEL_SIZE = "base"  # "small" or "medium" for better accuracy
DEFAULT_HF_TOKEN = "your_huggingface_token_here"  # Replace with your token

# --- Global Models ---
whisper_model = None
diarization_pipeline = None
med_ner = None
phi_analyzer = AnalyzerEngine()
phi_anonymizer = AnonymizerEngine()

qwen_models = {
    "Qwen Medical 7B": "Qwen/Qwen2.5-7B-Instruct-1M",
    "Qwen Fast 3B": "Qwen/Qwen2.5-3B-Instruct",
}

# --- Helper Functions ---
def load_models(hf_token):
    """Load all required models"""
    global whisper_model, diarization_pipeline, med_ner
    
    # Load Whisper (Speech-to-Text)
    if whisper_model is None:
        whisper_model = whisper.load_model(WHISPER_MODEL_SIZE, device="cuda" if torch.cuda.is_available() else "cpu")
    
    # Load Diarization
    if diarization_pipeline is None:
        diarization_pipeline = Pipeline.from_pretrained(
            "pyannote/speaker-diarization", 
            use_auth_token=hf_token
        )
    
    # Load Medical NER
    if med_ner is None:
        med_ner = hf_pipeline("ner", model=MEDICAL_NER_MODEL, aggregation_strategy="simple")
    
    return "Models loaded successfully"

def convert_audio_to_wav(input_file):
    """Convert any audio file to 16kHz WAV format"""
    audio = AudioSegment.from_file(input_file)
    wav_path = os.path.join(tempfile.gettempdir(), "consultation.wav")
    audio.set_frame_rate(16000).export(wav_path, format="wav")
    return wav_path

def anonymize_phi(text):
    """Remove personally identifiable health information"""
    results = phi_analyzer.analyze(text=text, language="en")
    anonymized = phi_anonymizer.anonymize(text, results)
    return anonymized.text

# --- Core Processing Functions ---
def transcribe_and_diarize(audio_file, hf_token):
    """Convert audio to text with speaker labels"""
    try:
        # Convert audio
        wav_path = convert_audio_to_wav(audio_file)
        
        # Transcribe
        transcript = whisper_model.transcribe(wav_path)["segments"]
        
        # Diarize
        diarization = diarization_pipeline(wav_path)
        
        # Combine results
        output = []
        for seg in transcript:
            start, end, text = seg["start"], seg["end"], seg["text"]
            speaker = next(diarization.itertracks(yield_label=True)).label
            output.append(f"[{start:.1f}s] {speaker}: {text}")
        
        return "\n".join(output), transcript
    
    except Exception as e:
        return f"Error: {str(e)}", None

def extract_medical_entities(text):
    """Identify drugs, conditions, and procedures"""
    entities = med_ner(text)
    return {
        "Drugs": [e["word"] for e in entities if e["entity_group"] == "DRUG"],
        "Conditions": [e["word"] for e in entities if e["entity_group"] == "DISEASE"],
        "Procedures": [e["word"] for e in entities if e["entity_group"] == "TREATMENT"]
    }

def generate_soap_notes(transcript, model_choice, anonymize_phi_flag):
    """Generate structured medical notes using Qwen"""
    # Anonymize if requested
    if anonymize_phi_flag:
        transcript = anonymize_phi(transcript)
    
    # Initialize Qwen
    engine_config = TurbomindEngineConfig(
        cache_max_entry_count=0.5,
        session_len=131072
    )
    
    pipe = lm_pipeline(qwen_models[model_choice], backend_config=engine_config)
    
    # Medical prompt template
    system_prompt = """You are a clinical assistant. Convert this doctor-patient conversation into SOAP notes:
    - Subjective: Patient-reported symptoms
    - Objective: Clinician observations
    - Assessment: Diagnosis/differential
    - Plan: Treatment and follow-up"""
    
    response = pipe([{
        "role": "system", 
        "content": system_prompt
    }, {
        "role": "user",
        "content": f"Consultation Transcript:\n{transcript}\n\nGenerate concise SOAP notes:"
    }], GenerationConfig(max_new_tokens=1024))
    
    return response.text

# --- Gradio Interface ---
with gr.Blocks(title="Clinical Consultation Summarizer", theme=gr.themes.Soft()) as app:
    gr.Markdown("""# 🩺 Patient-Doctor Consultation Summarizer""")
    
    with gr.Row():
        with gr.Column():
            audio_input = gr.Audio(
                sources=["upload", "microphone"],
                type="filepath",
                label="Upload Consultation Recording"
            )
            hf_token = gr.Textbox(
                label="Hugging Face Token",
                value=DEFAULT_HF_TOKEN,
                type="password"
            )
            model_choice = gr.Dropdown(
                choices=list(qwen_models.keys()),
                value="Qwen Medical 7B",
                label="Model"
            )
            anonymize_check = gr.Checkbox(
                label="Anonymize Protected Health Info (PHI)",
                value=True
            )
            process_btn = gr.Button("Process Consultation")
            
        with gr.Column():
            with gr.Tabs():
                with gr.Tab("Transcript"):
                    transcript_output = gr.Textbox(label="Transcribed Conversation", lines=15)
                with gr.Tab("SOAP Notes"):
                    soap_output = gr.Textbox(label="Clinical Summary", lines=15)
                with gr.Tab("Medical Entities"):
                    entity_output = gr.JSON(label="Extracted Medical Terms")
    
    # Processing
    process_btn.click(
        fn=lambda audio, token: load_models(token) or transcribe_and_diarize(audio, token),
        inputs=[audio_input, hf_token],
        outputs=[transcript_output, gr.State()]
    ).success(
        fn=generate_soap_notes,
        inputs=[transcript_output, model_choice, anonymize_check],
        outputs=soap_output
    ).success(
        fn=extract_medical_entities,
        inputs=transcript_output,
        outputs=entity_output
    )

if __name__ == "__main__":
    app.launch(server_port=7860, share=True)