File size: 2,413 Bytes
88009c2
b13ff14
 
88009c2
b13ff14
 
 
 
88009c2
b13ff14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88009c2
b13ff14
 
 
 
 
 
 
 
 
88009c2
b13ff14
8a6d080
b13ff14
8a6d080
 
b13ff14
 
 
88009c2
 
8a6d080
88009c2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# Load the model and tokenizer
MODEL_NAME = "Lech-Iyoko/bert-symptom-checker"
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Define label mapping
LABEL_MAPPING = {
    0: "Sprains and Strains", 1: "Fractures", 2: "Contusions (Bruises)",
    3: "Cuts and Lacerations", 4: "Concussions", 5: "Burns", 6: "Dislocations",
    7: "Abrasions (Scrapes)", 8: "Whiplash Injuries", 9: "Eye Injuries", 10: "Puncture Wounds",
    11: "Bites and Stings", 12: "Back Injuries", 13: "Broken Nose", 14: "Knee Injuries",
    15: "Ankle Injuries", 16: "Shoulder Injuries", 17: "Wrist Injuries", 18: "Chest Injuries",
    19: "Head Injuries", 20: "Acne", 21: "Allergies", 22: "Alzheimer's Disease", 23: "Anemia",
    24: "Anxiety Disorders", 25: "Arthritis", 26: "Asthma", 27: "Back Pain", 28: "Bipolar Disorder",
    29: "Bronchitis", 30: "Cataracts", 31: "Chickenpox", 32: "COPD", 33: "Common Cold",
    34: "Conjunctivitis (Pink Eye)", 35: "Constipation", 36: "Coronary Heart Disease",
    37: "Depression", 38: "Diabetes Type 1", 39: "Diabetes Type 2", 40: "Diarrhea",
    41: "Ear Infections", 42: "Eczema", 43: "Fibromyalgia", 44: "Flu", 45: "GERD", 46: "Gout",
    47: "Hay Fever (Allergic Rhinitis)", 48: "Headaches", 49: "High Blood Pressure (Hypertension)",
    50: "High Cholesterol (Hypercholesterolemia)", 51: "IBS", 52: "Kidney Stones", 53: "Migraines",
    54: "Obesity", 55: "Osteoarthritis", 56: "Psoriasis", 57: "UTI", 58: "Other"
}

def predict_symptom(text):
    tokens = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**tokens)
    predicted_index = outputs.logits.argmax().item()
    confidence = torch.softmax(outputs.logits, dim=1)[0][predicted_index].item()
    return LABEL_MAPPING.get(predicted_index, "Unknown"), confidence

# Define Gradio interface
iface = gr.Interface(
    fn=predict_symptom,
    inputs=gr.Textbox(lines=2, placeholder="Describe your symptoms here..."),
    outputs=[
        gr.Textbox(label="Predicted Condition"),
        gr.Textbox(label="Confidence Level")
    ],
    title="Symptom Checker",
    description="Enter your symptoms to get a predicted medical condition."
)


if __name__ == "__main__":
    iface.launch()