Spaces:
Paused
Paused
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
from transformers.utils import logging | |
logging.set_verbosity_debug() | |
logger = logging.get_logger("transformers") | |
# Load model directly from your Hugging Face repository | |
def load_model(): | |
try: | |
# First try loading with half precision to save memory | |
tokenizer = AutoTokenizer.from_pretrained("omi-health/sum-small", trust_remote_code=False) | |
# Try to use GPU with half precision first | |
if torch.cuda.is_available(): | |
model = AutoModelForCausalLM.from_pretrained( | |
"omi-health/sum-small", | |
trust_remote_code=False, | |
device_map="auto" # Let the library decide best device mapping | |
) | |
print(f"GPU: {torch.cuda.get_device_name(0)}") | |
print(f"Memory allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB") | |
else: | |
# Fall back to CPU | |
model = AutoModelForCausalLM.from_pretrained("omi-health/sum-small", trust_remote_code=False) | |
print("Using CPU (no GPU available)") | |
except Exception as e: | |
print(f"Error loading model with GPU/half-precision: {e}") | |
print("Falling back to CPU...") | |
model = AutoModelForCausalLM.from_pretrained("omi-health/sum-small", trust_remote_code=False) | |
return model, tokenizer | |
def generate_soap_note(doctor_patient_conversation): | |
if not doctor_patient_conversation.strip(): | |
return "Please enter a doctor-patient conversation." | |
try: | |
system_prompt = f""" | |
Please generate a structured SOAP (Subjective, Objective, Assessment, Plan) note based on the following doctor-patient conversation: | |
Include all relevant details in the SOAP note, and ensure that the note is clear and concise. Address each of the following: | |
Subjective: Patient's reported symptoms and concerns. | |
Objective: Observations and findings from the doctor's examination. | |
Assessment: Doctor's assessment of the patient's condition. | |
Plan: Recommended next steps for the patient's care. | |
Do not include any additional information or context outside of the SOAP note. Do not include the original prompt or conversation in the output.""" | |
messages = [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": doctor_patient_conversation} | |
] | |
device = next(model.parameters()).device # Get device from model | |
pipe = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
) | |
generation_args = { | |
"max_new_tokens": 500, | |
"return_full_text": False, | |
"temperature": 0.0, | |
"do_sample": False, | |
"num_beams": 1, | |
} | |
import gc | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
output = pipe(messages, **generation_args) | |
return output[0]['generated_text'] | |
except RuntimeError as e: | |
if "CUDA out of memory" in str(e): | |
return "Error: GPU ran out of memory. Try with a shorter conversation or on a machine with more GPU memory." | |
else: | |
return f"Error during generation: {str(e)}" | |
# Load model and tokenizer (this will run once when the app starts) | |
model, tokenizer = load_model() | |
# Sample conversation for the example | |
sample_conversation = """ | |
Doctor: Good morning, how are you feeling today? | |
Patient: Not so great, doctor. I've had this persistent cough for about two weeks now. | |
Doctor: I'm sorry to hear that. Can you tell me more about the cough? Is it dry or are you coughing up anything? | |
Patient: It started as a dry cough, but for the past few days I've been coughing up some yellowish phlegm. | |
Doctor: Do you have any other symptoms like fever, chills, or shortness of breath? | |
Patient: I had a fever of 100.5°F two days ago. I've been feeling more tired than usual, and sometimes it's a bit hard to catch my breath after coughing a lot. | |
""" | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=generate_soap_note, | |
inputs=gr.Textbox( | |
lines=15, | |
placeholder="Enter doctor-patient conversation here...", | |
label="Doctor-Patient Conversation", | |
value=sample_conversation | |
), | |
outputs=gr.Textbox( | |
label="Generated SOAP Note", | |
lines=15 | |
), | |
title="Medical SOAP Note Generator", | |
description="Enter a doctor-patient conversation to generate a structured SOAP note using OMI Health's task-specific model.", | |
examples=[[sample_conversation]], | |
allow_flagging="never" | |
) | |
# Launch the app | |
demo.launch() |