sheikh987 commited on
Commit
013b924
·
verified ·
1 Parent(s): e7ca9ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -19
app.py CHANGED
@@ -39,12 +39,16 @@ except Exception as e:
39
  # --- Download and Load Classification Model (EfficientNet) ---
40
  try:
41
  CLASS_REPO_ID = "sheikh987/efficientnet-B4"
 
42
  CLASS_MODEL_FILENAME = "efficientnet_b4_augmented_best.pth"
 
43
  print(f"--> Downloading classification model from: {CLASS_REPO_ID}")
44
  class_model_path = hf_hub_download(repo_id=CLASS_REPO_ID, filename=CLASS_MODEL_FILENAME)
45
 
46
  NUM_CLASSES = 7
47
- classification_model = timm.create_model('efficientnet_b3', pretrained=False, num_classes=NUM_CLASSES).to(DEVICE)
 
 
48
  classification_model.load_state_dict(torch.load(class_model_path, map_location=DEVICE))
49
  classification_model.eval()
50
  print(" Classification model loaded successfully.")
@@ -56,11 +60,14 @@ except Exception as e:
56
 
57
  # --- Load Knowledge Base and Labels ---
58
  try:
59
- with open('knowledge_base.json', 'r') as f:
 
60
  knowledge_base = json.load(f)
61
- print("--> Knowledge base loaded.")
62
- except FileNotFoundError:
63
- raise gr.Error("knowledge_base.json not found. Make sure it has been uploaded to the Space.")
 
 
64
 
65
  idx_to_class_abbr = {0: 'MEL', 1: 'NV', 2: 'BCC', 3: 'AKIEC', 4: 'BKL', 5: 'DF', 6: 'VASC'}
66
 
@@ -70,8 +77,9 @@ transform_segment = A.Compose([
70
  A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0),
71
  ToTensorV2(),
72
  ])
 
73
  transform_classify = transforms.Compose([
74
- transforms.Resize((300, 300)),
75
  transforms.ToTensor(),
76
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
77
  ])
@@ -79,7 +87,7 @@ transform_classify = transforms.Compose([
79
  print("\n--> Application ready to accept requests.")
80
 
81
 
82
- # --- 2. DEFINE THE FULL PIPELINE FUNCTION (UPDATED) ---
83
 
84
  def full_pipeline(input_image):
85
  if input_image is None:
@@ -117,35 +125,29 @@ def full_pipeline(input_image):
117
  confidence, predicted_idx = torch.max(probabilities, 1)
118
  confidence_percent = confidence.item() * 100
119
 
120
- # SAFETY NET
121
  CONFIDENCE_THRESHOLD = 50.0
122
  if confidence_percent < CONFIDENCE_THRESHOLD:
123
  inconclusive_text = (
124
  f"**Analysis Inconclusive**\n\n"
125
- f"The AI model's confidence ({confidence_percent:.2f}%) is below the required threshold of {CONFIDENCE_THRESHOLD}%.\n\n"
126
- "This can happen if the image is blurry, has poor lighting, or shows a condition the model was not trained on.\n\n"
127
- "**--- IMPORTANT DISCLAIMER ---**\n"
128
- "This is NOT a diagnosis. Please consult a qualified dermatologist for an accurate assessment."
129
  )
130
  mask_display = Image.fromarray((seg_mask * 255).astype(np.uint8))
131
  return mask_display, cropped_image_pil, inconclusive_text
132
 
133
- # --- STAGE 3: LOOKUP and FORMAT (UPDATED) ---
134
  predicted_abbr = idx_to_class_abbr[predicted_idx.item()]
135
  info = knowledge_base.get(predicted_abbr, {})
136
 
137
- # Format the 'causes' and 'treatments' lists into clean, bulleted strings
138
- causes_list = info.get('causes', ['Specific causes not listed.'])
139
  causes_text = "\n".join([f"• {c}" for c in causes_list])
140
 
141
- treatments_list = info.get('common_treatments', ['No specific treatments listed.'])
142
  treatments_text = "\n".join([f"• {t}" for t in treatments_list])
143
 
144
- # Build the final output text using all the information
145
  info_text = (
146
  f"**Predicted Condition:** {info.get('full_name', 'N/A')} ({predicted_abbr})\n"
147
  f"**Confidence:** {confidence_percent:.2f}%\n\n"
148
- f"**Description:**\n{info.get('description', 'No description available.')}\n\n"
149
  f"**Common Causes:**\n{causes_text}\n\n"
150
  f"**Common Treatments:**\n{treatments_text}\n\n"
151
  f"**--- IMPORTANT DISCLAIMER ---**\n{info.get('disclaimer', '')}"
@@ -166,7 +168,7 @@ iface = gr.Interface(
166
  gr.Markdown(label="Analysis Result")
167
  ],
168
  title="AI Skin Lesion Analyzer",
169
- description="This tool performs a two-stage analysis on a skin lesion image. **Stage 1:** A UNet model segments the lesion. **Stage 2:** An EfficientNet model classifies the segmented lesion. \n\n**DISCLAIMER:** This is an educational tool and is NOT a substitute for professional medical advice. Always consult a qualified dermatologist for any health concerns.",
170
  allow_flagging="never"
171
  )
172
 
 
39
  # --- Download and Load Classification Model (EfficientNet) ---
40
  try:
41
  CLASS_REPO_ID = "sheikh987/efficientnet-B4"
42
+ # === Use the filename you want, which now exists in the repo ===
43
  CLASS_MODEL_FILENAME = "efficientnet_b4_augmented_best.pth"
44
+
45
  print(f"--> Downloading classification model from: {CLASS_REPO_ID}")
46
  class_model_path = hf_hub_download(repo_id=CLASS_REPO_ID, filename=CLASS_MODEL_FILENAME)
47
 
48
  NUM_CLASSES = 7
49
+ # === Define the correct model architecture (B4) to prevent size mismatch ===
50
+ classification_model = timm.create_model('efficientnet_b4', pretrained=False, num_classes=NUM_CLASSES).to(DEVICE)
51
+
52
  classification_model.load_state_dict(torch.load(class_model_path, map_location=DEVICE))
53
  classification_model.eval()
54
  print(" Classification model loaded successfully.")
 
60
 
61
  # --- Load Knowledge Base and Labels ---
62
  try:
63
+ knowledge_base_path = hf_hub_download(repo_id=CLASS_REPO_ID, filename="knowledge_base.json")
64
+ with open(knowledge_base_path, 'r') as f:
65
  knowledge_base = json.load(f)
66
+ print("--> Knowledge base loaded from Hub.")
67
+ except Exception as e:
68
+ print(f"!!! ERROR loading knowledge_base.json from Hub: {e}")
69
+ raise gr.Error("knowledge_base.json not found in the Hub repo. Make sure it has been uploaded.")
70
+
71
 
72
  idx_to_class_abbr = {0: 'MEL', 1: 'NV', 2: 'BCC', 3: 'AKIEC', 4: 'BKL', 5: 'DF', 6: 'VASC'}
73
 
 
77
  A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0),
78
  ToTensorV2(),
79
  ])
80
+ # === Use the correct image size for EfficientNet-B4 (380x380) ===
81
  transform_classify = transforms.Compose([
82
+ transforms.Resize((380, 380)),
83
  transforms.ToTensor(),
84
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
85
  ])
 
87
  print("\n--> Application ready to accept requests.")
88
 
89
 
90
+ # --- 2. DEFINE THE FULL PIPELINE FUNCTION ---
91
 
92
  def full_pipeline(input_image):
93
  if input_image is None:
 
125
  confidence, predicted_idx = torch.max(probabilities, 1)
126
  confidence_percent = confidence.item() * 100
127
 
 
128
  CONFIDENCE_THRESHOLD = 50.0
129
  if confidence_percent < CONFIDENCE_THRESHOLD:
130
  inconclusive_text = (
131
  f"**Analysis Inconclusive**\n\n"
132
+ f"The AI model's confidence ({confidence_percent:.2f}%) is below the required threshold of {CONFIDENCE_THRESHOLD}%."
 
 
 
133
  )
134
  mask_display = Image.fromarray((seg_mask * 255).astype(np.uint8))
135
  return mask_display, cropped_image_pil, inconclusive_text
136
 
137
+ # STAGE 3: LOOKUP and FORMAT
138
  predicted_abbr = idx_to_class_abbr[predicted_idx.item()]
139
  info = knowledge_base.get(predicted_abbr, {})
140
 
141
+ causes_list = info.get('causes', [])
 
142
  causes_text = "\n".join([f"• {c}" for c in causes_list])
143
 
144
+ treatments_list = info.get('common_treatments', [])
145
  treatments_text = "\n".join([f"• {t}" for t in treatments_list])
146
 
 
147
  info_text = (
148
  f"**Predicted Condition:** {info.get('full_name', 'N/A')} ({predicted_abbr})\n"
149
  f"**Confidence:** {confidence_percent:.2f}%\n\n"
150
+ f"**Description:**\n{info.get('description', 'N/A')}\n\n"
151
  f"**Common Causes:**\n{causes_text}\n\n"
152
  f"**Common Treatments:**\n{treatments_text}\n\n"
153
  f"**--- IMPORTANT DISCLAIMER ---**\n{info.get('disclaimer', '')}"
 
168
  gr.Markdown(label="Analysis Result")
169
  ],
170
  title="AI Skin Lesion Analyzer",
171
+ description="A two-stage AI tool for skin lesion analysis.",
172
  allow_flagging="never"
173
  )
174