import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification # Load the model and tokenizer MODEL_NAME = "Lech-Iyoko/bert-symptom-checker" model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) # Define label mapping LABEL_MAPPING = { 0: "Sprains and Strains", 1: "Fractures", 2: "Contusions (Bruises)", 3: "Cuts and Lacerations", 4: "Concussions", 5: "Burns", 6: "Dislocations", 7: "Abrasions (Scrapes)", 8: "Whiplash Injuries", 9: "Eye Injuries", 10: "Puncture Wounds", 11: "Bites and Stings", 12: "Back Injuries", 13: "Broken Nose", 14: "Knee Injuries", 15: "Ankle Injuries", 16: "Shoulder Injuries", 17: "Wrist Injuries", 18: "Chest Injuries", 19: "Head Injuries", 20: "Acne", 21: "Allergies", 22: "Alzheimer's Disease", 23: "Anemia", 24: "Anxiety Disorders", 25: "Arthritis", 26: "Asthma", 27: "Back Pain", 28: "Bipolar Disorder", 29: "Bronchitis", 30: "Cataracts", 31: "Chickenpox", 32: "COPD", 33: "Common Cold", 34: "Conjunctivitis (Pink Eye)", 35: "Constipation", 36: "Coronary Heart Disease", 37: "Depression", 38: "Diabetes Type 1", 39: "Diabetes Type 2", 40: "Diarrhea", 41: "Ear Infections", 42: "Eczema", 43: "Fibromyalgia", 44: "Flu", 45: "GERD", 46: "Gout", 47: "Hay Fever (Allergic Rhinitis)", 48: "Headaches", 49: "High Blood Pressure (Hypertension)", 50: "High Cholesterol (Hypercholesterolemia)", 51: "IBS", 52: "Kidney Stones", 53: "Migraines", 54: "Obesity", 55: "Osteoarthritis", 56: "Psoriasis", 57: "UTI", 58: "Other" } def predict_symptom(text): tokens = tokenizer(text, padding=True, truncation=True, return_tensors="pt") with torch.no_grad(): outputs = model(**tokens) predicted_index = outputs.logits.argmax().item() confidence = torch.softmax(outputs.logits, dim=1)[0][predicted_index].item() return LABEL_MAPPING.get(predicted_index, "Unknown"), confidence # Define Gradio interface iface = gr.Interface( fn=predict_symptom, inputs=gr.Textbox(lines=2, placeholder="Describe your symptoms here..."), outputs=[ gr.Textbox(label="Predicted Condition"), gr.Textbox(label="Confidence Level") ], title="Symptom Checker", description="Enter your symptoms to get a predicted medical condition." ) if __name__ == "__main__": iface.launch()