import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.utils import logging logging.set_verbosity_debug() logger = logging.get_logger("transformers") # Load model directly from your Hugging Face repository def load_model(): tokenizer = AutoTokenizer.from_pretrained("omi-health/sum-small", trust_remote_code=False) model = AutoModelForCausalLM.from_pretrained("omi-health/sum-small", trust_remote_code=False) # Move model to GPU if available device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) print(f"Using device: {device}") if device == "cuda": print(f"GPU: {torch.cuda.get_device_name(0)}") print(f"Memory allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB") return model, tokenizer def generate_soap_note(doctor_patient_conversation): if not doctor_patient_conversation.strip(): return "Please enter a doctor-patient conversation." # Create a properly formatted prompt with instructions prompt = f"""<|user|> Please generate a structured SOAP (Subjective, Objective, Assessment, Plan) note based on the following doctor-patient conversation: {doctor_patient_conversation} <|assistant|>""" device = "cuda" if torch.cuda.is_available() else "cpu" # Tokenize and generate with explicit padding settings inputs = tokenizer( prompt, return_tensors="pt", padding=True, truncation=True, max_length=tokenizer.model_max_length ) inputs = {k: v.to(device) for k, v in inputs.items()} generate_ids = model.generate( inputs["input_ids"], attention_mask=inputs["attention_mask"], # Explicitly pass attention mask max_length=2048, num_beams=5, no_repeat_ngram_size=2, early_stopping=True ) # Decode and extract the response part decoded_response = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] # Extract only the assistant's response (remove the prompt part) if "<|assistant|>" in decoded_response: decoded_response = decoded_response.split("<|assistant|>")[1].strip() logger.debug(f"Decoded response: {decoded_response}") return decoded_response # Load model and tokenizer (this will run once when the app starts) model, tokenizer = load_model() # Sample conversation for the example sample_conversation = """ Doctor: Good morning, how are you feeling today? Patient: Not so great, doctor. I've had this persistent cough for about two weeks now. Doctor: I'm sorry to hear that. Can you tell me more about the cough? Is it dry or are you coughing up anything? Patient: It started as a dry cough, but for the past few days I've been coughing up some yellowish phlegm. Doctor: Do you have any other symptoms like fever, chills, or shortness of breath? Patient: I had a fever of 100.5°F two days ago. I've been feeling more tired than usual, and sometimes it's a bit hard to catch my breath after coughing a lot. """ # Create Gradio interface demo = gr.Interface( fn=generate_soap_note, inputs=gr.Textbox( lines=15, placeholder="Enter doctor-patient conversation here...", label="Doctor-Patient Conversation", value=sample_conversation ), outputs=gr.Textbox( label="Generated SOAP Note", lines=15 ), title="Medical SOAP Note Generator", description="Enter a doctor-patient conversation to generate a structured SOAP note using OMI Health's task-specific model.", examples=[[sample_conversation]], allow_flagging="never" ) # Launch the app demo.launch()