Lech-Iyoko commited on
Commit
b13ff14
·
verified ·
1 Parent(s): ca24f90

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -10
app.py CHANGED
@@ -1,18 +1,48 @@
1
  import gradio as gr
2
- import requests
 
3
 
4
- API_URL = "https://curly-space-waffle-x7g59v6qg7vcp6g9-8000.app.github.dev/predict"
 
 
 
5
 
6
- def predict_symptoms(symptoms):
7
- response = requests.post(API_URL, json={"symptoms": symptoms})
8
- return response.json().get("prediction", "Error: Could not fetch prediction.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
 
 
 
 
 
 
 
 
 
10
  iface = gr.Interface(
11
- fn=predict_symptoms,
12
- inputs=gr.Textbox(placeholder="Enter symptoms..."),
13
- outputs="text",
14
- title="AI-Powered Symptom Checker",
15
- description="Enter your symptoms and get possible conditions based on AI predictions."
 
 
 
16
  )
17
 
18
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
 
5
+ # Load the model and tokenizer
6
+ MODEL_NAME = "Lech-Iyoko/bert-symptom-checker"
7
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
8
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
9
 
10
+ # Define label mapping
11
+ LABEL_MAPPING = {
12
+ 0: "Sprains and Strains", 1: "Fractures", 2: "Contusions (Bruises)",
13
+ 3: "Cuts and Lacerations", 4: "Concussions", 5: "Burns", 6: "Dislocations",
14
+ 7: "Abrasions (Scrapes)", 8: "Whiplash Injuries", 9: "Eye Injuries", 10: "Puncture Wounds",
15
+ 11: "Bites and Stings", 12: "Back Injuries", 13: "Broken Nose", 14: "Knee Injuries",
16
+ 15: "Ankle Injuries", 16: "Shoulder Injuries", 17: "Wrist Injuries", 18: "Chest Injuries",
17
+ 19: "Head Injuries", 20: "Acne", 21: "Allergies", 22: "Alzheimer's Disease", 23: "Anemia",
18
+ 24: "Anxiety Disorders", 25: "Arthritis", 26: "Asthma", 27: "Back Pain", 28: "Bipolar Disorder",
19
+ 29: "Bronchitis", 30: "Cataracts", 31: "Chickenpox", 32: "COPD", 33: "Common Cold",
20
+ 34: "Conjunctivitis (Pink Eye)", 35: "Constipation", 36: "Coronary Heart Disease",
21
+ 37: "Depression", 38: "Diabetes Type 1", 39: "Diabetes Type 2", 40: "Diarrhea",
22
+ 41: "Ear Infections", 42: "Eczema", 43: "Fibromyalgia", 44: "Flu", 45: "GERD", 46: "Gout",
23
+ 47: "Hay Fever (Allergic Rhinitis)", 48: "Headaches", 49: "High Blood Pressure (Hypertension)",
24
+ 50: "High Cholesterol (Hypercholesterolemia)", 51: "IBS", 52: "Kidney Stones", 53: "Migraines",
25
+ 54: "Obesity", 55: "Osteoarthritis", 56: "Psoriasis", 57: "UTI", 58: "Other"
26
+ }
27
 
28
+ def predict_symptom(text):
29
+ tokens = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
30
+ with torch.no_grad():
31
+ outputs = model(**tokens)
32
+ predicted_index = outputs.logits.argmax().item()
33
+ confidence = torch.softmax(outputs.logits, dim=1)[0][predicted_index].item()
34
+ return LABEL_MAPPING.get(predicted_index, "Unknown"), confidence
35
+
36
+ # Define Gradio interface
37
  iface = gr.Interface(
38
+ fn=predict_symptom,
39
+ inputs=gr.inputs.Textbox(lines=2, placeholder="Describe your symptoms here..."),
40
+ outputs=[
41
+ gr.outputs.Textbox(label="Predicted Condition"),
42
+ gr.outputs.Textbox(label="Confidence Level")
43
+ ],
44
+ title="Symptom Checker",
45
+ description="Enter your symptoms to get a predicted medical condition."
46
  )
47
 
48
  if __name__ == "__main__":