Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
# Load the model and tokenizer | |
MODEL_NAME = "Lech-Iyoko/bert-symptom-checker" | |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
# Define label mapping | |
LABEL_MAPPING = { | |
0: "Sprains and Strains", 1: "Fractures", 2: "Contusions (Bruises)", | |
3: "Cuts and Lacerations", 4: "Concussions", 5: "Burns", 6: "Dislocations", | |
7: "Abrasions (Scrapes)", 8: "Whiplash Injuries", 9: "Eye Injuries", 10: "Puncture Wounds", | |
11: "Bites and Stings", 12: "Back Injuries", 13: "Broken Nose", 14: "Knee Injuries", | |
15: "Ankle Injuries", 16: "Shoulder Injuries", 17: "Wrist Injuries", 18: "Chest Injuries", | |
19: "Head Injuries", 20: "Acne", 21: "Allergies", 22: "Alzheimer's Disease", 23: "Anemia", | |
24: "Anxiety Disorders", 25: "Arthritis", 26: "Asthma", 27: "Back Pain", 28: "Bipolar Disorder", | |
29: "Bronchitis", 30: "Cataracts", 31: "Chickenpox", 32: "COPD", 33: "Common Cold", | |
34: "Conjunctivitis (Pink Eye)", 35: "Constipation", 36: "Coronary Heart Disease", | |
37: "Depression", 38: "Diabetes Type 1", 39: "Diabetes Type 2", 40: "Diarrhea", | |
41: "Ear Infections", 42: "Eczema", 43: "Fibromyalgia", 44: "Flu", 45: "GERD", 46: "Gout", | |
47: "Hay Fever (Allergic Rhinitis)", 48: "Headaches", 49: "High Blood Pressure (Hypertension)", | |
50: "High Cholesterol (Hypercholesterolemia)", 51: "IBS", 52: "Kidney Stones", 53: "Migraines", | |
54: "Obesity", 55: "Osteoarthritis", 56: "Psoriasis", 57: "UTI", 58: "Other" | |
} | |
def predict_symptom(text): | |
tokens = tokenizer(text, padding=True, truncation=True, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**tokens) | |
predicted_index = outputs.logits.argmax().item() | |
confidence = torch.softmax(outputs.logits, dim=1)[0][predicted_index].item() | |
return LABEL_MAPPING.get(predicted_index, "Unknown"), confidence | |
# Define Gradio interface | |
iface = gr.Interface( | |
fn=predict_symptom, | |
inputs=gr.Textbox(lines=2, placeholder="Describe your symptoms here..."), | |
outputs=[ | |
gr.Textbox(label="Predicted Condition"), | |
gr.Textbox(label="Confidence Level") | |
], | |
title="Symptom Checker", | |
description="Enter your symptoms to get a predicted medical condition." | |
) | |
if __name__ == "__main__": | |
iface.launch() | |