Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,18 +1,48 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
import
|
|
|
|
| 3 |
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
iface = gr.Interface(
|
| 11 |
-
fn=
|
| 12 |
-
inputs=gr.Textbox(placeholder="
|
| 13 |
-
outputs=
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
| 16 |
)
|
| 17 |
|
| 18 |
if __name__ == "__main__":
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 4 |
|
| 5 |
+
# Load the model and tokenizer
|
| 6 |
+
MODEL_NAME = "Lech-Iyoko/bert-symptom-checker"
|
| 7 |
+
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
|
| 8 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 9 |
|
| 10 |
+
# Define label mapping
|
| 11 |
+
LABEL_MAPPING = {
|
| 12 |
+
0: "Sprains and Strains", 1: "Fractures", 2: "Contusions (Bruises)",
|
| 13 |
+
3: "Cuts and Lacerations", 4: "Concussions", 5: "Burns", 6: "Dislocations",
|
| 14 |
+
7: "Abrasions (Scrapes)", 8: "Whiplash Injuries", 9: "Eye Injuries", 10: "Puncture Wounds",
|
| 15 |
+
11: "Bites and Stings", 12: "Back Injuries", 13: "Broken Nose", 14: "Knee Injuries",
|
| 16 |
+
15: "Ankle Injuries", 16: "Shoulder Injuries", 17: "Wrist Injuries", 18: "Chest Injuries",
|
| 17 |
+
19: "Head Injuries", 20: "Acne", 21: "Allergies", 22: "Alzheimer's Disease", 23: "Anemia",
|
| 18 |
+
24: "Anxiety Disorders", 25: "Arthritis", 26: "Asthma", 27: "Back Pain", 28: "Bipolar Disorder",
|
| 19 |
+
29: "Bronchitis", 30: "Cataracts", 31: "Chickenpox", 32: "COPD", 33: "Common Cold",
|
| 20 |
+
34: "Conjunctivitis (Pink Eye)", 35: "Constipation", 36: "Coronary Heart Disease",
|
| 21 |
+
37: "Depression", 38: "Diabetes Type 1", 39: "Diabetes Type 2", 40: "Diarrhea",
|
| 22 |
+
41: "Ear Infections", 42: "Eczema", 43: "Fibromyalgia", 44: "Flu", 45: "GERD", 46: "Gout",
|
| 23 |
+
47: "Hay Fever (Allergic Rhinitis)", 48: "Headaches", 49: "High Blood Pressure (Hypertension)",
|
| 24 |
+
50: "High Cholesterol (Hypercholesterolemia)", 51: "IBS", 52: "Kidney Stones", 53: "Migraines",
|
| 25 |
+
54: "Obesity", 55: "Osteoarthritis", 56: "Psoriasis", 57: "UTI", 58: "Other"
|
| 26 |
+
}
|
| 27 |
|
| 28 |
+
def predict_symptom(text):
|
| 29 |
+
tokens = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
|
| 30 |
+
with torch.no_grad():
|
| 31 |
+
outputs = model(**tokens)
|
| 32 |
+
predicted_index = outputs.logits.argmax().item()
|
| 33 |
+
confidence = torch.softmax(outputs.logits, dim=1)[0][predicted_index].item()
|
| 34 |
+
return LABEL_MAPPING.get(predicted_index, "Unknown"), confidence
|
| 35 |
+
|
| 36 |
+
# Define Gradio interface
|
| 37 |
iface = gr.Interface(
|
| 38 |
+
fn=predict_symptom,
|
| 39 |
+
inputs=gr.inputs.Textbox(lines=2, placeholder="Describe your symptoms here..."),
|
| 40 |
+
outputs=[
|
| 41 |
+
gr.outputs.Textbox(label="Predicted Condition"),
|
| 42 |
+
gr.outputs.Textbox(label="Confidence Level")
|
| 43 |
+
],
|
| 44 |
+
title="Symptom Checker",
|
| 45 |
+
description="Enter your symptoms to get a predicted medical condition."
|
| 46 |
)
|
| 47 |
|
| 48 |
if __name__ == "__main__":
|