sheikh987 commited on
Commit
babfa47
·
verified ·
1 Parent(s): 013b924

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -39,14 +39,14 @@ except Exception as e:
39
  # --- Download and Load Classification Model (EfficientNet) ---
40
  try:
41
  CLASS_REPO_ID = "sheikh987/efficientnet-B4"
42
- # === Use the filename you want, which now exists in the repo ===
43
  CLASS_MODEL_FILENAME = "efficientnet_b4_augmented_best.pth"
44
 
45
  print(f"--> Downloading classification model from: {CLASS_REPO_ID}")
46
  class_model_path = hf_hub_download(repo_id=CLASS_REPO_ID, filename=CLASS_MODEL_FILENAME)
47
 
48
  NUM_CLASSES = 7
49
- # === Define the correct model architecture (B4) to prevent size mismatch ===
50
  classification_model = timm.create_model('efficientnet_b4', pretrained=False, num_classes=NUM_CLASSES).to(DEVICE)
51
 
52
  classification_model.load_state_dict(torch.load(class_model_path, map_location=DEVICE))
@@ -60,6 +60,7 @@ except Exception as e:
60
 
61
  # --- Load Knowledge Base and Labels ---
62
  try:
 
63
  knowledge_base_path = hf_hub_download(repo_id=CLASS_REPO_ID, filename="knowledge_base.json")
64
  with open(knowledge_base_path, 'r') as f:
65
  knowledge_base = json.load(f)
@@ -77,7 +78,7 @@ transform_segment = A.Compose([
77
  A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0),
78
  ToTensorV2(),
79
  ])
80
- # === Use the correct image size for EfficientNet-B4 (380x380) ===
81
  transform_classify = transforms.Compose([
82
  transforms.Resize((380, 380)),
83
  transforms.ToTensor(),
@@ -138,16 +139,16 @@ def full_pipeline(input_image):
138
  predicted_abbr = idx_to_class_abbr[predicted_idx.item()]
139
  info = knowledge_base.get(predicted_abbr, {})
140
 
141
- causes_list = info.get('causes', [])
142
  causes_text = "\n".join([f"• {c}" for c in causes_list])
143
 
144
- treatments_list = info.get('common_treatments', [])
145
  treatments_text = "\n".join([f"• {t}" for t in treatments_list])
146
 
147
  info_text = (
148
  f"**Predicted Condition:** {info.get('full_name', 'N/A')} ({predicted_abbr})\n"
149
  f"**Confidence:** {confidence_percent:.2f}%\n\n"
150
- f"**Description:**\n{info.get('description', 'N/A')}\n\n"
151
  f"**Common Causes:**\n{causes_text}\n\n"
152
  f"**Common Treatments:**\n{treatments_text}\n\n"
153
  f"**--- IMPORTANT DISCLAIMER ---**\n{info.get('disclaimer', '')}"
@@ -168,7 +169,7 @@ iface = gr.Interface(
168
  gr.Markdown(label="Analysis Result")
169
  ],
170
  title="AI Skin Lesion Analyzer",
171
- description="A two-stage AI tool for skin lesion analysis.",
172
  allow_flagging="never"
173
  )
174
 
 
39
  # --- Download and Load Classification Model (EfficientNet) ---
40
  try:
41
  CLASS_REPO_ID = "sheikh987/efficientnet-B4"
42
+ # This filename MUST exactly match the file in your Hugging Face repo
43
  CLASS_MODEL_FILENAME = "efficientnet_b4_augmented_best.pth"
44
 
45
  print(f"--> Downloading classification model from: {CLASS_REPO_ID}")
46
  class_model_path = hf_hub_download(repo_id=CLASS_REPO_ID, filename=CLASS_MODEL_FILENAME)
47
 
48
  NUM_CLASSES = 7
49
+ # The architecture ('efficientnet_b4') MUST match the model weights
50
  classification_model = timm.create_model('efficientnet_b4', pretrained=False, num_classes=NUM_CLASSES).to(DEVICE)
51
 
52
  classification_model.load_state_dict(torch.load(class_model_path, map_location=DEVICE))
 
60
 
61
  # --- Load Knowledge Base and Labels ---
62
  try:
63
+ # It's best practice to also host the knowledge base in your repo
64
  knowledge_base_path = hf_hub_download(repo_id=CLASS_REPO_ID, filename="knowledge_base.json")
65
  with open(knowledge_base_path, 'r') as f:
66
  knowledge_base = json.load(f)
 
78
  A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0),
79
  ToTensorV2(),
80
  ])
81
+ # The image size (380x380) MUST match the model architecture
82
  transform_classify = transforms.Compose([
83
  transforms.Resize((380, 380)),
84
  transforms.ToTensor(),
 
139
  predicted_abbr = idx_to_class_abbr[predicted_idx.item()]
140
  info = knowledge_base.get(predicted_abbr, {})
141
 
142
+ causes_list = info.get('causes', ['N/A'])
143
  causes_text = "\n".join([f"• {c}" for c in causes_list])
144
 
145
+ treatments_list = info.get('common_treatments', ['N/A'])
146
  treatments_text = "\n".join([f"• {t}" for t in treatments_list])
147
 
148
  info_text = (
149
  f"**Predicted Condition:** {info.get('full_name', 'N/A')} ({predicted_abbr})\n"
150
  f"**Confidence:** {confidence_percent:.2f}%\n\n"
151
+ f"**Description:**\n{info.get('description', 'No description available.')}\n\n"
152
  f"**Common Causes:**\n{causes_text}\n\n"
153
  f"**Common Treatments:**\n{treatments_text}\n\n"
154
  f"**--- IMPORTANT DISCLAIMER ---**\n{info.get('disclaimer', '')}"
 
169
  gr.Markdown(label="Analysis Result")
170
  ],
171
  title="AI Skin Lesion Analyzer",
172
+ description="A two-stage AI tool for skin lesion analysis. **DISCLAIMER:** This is an educational tool and is NOT a substitute for professional medical advice. Always consult a qualified dermatologist for any health concerns.",
173
  allow_flagging="never"
174
  )
175