Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -39,12 +39,16 @@ except Exception as e:
|
|
39 |
# --- Download and Load Classification Model (EfficientNet) ---
|
40 |
try:
|
41 |
CLASS_REPO_ID = "sheikh987/efficientnet-B4"
|
|
|
42 |
CLASS_MODEL_FILENAME = "efficientnet_b4_augmented_best.pth"
|
|
|
43 |
print(f"--> Downloading classification model from: {CLASS_REPO_ID}")
|
44 |
class_model_path = hf_hub_download(repo_id=CLASS_REPO_ID, filename=CLASS_MODEL_FILENAME)
|
45 |
|
46 |
NUM_CLASSES = 7
|
47 |
-
|
|
|
|
|
48 |
classification_model.load_state_dict(torch.load(class_model_path, map_location=DEVICE))
|
49 |
classification_model.eval()
|
50 |
print(" Classification model loaded successfully.")
|
@@ -56,11 +60,14 @@ except Exception as e:
|
|
56 |
|
57 |
# --- Load Knowledge Base and Labels ---
|
58 |
try:
|
59 |
-
|
|
|
60 |
knowledge_base = json.load(f)
|
61 |
-
print("--> Knowledge base loaded.")
|
62 |
-
except
|
63 |
-
|
|
|
|
|
64 |
|
65 |
idx_to_class_abbr = {0: 'MEL', 1: 'NV', 2: 'BCC', 3: 'AKIEC', 4: 'BKL', 5: 'DF', 6: 'VASC'}
|
66 |
|
@@ -70,8 +77,9 @@ transform_segment = A.Compose([
|
|
70 |
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0),
|
71 |
ToTensorV2(),
|
72 |
])
|
|
|
73 |
transform_classify = transforms.Compose([
|
74 |
-
transforms.Resize((
|
75 |
transforms.ToTensor(),
|
76 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
77 |
])
|
@@ -79,7 +87,7 @@ transform_classify = transforms.Compose([
|
|
79 |
print("\n--> Application ready to accept requests.")
|
80 |
|
81 |
|
82 |
-
# --- 2. DEFINE THE FULL PIPELINE FUNCTION
|
83 |
|
84 |
def full_pipeline(input_image):
|
85 |
if input_image is None:
|
@@ -117,35 +125,29 @@ def full_pipeline(input_image):
|
|
117 |
confidence, predicted_idx = torch.max(probabilities, 1)
|
118 |
confidence_percent = confidence.item() * 100
|
119 |
|
120 |
-
# SAFETY NET
|
121 |
CONFIDENCE_THRESHOLD = 50.0
|
122 |
if confidence_percent < CONFIDENCE_THRESHOLD:
|
123 |
inconclusive_text = (
|
124 |
f"**Analysis Inconclusive**\n\n"
|
125 |
-
f"The AI model's confidence ({confidence_percent:.2f}%) is below the required threshold of {CONFIDENCE_THRESHOLD}
|
126 |
-
"This can happen if the image is blurry, has poor lighting, or shows a condition the model was not trained on.\n\n"
|
127 |
-
"**--- IMPORTANT DISCLAIMER ---**\n"
|
128 |
-
"This is NOT a diagnosis. Please consult a qualified dermatologist for an accurate assessment."
|
129 |
)
|
130 |
mask_display = Image.fromarray((seg_mask * 255).astype(np.uint8))
|
131 |
return mask_display, cropped_image_pil, inconclusive_text
|
132 |
|
133 |
-
#
|
134 |
predicted_abbr = idx_to_class_abbr[predicted_idx.item()]
|
135 |
info = knowledge_base.get(predicted_abbr, {})
|
136 |
|
137 |
-
|
138 |
-
causes_list = info.get('causes', ['Specific causes not listed.'])
|
139 |
causes_text = "\n".join([f"• {c}" for c in causes_list])
|
140 |
|
141 |
-
treatments_list = info.get('common_treatments', [
|
142 |
treatments_text = "\n".join([f"• {t}" for t in treatments_list])
|
143 |
|
144 |
-
# Build the final output text using all the information
|
145 |
info_text = (
|
146 |
f"**Predicted Condition:** {info.get('full_name', 'N/A')} ({predicted_abbr})\n"
|
147 |
f"**Confidence:** {confidence_percent:.2f}%\n\n"
|
148 |
-
f"**Description:**\n{info.get('description', '
|
149 |
f"**Common Causes:**\n{causes_text}\n\n"
|
150 |
f"**Common Treatments:**\n{treatments_text}\n\n"
|
151 |
f"**--- IMPORTANT DISCLAIMER ---**\n{info.get('disclaimer', '')}"
|
@@ -166,7 +168,7 @@ iface = gr.Interface(
|
|
166 |
gr.Markdown(label="Analysis Result")
|
167 |
],
|
168 |
title="AI Skin Lesion Analyzer",
|
169 |
-
description="
|
170 |
allow_flagging="never"
|
171 |
)
|
172 |
|
|
|
39 |
# --- Download and Load Classification Model (EfficientNet) ---
|
40 |
try:
|
41 |
CLASS_REPO_ID = "sheikh987/efficientnet-B4"
|
42 |
+
# === Use the filename you want, which now exists in the repo ===
|
43 |
CLASS_MODEL_FILENAME = "efficientnet_b4_augmented_best.pth"
|
44 |
+
|
45 |
print(f"--> Downloading classification model from: {CLASS_REPO_ID}")
|
46 |
class_model_path = hf_hub_download(repo_id=CLASS_REPO_ID, filename=CLASS_MODEL_FILENAME)
|
47 |
|
48 |
NUM_CLASSES = 7
|
49 |
+
# === Define the correct model architecture (B4) to prevent size mismatch ===
|
50 |
+
classification_model = timm.create_model('efficientnet_b4', pretrained=False, num_classes=NUM_CLASSES).to(DEVICE)
|
51 |
+
|
52 |
classification_model.load_state_dict(torch.load(class_model_path, map_location=DEVICE))
|
53 |
classification_model.eval()
|
54 |
print(" Classification model loaded successfully.")
|
|
|
60 |
|
61 |
# --- Load Knowledge Base and Labels ---
|
62 |
try:
|
63 |
+
knowledge_base_path = hf_hub_download(repo_id=CLASS_REPO_ID, filename="knowledge_base.json")
|
64 |
+
with open(knowledge_base_path, 'r') as f:
|
65 |
knowledge_base = json.load(f)
|
66 |
+
print("--> Knowledge base loaded from Hub.")
|
67 |
+
except Exception as e:
|
68 |
+
print(f"!!! ERROR loading knowledge_base.json from Hub: {e}")
|
69 |
+
raise gr.Error("knowledge_base.json not found in the Hub repo. Make sure it has been uploaded.")
|
70 |
+
|
71 |
|
72 |
idx_to_class_abbr = {0: 'MEL', 1: 'NV', 2: 'BCC', 3: 'AKIEC', 4: 'BKL', 5: 'DF', 6: 'VASC'}
|
73 |
|
|
|
77 |
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0),
|
78 |
ToTensorV2(),
|
79 |
])
|
80 |
+
# === Use the correct image size for EfficientNet-B4 (380x380) ===
|
81 |
transform_classify = transforms.Compose([
|
82 |
+
transforms.Resize((380, 380)),
|
83 |
transforms.ToTensor(),
|
84 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
85 |
])
|
|
|
87 |
print("\n--> Application ready to accept requests.")
|
88 |
|
89 |
|
90 |
+
# --- 2. DEFINE THE FULL PIPELINE FUNCTION ---
|
91 |
|
92 |
def full_pipeline(input_image):
|
93 |
if input_image is None:
|
|
|
125 |
confidence, predicted_idx = torch.max(probabilities, 1)
|
126 |
confidence_percent = confidence.item() * 100
|
127 |
|
|
|
128 |
CONFIDENCE_THRESHOLD = 50.0
|
129 |
if confidence_percent < CONFIDENCE_THRESHOLD:
|
130 |
inconclusive_text = (
|
131 |
f"**Analysis Inconclusive**\n\n"
|
132 |
+
f"The AI model's confidence ({confidence_percent:.2f}%) is below the required threshold of {CONFIDENCE_THRESHOLD}%."
|
|
|
|
|
|
|
133 |
)
|
134 |
mask_display = Image.fromarray((seg_mask * 255).astype(np.uint8))
|
135 |
return mask_display, cropped_image_pil, inconclusive_text
|
136 |
|
137 |
+
# STAGE 3: LOOKUP and FORMAT
|
138 |
predicted_abbr = idx_to_class_abbr[predicted_idx.item()]
|
139 |
info = knowledge_base.get(predicted_abbr, {})
|
140 |
|
141 |
+
causes_list = info.get('causes', [])
|
|
|
142 |
causes_text = "\n".join([f"• {c}" for c in causes_list])
|
143 |
|
144 |
+
treatments_list = info.get('common_treatments', [])
|
145 |
treatments_text = "\n".join([f"• {t}" for t in treatments_list])
|
146 |
|
|
|
147 |
info_text = (
|
148 |
f"**Predicted Condition:** {info.get('full_name', 'N/A')} ({predicted_abbr})\n"
|
149 |
f"**Confidence:** {confidence_percent:.2f}%\n\n"
|
150 |
+
f"**Description:**\n{info.get('description', 'N/A')}\n\n"
|
151 |
f"**Common Causes:**\n{causes_text}\n\n"
|
152 |
f"**Common Treatments:**\n{treatments_text}\n\n"
|
153 |
f"**--- IMPORTANT DISCLAIMER ---**\n{info.get('disclaimer', '')}"
|
|
|
168 |
gr.Markdown(label="Analysis Result")
|
169 |
],
|
170 |
title="AI Skin Lesion Analyzer",
|
171 |
+
description="A two-stage AI tool for skin lesion analysis.",
|
172 |
allow_flagging="never"
|
173 |
)
|
174 |
|