Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -39,14 +39,14 @@ except Exception as e:
|
|
39 |
# --- Download and Load Classification Model (EfficientNet) ---
|
40 |
try:
|
41 |
CLASS_REPO_ID = "sheikh987/efficientnet-B4"
|
42 |
-
#
|
43 |
CLASS_MODEL_FILENAME = "efficientnet_b4_augmented_best.pth"
|
44 |
|
45 |
print(f"--> Downloading classification model from: {CLASS_REPO_ID}")
|
46 |
class_model_path = hf_hub_download(repo_id=CLASS_REPO_ID, filename=CLASS_MODEL_FILENAME)
|
47 |
|
48 |
NUM_CLASSES = 7
|
49 |
-
#
|
50 |
classification_model = timm.create_model('efficientnet_b4', pretrained=False, num_classes=NUM_CLASSES).to(DEVICE)
|
51 |
|
52 |
classification_model.load_state_dict(torch.load(class_model_path, map_location=DEVICE))
|
@@ -60,6 +60,7 @@ except Exception as e:
|
|
60 |
|
61 |
# --- Load Knowledge Base and Labels ---
|
62 |
try:
|
|
|
63 |
knowledge_base_path = hf_hub_download(repo_id=CLASS_REPO_ID, filename="knowledge_base.json")
|
64 |
with open(knowledge_base_path, 'r') as f:
|
65 |
knowledge_base = json.load(f)
|
@@ -77,7 +78,7 @@ transform_segment = A.Compose([
|
|
77 |
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0),
|
78 |
ToTensorV2(),
|
79 |
])
|
80 |
-
#
|
81 |
transform_classify = transforms.Compose([
|
82 |
transforms.Resize((380, 380)),
|
83 |
transforms.ToTensor(),
|
@@ -138,16 +139,16 @@ def full_pipeline(input_image):
|
|
138 |
predicted_abbr = idx_to_class_abbr[predicted_idx.item()]
|
139 |
info = knowledge_base.get(predicted_abbr, {})
|
140 |
|
141 |
-
causes_list = info.get('causes', [])
|
142 |
causes_text = "\n".join([f"• {c}" for c in causes_list])
|
143 |
|
144 |
-
treatments_list = info.get('common_treatments', [])
|
145 |
treatments_text = "\n".join([f"• {t}" for t in treatments_list])
|
146 |
|
147 |
info_text = (
|
148 |
f"**Predicted Condition:** {info.get('full_name', 'N/A')} ({predicted_abbr})\n"
|
149 |
f"**Confidence:** {confidence_percent:.2f}%\n\n"
|
150 |
-
f"**Description:**\n{info.get('description', '
|
151 |
f"**Common Causes:**\n{causes_text}\n\n"
|
152 |
f"**Common Treatments:**\n{treatments_text}\n\n"
|
153 |
f"**--- IMPORTANT DISCLAIMER ---**\n{info.get('disclaimer', '')}"
|
@@ -168,7 +169,7 @@ iface = gr.Interface(
|
|
168 |
gr.Markdown(label="Analysis Result")
|
169 |
],
|
170 |
title="AI Skin Lesion Analyzer",
|
171 |
-
description="A two-stage AI tool for skin lesion analysis.",
|
172 |
allow_flagging="never"
|
173 |
)
|
174 |
|
|
|
39 |
# --- Download and Load Classification Model (EfficientNet) ---
|
40 |
try:
|
41 |
CLASS_REPO_ID = "sheikh987/efficientnet-B4"
|
42 |
+
# This filename MUST exactly match the file in your Hugging Face repo
|
43 |
CLASS_MODEL_FILENAME = "efficientnet_b4_augmented_best.pth"
|
44 |
|
45 |
print(f"--> Downloading classification model from: {CLASS_REPO_ID}")
|
46 |
class_model_path = hf_hub_download(repo_id=CLASS_REPO_ID, filename=CLASS_MODEL_FILENAME)
|
47 |
|
48 |
NUM_CLASSES = 7
|
49 |
+
# The architecture ('efficientnet_b4') MUST match the model weights
|
50 |
classification_model = timm.create_model('efficientnet_b4', pretrained=False, num_classes=NUM_CLASSES).to(DEVICE)
|
51 |
|
52 |
classification_model.load_state_dict(torch.load(class_model_path, map_location=DEVICE))
|
|
|
60 |
|
61 |
# --- Load Knowledge Base and Labels ---
|
62 |
try:
|
63 |
+
# It's best practice to also host the knowledge base in your repo
|
64 |
knowledge_base_path = hf_hub_download(repo_id=CLASS_REPO_ID, filename="knowledge_base.json")
|
65 |
with open(knowledge_base_path, 'r') as f:
|
66 |
knowledge_base = json.load(f)
|
|
|
78 |
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0),
|
79 |
ToTensorV2(),
|
80 |
])
|
81 |
+
# The image size (380x380) MUST match the model architecture
|
82 |
transform_classify = transforms.Compose([
|
83 |
transforms.Resize((380, 380)),
|
84 |
transforms.ToTensor(),
|
|
|
139 |
predicted_abbr = idx_to_class_abbr[predicted_idx.item()]
|
140 |
info = knowledge_base.get(predicted_abbr, {})
|
141 |
|
142 |
+
causes_list = info.get('causes', ['N/A'])
|
143 |
causes_text = "\n".join([f"• {c}" for c in causes_list])
|
144 |
|
145 |
+
treatments_list = info.get('common_treatments', ['N/A'])
|
146 |
treatments_text = "\n".join([f"• {t}" for t in treatments_list])
|
147 |
|
148 |
info_text = (
|
149 |
f"**Predicted Condition:** {info.get('full_name', 'N/A')} ({predicted_abbr})\n"
|
150 |
f"**Confidence:** {confidence_percent:.2f}%\n\n"
|
151 |
+
f"**Description:**\n{info.get('description', 'No description available.')}\n\n"
|
152 |
f"**Common Causes:**\n{causes_text}\n\n"
|
153 |
f"**Common Treatments:**\n{treatments_text}\n\n"
|
154 |
f"**--- IMPORTANT DISCLAIMER ---**\n{info.get('disclaimer', '')}"
|
|
|
169 |
gr.Markdown(label="Analysis Result")
|
170 |
],
|
171 |
title="AI Skin Lesion Analyzer",
|
172 |
+
description="A two-stage AI tool for skin lesion analysis. **DISCLAIMER:** This is an educational tool and is NOT a substitute for professional medical advice. Always consult a qualified dermatologist for any health concerns.",
|
173 |
allow_flagging="never"
|
174 |
)
|
175 |
|