Spaces:
Paused
Paused
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from transformers.utils import logging | |
logging.set_verbosity_debug() | |
logger = logging.get_logger("transformers") | |
# Load model directly from your Hugging Face repository | |
def load_model(): | |
tokenizer = AutoTokenizer.from_pretrained("omi-health/sum-small", trust_remote_code=False) | |
model = AutoModelForCausalLM.from_pretrained("omi-health/sum-small", trust_remote_code=False) | |
# Move model to GPU if available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = model.to(device) | |
print(f"Using device: {device}") | |
if device == "cuda": | |
print(f"GPU: {torch.cuda.get_device_name(0)}") | |
print(f"Memory allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB") | |
return model, tokenizer | |
def generate_soap_note(doctor_patient_conversation): | |
if not doctor_patient_conversation.strip(): | |
return "Please enter a doctor-patient conversation." | |
# Create a properly formatted prompt with instructions | |
prompt = f"""<|user|> | |
Please generate a structured SOAP (Subjective, Objective, Assessment, Plan) note based on the following doctor-patient conversation: | |
{doctor_patient_conversation} | |
<|assistant|>""" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Tokenize and generate with explicit padding settings | |
inputs = tokenizer( | |
prompt, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=tokenizer.model_max_length | |
) | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
generate_ids = model.generate( | |
inputs["input_ids"], | |
attention_mask=inputs["attention_mask"], # Explicitly pass attention mask | |
max_length=2048, | |
num_beams=5, | |
no_repeat_ngram_size=2, | |
early_stopping=True | |
) | |
# Decode and extract the response part | |
decoded_response = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
# Extract only the assistant's response (remove the prompt part) | |
if "<|assistant|>" in decoded_response: | |
decoded_response = decoded_response.split("<|assistant|>")[1].strip() | |
logger.debug(f"Decoded response: {decoded_response}") | |
return decoded_response | |
# Load model and tokenizer (this will run once when the app starts) | |
model, tokenizer = load_model() | |
# Sample conversation for the example | |
sample_conversation = """ | |
Doctor: Good morning, how are you feeling today? | |
Patient: Not so great, doctor. I've had this persistent cough for about two weeks now. | |
Doctor: I'm sorry to hear that. Can you tell me more about the cough? Is it dry or are you coughing up anything? | |
Patient: It started as a dry cough, but for the past few days I've been coughing up some yellowish phlegm. | |
Doctor: Do you have any other symptoms like fever, chills, or shortness of breath? | |
Patient: I had a fever of 100.5°F two days ago. I've been feeling more tired than usual, and sometimes it's a bit hard to catch my breath after coughing a lot. | |
""" | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=generate_soap_note, | |
inputs=gr.Textbox( | |
lines=15, | |
placeholder="Enter doctor-patient conversation here...", | |
label="Doctor-Patient Conversation", | |
value=sample_conversation | |
), | |
outputs=gr.Textbox( | |
label="Generated SOAP Note", | |
lines=15 | |
), | |
title="Medical SOAP Note Generator", | |
description="Enter a doctor-patient conversation to generate a structured SOAP note using OMI Health's task-specific model.", | |
examples=[[sample_conversation]], | |
allow_flagging="never" | |
) | |
# Launch the app | |
demo.launch() |