ayushsinha's picture
Upload 2 files
7375de2 verified
import gradio as gr
import torch
from transformers import pipeline
# Load the NER model
model_name = "AventIQ-AI/bert-medical-entity-extraction"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Loading model...")
ner_pipeline = pipeline("ner", model=model_name, tokenizer=model_name, aggregation_strategy="simple", device=0 if torch.cuda.is_available() else -1)
# Define entity mapping based on README
entity_mapping = {
"LABEL_1": "Symptom",
"LABEL_2": "Disease",
"LABEL_3": "Medication",
"LABEL_4": "Treatment",
"LABEL_5": "Anatomy",
"LABEL_6": "Medical Procedure"
}
def extract_medical_entities(text):
"""Extract relevant medical entities from the input text."""
if not text.strip():
return "⚠️ Please enter a valid medical text."
print(f"Processing: {text}")
entities = ner_pipeline(text)
# Filter out non-entity labels (e.g., "O" or punctuation)
relevant_entities = [
f"πŸ“Œ **{entity['word'].replace('##', '')}** β†’ `{entity_mapping.get(entity['entity_group'], entity['entity_group'])}`"
for entity in entities if entity['entity_group'] in entity_mapping
]
response = "\n".join(relevant_entities) if relevant_entities else "⚠️ No relevant medical entities detected."
print(f"Response: {response}")
return response
# Create Gradio Interface
iface = gr.Interface(
fn=extract_medical_entities,
inputs=gr.Textbox(label="πŸ“ Enter Medical Text", placeholder="Type or paste a medical report...", lines=3),
outputs=gr.Textbox(label="πŸ₯ Extracted Medical Entities", placeholder="Detected medical terms will appear here...", lines=5),
title="πŸ”¬ Medical Entity Extraction",
description="πŸ’‰ Enter a medical-related text, and the AI will extract **diseases, symptoms, medications, and treatments.**",
theme="compact",
allow_flagging="never",
examples=[
["The patient is diagnosed with diabetes and prescribed metformin."],
["Symptoms include fever, sore throat, and fatigue."],
["He underwent a knee replacement surgery at Mayo Clinic."]
],
)
if __name__ == "__main__":
iface.launch()