File size: 4,763 Bytes
9891731
34b3b8f
3eff2f8
1296b9e
 
 
 
 
9891731
 
 
0fcb40c
 
 
 
 
 
 
 
 
3eff2f8
0fcb40c
 
 
 
 
 
 
 
 
 
 
 
 
 
9891731
 
 
 
 
 
0fcb40c
1852a35
 
ca4db09
9891731
c3da6a7
 
 
 
 
 
1852a35
749b99f
1852a35
 
 
 
0fcb40c
9891731
1852a35
 
 
 
 
 
 
 
 
 
 
 
 
3eff2f8
1852a35
 
 
 
 
0fcb40c
1852a35
ca4db09
0fcb40c
 
 
 
 
9891731
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from transformers.utils import logging

logging.set_verbosity_debug()

logger = logging.get_logger("transformers")

# Load model directly from your Hugging Face repository
def load_model():    
    try:
        # First try loading with half precision to save memory
        tokenizer = AutoTokenizer.from_pretrained("omi-health/sum-small", trust_remote_code=False)
        
        # Try to use GPU with half precision first
        if torch.cuda.is_available():
            model = AutoModelForCausalLM.from_pretrained(
                "omi-health/sum-small", 
                trust_remote_code=False,
                
                device_map="auto"  # Let the library decide best device mapping
            )
            print(f"GPU: {torch.cuda.get_device_name(0)}")
            print(f"Memory allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")
        else:
            # Fall back to CPU
            model = AutoModelForCausalLM.from_pretrained("omi-health/sum-small", trust_remote_code=False)
            print("Using CPU (no GPU available)")
            
    except Exception as e:
        print(f"Error loading model with GPU/half-precision: {e}")
        print("Falling back to CPU...")
        model = AutoModelForCausalLM.from_pretrained("omi-health/sum-small", trust_remote_code=False)
        
    return model, tokenizer

def generate_soap_note(doctor_patient_conversation):
    if not doctor_patient_conversation.strip():
        return "Please enter a doctor-patient conversation."
    
    try:

        system_prompt = f"""
Please generate a structured SOAP (Subjective, Objective, Assessment, Plan) note based on the following doctor-patient conversation:

Include all relevant details in the SOAP note, and ensure that the note is clear and concise. Address each of the following:
Subjective: Patient's reported symptoms and concerns.
Objective: Observations and findings from the doctor's examination.
Assessment: Doctor's assessment of the patient's condition.
Plan: Recommended next steps for the patient's care.

Do not include any additional information or context outside of the SOAP note. Do not include the original prompt or conversation in the output."""

        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": doctor_patient_conversation}
        ]
        device = next(model.parameters()).device  # Get device from model

        pipe = pipeline( 
            "text-generation", 
            model=model, 
            tokenizer=tokenizer, 
        ) 

        generation_args = { 
            "max_new_tokens": 500, 
            "return_full_text": False, 
            "temperature": 0.0, 
            "do_sample": False, 
            "num_beams": 1,
        } 

        import gc
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        output = pipe(messages, **generation_args) 
        
        return output[0]['generated_text']
    
    except RuntimeError as e:
        if "CUDA out of memory" in str(e):
            return "Error: GPU ran out of memory. Try with a shorter conversation or on a machine with more GPU memory."
        else:
            return f"Error during generation: {str(e)}"
    
# Load model and tokenizer (this will run once when the app starts)
model, tokenizer = load_model()

# Sample conversation for the example
sample_conversation = """
Doctor: Good morning, how are you feeling today?
Patient: Not so great, doctor. I've had this persistent cough for about two weeks now.
Doctor: I'm sorry to hear that. Can you tell me more about the cough? Is it dry or are you coughing up anything?
Patient: It started as a dry cough, but for the past few days I've been coughing up some yellowish phlegm.
Doctor: Do you have any other symptoms like fever, chills, or shortness of breath?
Patient: I had a fever of 100.5°F two days ago. I've been feeling more tired than usual, and sometimes it's a bit hard to catch my breath after coughing a lot.
"""

# Create Gradio interface
demo = gr.Interface(
    fn=generate_soap_note,
    inputs=gr.Textbox(
        lines=15, 
        placeholder="Enter doctor-patient conversation here...",
        label="Doctor-Patient Conversation",
        value=sample_conversation
    ),
    outputs=gr.Textbox(
        label="Generated SOAP Note", 
        lines=15
    ),
    title="Medical SOAP Note Generator",
    description="Enter a doctor-patient conversation to generate a structured SOAP note using OMI Health's task-specific model.",
    examples=[[sample_conversation]],
    allow_flagging="never"
)

# Launch the app
demo.launch()