Symptom_Checker / app.py
Lech-Iyoko's picture
Update app.py
d4fa757 verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Load the model and tokenizer
MODEL_NAME = "Lech-Iyoko/bert-symptom-checker"
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Define label mapping
LABEL_MAPPING = {
0: "Sprains and Strains", 1: "Fractures", 2: "Contusions (Bruises)",
3: "Cuts and Lacerations", 4: "Concussions", 5: "Burns", 6: "Dislocations",
7: "Abrasions (Scrapes)", 8: "Whiplash Injuries", 9: "Eye Injuries", 10: "Puncture Wounds",
11: "Bites and Stings", 12: "Back Injuries", 13: "Broken Nose", 14: "Knee Injuries",
15: "Ankle Injuries", 16: "Shoulder Injuries", 17: "Wrist Injuries", 18: "Chest Injuries",
19: "Head Injuries", 20: "Acne", 21: "Allergies", 22: "Alzheimer's Disease", 23: "Anemia",
24: "Anxiety Disorders", 25: "Arthritis", 26: "Asthma", 27: "Back Pain", 28: "Bipolar Disorder",
29: "Bronchitis", 30: "Cataracts", 31: "Chickenpox", 32: "COPD", 33: "Common Cold",
34: "Conjunctivitis (Pink Eye)", 35: "Constipation", 36: "Coronary Heart Disease",
37: "Depression", 38: "Diabetes Type 1", 39: "Diabetes Type 2", 40: "Diarrhea",
41: "Ear Infections", 42: "Eczema", 43: "Fibromyalgia", 44: "Flu", 45: "GERD", 46: "Gout",
47: "Hay Fever (Allergic Rhinitis)", 48: "Headaches", 49: "High Blood Pressure (Hypertension)",
50: "High Cholesterol (Hypercholesterolemia)", 51: "IBS", 52: "Kidney Stones", 53: "Migraines",
54: "Obesity", 55: "Osteoarthritis", 56: "Psoriasis", 57: "UTI", 58: "Other"
}
def predict_symptom(text):
tokens = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
outputs = model(**tokens)
predicted_index = outputs.logits.argmax().item()
confidence = torch.softmax(outputs.logits, dim=1)[0][predicted_index].item()
return LABEL_MAPPING.get(predicted_index, "Unknown"), confidence
# Define Gradio interface
iface = gr.Interface(
fn=predict_symptom,
inputs=gr.Textbox(lines=2, placeholder="Describe your symptoms here..."),
outputs=[
gr.Textbox(label="Predicted Condition"),
gr.Textbox(label="Confidence Level")
],
title="Symptom Checker",
description="Enter your symptoms to get a predicted medical condition."
)
if __name__ == "__main__":
iface.launch()