sum-soap-demo / app.py
jason-moore's picture
one more
f44fff8
raw
history blame
3.75 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.utils import logging
logging.set_verbosity_debug()
logger = logging.get_logger("transformers")
# Load model directly from your Hugging Face repository
def load_model():
tokenizer = AutoTokenizer.from_pretrained("omi-health/sum-small", trust_remote_code=False)
model = AutoModelForCausalLM.from_pretrained("omi-health/sum-small", trust_remote_code=False)
# Move model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
print(f"Using device: {device}")
if device == "cuda":
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"Memory allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")
return model, tokenizer
def generate_soap_note(doctor_patient_conversation):
if not doctor_patient_conversation.strip():
return "Please enter a doctor-patient conversation."
# Create a properly formatted prompt with instructions
prompt = f"""<|user|>
Please generate a structured SOAP (Subjective, Objective, Assessment, Plan) note based on the following doctor-patient conversation:
{doctor_patient_conversation}
<|assistant|>"""
device = "cuda" if torch.cuda.is_available() else "cpu"
# Tokenize and generate with explicit padding settings
inputs = tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=tokenizer.model_max_length
)
inputs = {k: v.to(device) for k, v in inputs.items()}
generate_ids = model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"], # Explicitly pass attention mask
max_length=2048,
num_beams=5,
no_repeat_ngram_size=2,
early_stopping=True
)
# Decode and extract the response part
decoded_response = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
# Extract only the assistant's response (remove the prompt part)
if "<|assistant|>" in decoded_response:
decoded_response = decoded_response.split("<|assistant|>")[1].strip()
logger.debug(f"Decoded response: {decoded_response}")
return decoded_response
# Load model and tokenizer (this will run once when the app starts)
model, tokenizer = load_model()
# Sample conversation for the example
sample_conversation = """
Doctor: Good morning, how are you feeling today?
Patient: Not so great, doctor. I've had this persistent cough for about two weeks now.
Doctor: I'm sorry to hear that. Can you tell me more about the cough? Is it dry or are you coughing up anything?
Patient: It started as a dry cough, but for the past few days I've been coughing up some yellowish phlegm.
Doctor: Do you have any other symptoms like fever, chills, or shortness of breath?
Patient: I had a fever of 100.5°F two days ago. I've been feeling more tired than usual, and sometimes it's a bit hard to catch my breath after coughing a lot.
"""
# Create Gradio interface
demo = gr.Interface(
fn=generate_soap_note,
inputs=gr.Textbox(
lines=15,
placeholder="Enter doctor-patient conversation here...",
label="Doctor-Patient Conversation",
value=sample_conversation
),
outputs=gr.Textbox(
label="Generated SOAP Note",
lines=15
),
title="Medical SOAP Note Generator",
description="Enter a doctor-patient conversation to generate a structured SOAP note using OMI Health's task-specific model.",
examples=[[sample_conversation]],
allow_flagging="never"
)
# Launch the app
demo.launch()