File size: 18,431 Bytes
d84bb55
80bd890
cbd44c9
eec5410
6b4f26c
78f48b0
00f162e
6b4f26c
78f48b0
6b4f26c
eec5410
00f162e
d84bb55
e4075e2
eec5410
1553629
314581a
 
 
 
 
 
1553629
992eccd
78f48b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b4f26c
78f48b0
b3d4dd9
6b4f26c
 
78f48b0
 
 
 
314581a
 
78f48b0
 
329bbea
78f48b0
329bbea
78f48b0
 
 
 
329bbea
78f48b0
329bbea
78f48b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329bbea
78f48b0
 
6b4f26c
80bd890
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1553629
80bd890
 
 
 
1553629
 
80bd890
 
 
 
1553629
80bd890
 
 
 
3e3fe0f
80bd890
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b4f26c
80bd890
 
 
 
 
 
 
 
 
 
 
6b4f26c
 
78f48b0
758814b
314581a
 
758814b
 
 
 
 
 
 
 
 
 
 
 
314581a
 
 
 
 
 
 
 
 
 
6b4f26c
187fcba
6b4f26c
80bd890
 
 
 
dd8a13c
 
 
 
 
 
 
 
 
80bd890
dd8a13c
 
 
 
 
80bd890
 
 
 
 
 
 
 
 
 
 
 
c849e27
80bd890
76780a9
80bd890
 
4d1096e
80bd890
 
 
d28f205
80bd890
 
 
 
87c7747
80bd890
 
 
 
 
 
 
 
 
314581a
80bd890
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f0bce5
80bd890
329bbea
 
 
 
80bd890
 
 
 
329bbea
80bd890
329bbea
80bd890
 
329bbea
80bd890
329bbea
80bd890
 
329bbea
80bd890
329bbea
80bd890
 
 
329bbea
80bd890
329bbea
80bd890
329bbea
80bd890
329bbea
80bd890
 
329bbea
80bd890
329bbea
80bd890
 
 
 
329bbea
 
 
80bd890
 
 
 
 
4d1096e
eec5410
a6e0ce3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
import os
import ast
import spaces
import gradio as gr

from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login as hf_login

import xgrammar as xgr
from pydantic import BaseModel

hf_login(token=os.getenv("HF_TOKEN"))

model_name = "gregorlied/Llama-3.2-1B-Instruct-Medical-Report-Summarization-FP32"

"""
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    attn_implementation='eager',
    trust_remote_code=True,
)
"""

class Person(BaseModel):
  life_style: str
  family_history: str
  social_history: str
  medical_surgical_history: str
  signs_symptoms: str
  comorbidities: str
  diagnostic_techniques_procedures: str
  diagnosis: str
  laboratory_values: str
  pathology: str
  pharmacological_therapy: str
  interventional_therapy: str
  patient_outcome_assessment: str
  age: str
  gender: str
    
config = AutoConfig.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

tokenizer_info = xgr.TokenizerInfo.from_huggingface(
    tokenizer, vocab_size=len(tokenizer)
)

grammar_compiler = xgr.GrammarCompiler(tokenizer_info)
compiled_grammar = grammar_compiler.compile_json_schema(Person)
xgr_logits_processor = xgr.contrib.hf.LogitsProcessor(compiled_grammar)

default_value = "A 57-year-old male presented with fever (38.9°C), chest pain, cough, and progressive dyspnea. The patient exhibited tachypnea (34 breaths/min) and tachycardia (134 bpm). Auscultation revealed decreased breath sounds in both lung bases, with crackles on the left. A chest X-ray revealed bilateral pleural opacities and enlargement of the cardiac silhouette ( A). Echocardiography showed moderate pericardial effusion affecting the entire cardiac silhouette. Pericardiocentesis yielded 250 mL of exudative fluid. A CT scan of the chest showed pneumonia in the left lower lobe, bilateral pleural effusion, and moderate pericardial effusion ( B). Thoracentesis was performed and yielded 1,050 mL of exudative fluid. Laboratory tests yielded the following data: white blood cell count, 11.78 × 109 cells/L (84.3% neutrophils, 4.3% lymphocytes, and 9.1% monocytes); platelet count, 512 × 109/L; serum C-reactive protein, 31.27 mg/dL; serum creatinine, 0.94 mg/dL; serum sodium, 133 mEq/L; and serum potassium, 3.72 mEq/L. Examination of the pleural fluid showed a pH of 7.16, a glucose level of 4.5 mg/dL, proteins at 49.1 g/L, and an LDH content of 1,385 U/L. A urinary pneumococcal antigen test was positive. Pleural fluid culture was positive for S. pneumoniae. The patient was treated for four weeks with amoxicillin-clavulanate (2.2 g/8 h, i.v.) plus levofloxacin (500 mg twice a day), together with a nonsteroidal anti-inflammatory drug (ibuprofen, 800 mg/day), after which there was nearly complete resolution of the alterations seen on the chest X-ray and CT scan."

prompt = """You are a text extraction system for clinical reports.
Please extract relevant clinical information from the report.

### Instructions

- Use the JSON Schema given below.
- Return only a valid JSON object – no markdown, no comments.
- If no relevant facts are given for a field, set its value to "N/A".
- If multile relevant facts are given for a field, separate them with "; ".

### JSON Schema

{
  'life_style': '',
  'family_history': '',
  'social_history': '',
  'medical_surgical_history': '',
  'signs_symptoms': '',
  'comorbidities': '',
  'diagnostic_techniques_procedures': '',
  'diagnosis': '',
  'laboratory_values': '',
  'pathology': '',
  'pharmacological_therapy': '',
  'interventional_therapy': '',
  'patient_outcome_assessment': '',
  'age': '',
  'gender': '',
}

### Clinical Report
"""

def generate_html_tables(data, selected_fields):
    key_label_map = {
        'age': 'Age',
        'gender': 'Gender',
        'life_style': 'Lifestyle',
        'social_history': 'Social Background',
        'medical_surgical_history': 'Personal',
        'family_history': 'Family Members',
        'signs_symptoms': 'Symptoms',
        'comorbidities': 'Comorbid Conditions',
        'diagnostic_techniques_procedures': 'Diagnostic Procedures',
        'laboratory_values': 'Laboratory Results',
        'pathology': 'Pathology Report',
        'diagnosis': 'Diagnosis',
        'interventional_therapy': 'Interventional Therapy',
        'pharmacological_therapy': 'Pharmacological Therapy',
        'patient_outcome_assessment': 'Patient Outcome',
    }
    
    label_key_map = {v: k for k, v in key_label_map.items()}

    categories = {
        "Personal Information": ["Age", "Gender", "Lifestyle", "Social Background"],
        "Medical History": ["Personal", "Family Members"],
        "Clinical Presentation": ["Symptoms", "Comorbid Conditions"],
        "Medical Assessment": ["Diagnostic Procedures", "Laboratory Results", "Pathology Report"],
        "Diagnosis": ["Diagnosis"],
        "Treatment": ["Interventional Therapy", "Pharmacological Therapy"],
        "Patient Outcome": ["Patient Outcome"],
    }

    def format_bullets(value):
        items = [item.strip() for item in value.split(";") if item.strip()]
        if not items:
            return "<i>Not Available</i>"
        if len(items) == 1:
            return items[0]
        return "<ul style='margin: 0; padding-left: 1em'>" + "".join(f"<li>{item}</li>" for item in items) + "</ul>"

    table_style = (
        "width: 100%;"
        "table-layout: fixed;"
        # "border-collapse: collapse;"
        # "word-wrap: break-word;"
        "height: 100%;"
    )

    th_td_style_first = (
        # "height: 30px;"
        "padding: 8px;"
        "border: 1px solid #ccc;"
        "vertical-align: top;"
        "text-align: left;"
        "overflow: hidden;"
    )
    
    th_td_style_other = (
        "padding: 8px;"
        "border: 1px solid #ccc;"
        "vertical-align: top;"
        "text-align: left;"
    )

    html_tables = []
    for section, labels in categories.items():
        section_fields = [label for label in labels if label in selected_fields]
        if section_fields:
            table_html = f"<h3 style='margin-bottom: 0.5em;'>{section}</h3>"
            table_html += f"<table style='{table_style}'>"
            table_html += f"<tr><th style='{th_td_style_first}; width: 150px;'>Field</th><th style='{th_td_style_first};'>Details</th></tr>"
            for label in section_fields:
                key = label_key_map[label]
                value = data.get(key, "N/A")
                details = "<i>Not Available</i>" if value == "N/A" else format_bullets(value)
                table_html += f"<tr><td style='{th_td_style_other}; width: 150px;'><b>{label}</b></td><td style='{th_td_style_other}'>{details}</td></tr>"
            table_html += "</table>"
            html_tables.append(table_html)

    i = 0
    grouped_html = ""
    while i < len(html_tables):
        num_per_row = 2 if i < 4 else 3
        row_tables = html_tables[i:i+num_per_row]
        grouped_html += (
            "<div style='display: flex; gap: 1em; margin-bottom: 2em; align-items: stretch;'>"
        )
        for table in row_tables:
            grouped_html += (
                "<div style='flex: 1; min-width: 0; display: flex; flex-direction: column;'>"
                f"{table}"
                "</div>"
            )
        grouped_html += "</div>"
        i += num_per_row

    return f"<div style='font-family: sans-serif;'>{grouped_html}</div>"

@spaces.GPU(duration=60)
def summarize(
    text, 
    personal_info, 
    medical_history, 
    clinical_presentation, 
    medical_assessment, 
    diagnosis, 
    treatment, 
    patient_outcome,
):
    
    if not text.strip():
        return "Please enter some text to summarize."

    if text == default_value: 
        response = ['{"life_style": "N/A", "family_history": "N/A", "social_history": "N/A", "medical_surgical_history": "N/A", "signs_symptoms": "Fever; Chest pain; Cough; Progressive dyspnea; Tachypnea; Tachycardia; Decreased breath sounds in both lung bases; Crackles on the left", "comorbidities": "N/A", "diagnostic_techniques_procedures": "Chest X-ray; Echocardiography; Thoracentesis; Laboratory tests; Pleural fluid analysis; Urinary pneumococcal antigen test; Pleural fluid culture", "diagnosis": "Pneumonia; Pericardial effusion; S. pneumoniae infection", "laboratory_values": "White blood cell count: 11.78 \\u00d7 10^9 cells/L (84.3% neutrophils, 4.3% lymphocytes, 9.1% monocytes); Platelet count: 512 \\u00d7 10^9/L; Serum C-reactive protein: 31.27 mg/dL; Serum creatinine: 0.94 mg/dL; Serum sodium: 133 mEq/L; Serum potassium: 3.72 mEq/L; Pleural fluid pH: 7.16; Pleural fluid glucose: 4.5 mg/dL; Pleural fluid proteins: 49.1 g/L; Pleural fluid LDH: 1,385 U/L", "pathology": "N/A", "pharmacological_therapy": "Amoxicillin-clavulanate (2.2 g/8 h, i.v.); Levofloxacin (500 mg twice a day); Ibuprofen (800 mg/day)", "interventional_therapy": "Pericardiocentesis; Thoracentesis", "patient_outcome_assessment": "Nearly complete resolution of alterations on chest X-ray and CT scan", "age": "57 year", "gender": "Male"}']
    else:
        messages = [
            {"role": "system", "content": prompt.strip()},
            {"role": "user", "content": text.strip()},
        ]

        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=False, # only relevant for qwen
        )
    
        model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
        generated_ids = model.generate(
            input_ids=model_inputs["input_ids"],
            attention_mask = model_inputs["attention_mask"],
            max_new_tokens=2048,
            logits_processor=[xgr_logits_processor]
        )
        generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
        ]
    
        response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    try:
        data = ast.literal_eval(response[0])
    except:
        data = {
            'life_style': 'N/A',
            'family_history': 'N/A',
            'social_history': 'N/A',
            'medical_surgical_history': 'N/A',
            'signs_symptoms': 'N/A',
            'comorbidities': 'N/A',
            'diagnostic_techniques_procedures': 'N/A',
            'diagnosis': 'N/A',
            'laboratory_values': 'N/A',
            'pathology': '',
            'pharmacological_therapy': 'N/A',
            'interventional_therapy': 'N/A',
            'patient_outcome_assessment': 'N/A',
            'age': 'N/A',
            'gender': 'N/A',
        }
        
    selected_fields = []
    selected_fields += personal_info
    selected_fields += medical_history
    selected_fields += clinical_presentation
    selected_fields += medical_assessment
    selected_fields += diagnosis
    selected_fields += treatment
    selected_fields += patient_outcome
    
    return generate_html_tables(data, selected_fields)

with gr.Blocks() as demo:

    # need to be combined with `hf_oauth: true` in README.md
    # button = gr.LoginButton("Sign in")

    with gr.Column():
        gr.HTML("""
<div align="center">
  <img src="https://huggingface.co/spaces/gregorlied/medical-text-summarization/resolve/main/assets/LlamaMD-logo.png" alt="LlamaMD Logo" width="120" style="margin-bottom: 10px;">
  <h2><strong>LlamaMD</strong></h2>
  <p><em>Structured Information Extraction from Clinical Reports</em></p>
</div>
""")

    with gr.Tabs():
        with gr.Tab("LLamaMD"):
            with gr.Row():
                input_text = gr.Textbox(
                    label="Clinical Report", 
                    autoscroll=False,
                    lines=15, 
                    max_lines=15, 
                    placeholder="Paste your clinical report here...",
                    value=default_value,
                )
        
            with gr.Accordion("Advanced Settings", open=False):
                with gr.Row():
                    with gr.Column():
                        personal_info = gr.CheckboxGroup(
                            label="Personal Information", 
                            choices=[
                                "Age",
                                "Gender",
                                "Lifestyle", 
                                "Social Background",
                            ],
                            value=[
                                "Age",
                                "Gender",
                                "Lifestyle", 
                                "Social Background",
                            ],
                        )
                    
                    with gr.Column():
                        medical_history = gr.CheckboxGroup(
                            label="Medical History", 
                            choices=[ 
                                "Personal",
                                "Family Members",
                            ],
                            value=[
                                "Personal",
                                "Family Members",
                            ],
                        )
        
                with gr.Row():
                    with gr.Column():
                        clinical_presentation = gr.CheckboxGroup(
                            label="Clinical Presentation", 
                            choices=[
                                "Symptoms",
                                "Comorbid Conditions",
                            ],
                            value=[
                                "Symptoms",
                                "Comorbid Conditions",
                            ],
                        )
                    with gr.Column():
                        medical_assessment = gr.CheckboxGroup(
                            label="Medical Assessment", 
                            choices=[
                                "Diagnostic Procedures",
                                "Laboratory Results",
                                "Pathology Report",       
                            ],
                            value=[
                                "Diagnostic Procedures",
                                "Laboratory Results",
                                "Pathology Report",            
                            ],
                        )
            
                with gr.Row():
                    with gr.Column():
                        diagnosis = gr.CheckboxGroup(
                            label="Diagnosis", 
                            choices=[
                                "Diagnosis",
                            ],
                            value=[
                                "Diagnosis",
                            ],
                        )
                    with gr.Column():
                        treatment = gr.CheckboxGroup(
                            label="Treatment", 
                            choices=[
                                "Interventional Therapy",
                                "Pharmacological Therapy",
                            ],
                            value=[
                                "Interventional Therapy",
                                "Pharmacological Therapy",
                            ],
                        )
                    with gr.Column():
                        patient_outcome = gr.CheckboxGroup(
                            label="Patient Outcome", 
                            choices=[
                                "Patient Outcome",
                            ],
                            value=[
                                "Patient Outcome",
                            ],
                        )               
        
            with gr.Row():
                summarize_btn = gr.Button("Extract")
        
            with gr.Row():
                output_text = gr.HTML()
        
            summarize_btn.click(
                fn=summarize, 
                inputs=[input_text, personal_info, medical_history, clinical_presentation, medical_assessment, diagnosis, treatment, patient_outcome], 
                outputs=output_text, 
                show_progress=True,
            )

        with gr.Tab("Help"):
            gr.Markdown("""## Help
            
### Personal Information        

**Age**: Age of the patient.<br>
**Gender**: Gender of the patient.<br>
**Lifestyle**: Daily habits and activities of the patient (e.g. alcohol consumption, diet, smoking status).<br>
**Social Background**: Social factors of the patient (e.g. housing situation, marital status).<br>

### Medical History

**Personal**: Past medical conditions, previous surgeries or treatments of the patient.<br>
**Family Members**: Relevant medical conditions or genetic disorders in the patient’s family (e.g. cancer, heart disease).<br>

### Clinical Presentation

**Symptoms**: Current symptoms of the patient.<br>
**Comorbid Conditions**: Other medical conditions of the patient that may influence the treatment.<br>

### Medical Assessment

**Diagnostic Procedures**: Description of the diagnostic tests or procedures performed (e.g. X-rays, MRIs)<br>
**Laboratory Results**: Results foom laboratory test (e.g. blood counts, electrolyte levels)<br>
**Pathology Report**: Findings from pathological examinations (e.g. biopsy results)<br>

### Diagnosis

**Diagnosis**: All levels of diagnosis mentioned in the report.<br>

### Treatment

**Interventional Therapy**: Medications prescribed to the patient.<br>
**Pharmacological Therapy**: Information on surgical or non-surgical interventions performed.<br>

### Patient Outcome

**Patient Outcome**: Evaluation of the patient’s health status at the end of treatment.<br>
""")
            
        with gr.Tab("About"):
            gr.Markdown("""## About

LlamaMD is a project developed as part of the "NLP for Social Good" course at TU Berlin.
            
The goal of this project is to perform structured information extraction from clinical reports, helping doctors to have more time for their patients.
                        
The system is based on `meta-llama/Llama-3.2-1B-Instruct`, which has been fine-tuned on the ELMTEX dataset.
""")

if __name__ == "__main__":
    demo.launch()