sheikh987 commited on
Commit
e7ca9ad
·
verified ·
1 Parent(s): 6e28903

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +174 -0
app.py CHANGED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
+ import gradio as gr
4
+ import torch
5
+ from PIL import Image
6
+ import numpy as np
7
+ import json
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ # Import necessary model libraries
11
+ import segmentation_models_pytorch as smp
12
+ import timm
13
+ import albumentations as A
14
+ from albumentations.pytorch import ToTensorV2
15
+ from torchvision import transforms
16
+
17
+ # --- 1. SETUP: Download and Load all models and data ---
18
+
19
+ print("--> Initializing application and downloading models...")
20
+ DEVICE = "cpu"
21
+
22
+ # --- Download and Load Segmentation Model (UNet) ---
23
+ try:
24
+ SEG_REPO_ID = "sheikh987/unet-isic2018"
25
+ SEG_MODEL_FILENAME = "unet_full_data_best_model.pth"
26
+ print(f"--> Downloading segmentation model from: {SEG_REPO_ID}")
27
+ seg_model_path = hf_hub_download(repo_id=SEG_REPO_ID, filename=SEG_MODEL_FILENAME)
28
+
29
+ segmentation_model = smp.Unet(encoder_name="resnet34", encoder_weights=None, in_channels=3, classes=1).to(DEVICE)
30
+ segmentation_model.load_state_dict(torch.load(seg_model_path, map_location=DEVICE))
31
+ segmentation_model.eval()
32
+ print(" Segmentation model loaded successfully.")
33
+
34
+ except Exception as e:
35
+ print(f"!!! ERROR loading segmentation model: {e}")
36
+ raise gr.Error("Failed to load the segmentation model. Check repository name and file paths.")
37
+
38
+
39
+ # --- Download and Load Classification Model (EfficientNet) ---
40
+ try:
41
+ CLASS_REPO_ID = "sheikh987/efficientnet-B4"
42
+ CLASS_MODEL_FILENAME = "efficientnet_b4_augmented_best.pth"
43
+ print(f"--> Downloading classification model from: {CLASS_REPO_ID}")
44
+ class_model_path = hf_hub_download(repo_id=CLASS_REPO_ID, filename=CLASS_MODEL_FILENAME)
45
+
46
+ NUM_CLASSES = 7
47
+ classification_model = timm.create_model('efficientnet_b3', pretrained=False, num_classes=NUM_CLASSES).to(DEVICE)
48
+ classification_model.load_state_dict(torch.load(class_model_path, map_location=DEVICE))
49
+ classification_model.eval()
50
+ print(" Classification model loaded successfully.")
51
+
52
+ except Exception as e:
53
+ print(f"!!! ERROR loading classification model: {e}")
54
+ raise gr.Error("Failed to load the classification model. Check repository name and file paths.")
55
+
56
+
57
+ # --- Load Knowledge Base and Labels ---
58
+ try:
59
+ with open('knowledge_base.json', 'r') as f:
60
+ knowledge_base = json.load(f)
61
+ print("--> Knowledge base loaded.")
62
+ except FileNotFoundError:
63
+ raise gr.Error("knowledge_base.json not found. Make sure it has been uploaded to the Space.")
64
+
65
+ idx_to_class_abbr = {0: 'MEL', 1: 'NV', 2: 'BCC', 3: 'AKIEC', 4: 'BKL', 5: 'DF', 6: 'VASC'}
66
+
67
+ # --- Define Image Transforms ---
68
+ transform_segment = A.Compose([
69
+ A.Resize(height=256, width=256),
70
+ A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0),
71
+ ToTensorV2(),
72
+ ])
73
+ transform_classify = transforms.Compose([
74
+ transforms.Resize((300, 300)),
75
+ transforms.ToTensor(),
76
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
77
+ ])
78
+
79
+ print("\n--> Application ready to accept requests.")
80
+
81
+
82
+ # --- 2. DEFINE THE FULL PIPELINE FUNCTION (UPDATED) ---
83
+
84
+ def full_pipeline(input_image):
85
+ if input_image is None:
86
+ return None, None, "Please upload an image."
87
+
88
+ image_np = np.array(input_image.convert("RGB"))
89
+
90
+ # STAGE 1: SEGMENTATION
91
+ augmented_seg = transform_segment(image=image_np)
92
+ seg_input_tensor = augmented_seg['image'].unsqueeze(0).to(DEVICE)
93
+ with torch.no_grad():
94
+ seg_logits = segmentation_model(seg_input_tensor)
95
+ seg_mask = (torch.sigmoid(seg_logits) > 0.5).float().squeeze().cpu().numpy()
96
+
97
+ if seg_mask.sum() < 200:
98
+ return None, None, "Analysis Failed: No lesion could be clearly identified."
99
+
100
+ # STAGE 2: CROP and CLASSIFY
101
+ rows = np.any(seg_mask, axis=1)
102
+ cols = np.any(seg_mask, axis=0)
103
+ rmin, rmax = np.where(rows)[0][[0, -1]]
104
+ cmin, cmax = np.where(cols)[0][[0, -1]]
105
+
106
+ padding = 15
107
+ rmin, rmax = max(0, rmin - padding), min(image_np.shape[0], rmax + padding)
108
+ cmin, cmax = max(0, cmin - padding), min(image_np.shape[1], cmax + padding)
109
+
110
+ cropped_image_pil = Image.fromarray(image_np[rmin:rmax, cmin:cmax])
111
+
112
+ class_input_tensor = transform_classify(cropped_image_pil).unsqueeze(0).to(DEVICE)
113
+ with torch.no_grad():
114
+ class_logits = classification_model(class_input_tensor)
115
+ probabilities = torch.nn.functional.softmax(class_logits, dim=1)
116
+
117
+ confidence, predicted_idx = torch.max(probabilities, 1)
118
+ confidence_percent = confidence.item() * 100
119
+
120
+ # SAFETY NET
121
+ CONFIDENCE_THRESHOLD = 50.0
122
+ if confidence_percent < CONFIDENCE_THRESHOLD:
123
+ inconclusive_text = (
124
+ f"**Analysis Inconclusive**\n\n"
125
+ f"The AI model's confidence ({confidence_percent:.2f}%) is below the required threshold of {CONFIDENCE_THRESHOLD}%.\n\n"
126
+ "This can happen if the image is blurry, has poor lighting, or shows a condition the model was not trained on.\n\n"
127
+ "**--- IMPORTANT DISCLAIMER ---**\n"
128
+ "This is NOT a diagnosis. Please consult a qualified dermatologist for an accurate assessment."
129
+ )
130
+ mask_display = Image.fromarray((seg_mask * 255).astype(np.uint8))
131
+ return mask_display, cropped_image_pil, inconclusive_text
132
+
133
+ # --- STAGE 3: LOOKUP and FORMAT (UPDATED) ---
134
+ predicted_abbr = idx_to_class_abbr[predicted_idx.item()]
135
+ info = knowledge_base.get(predicted_abbr, {})
136
+
137
+ # Format the 'causes' and 'treatments' lists into clean, bulleted strings
138
+ causes_list = info.get('causes', ['Specific causes not listed.'])
139
+ causes_text = "\n".join([f"• {c}" for c in causes_list])
140
+
141
+ treatments_list = info.get('common_treatments', ['No specific treatments listed.'])
142
+ treatments_text = "\n".join([f"• {t}" for t in treatments_list])
143
+
144
+ # Build the final output text using all the information
145
+ info_text = (
146
+ f"**Predicted Condition:** {info.get('full_name', 'N/A')} ({predicted_abbr})\n"
147
+ f"**Confidence:** {confidence_percent:.2f}%\n\n"
148
+ f"**Description:**\n{info.get('description', 'No description available.')}\n\n"
149
+ f"**Common Causes:**\n{causes_text}\n\n"
150
+ f"**Common Treatments:**\n{treatments_text}\n\n"
151
+ f"**--- IMPORTANT DISCLAIMER ---**\n{info.get('disclaimer', '')}"
152
+ )
153
+
154
+ mask_display = Image.fromarray((seg_mask * 255).astype(np.uint8))
155
+
156
+ return mask_display, cropped_image_pil, info_text
157
+
158
+
159
+ # --- 3. CREATE AND LAUNCH THE GRADIO INTERFACE ---
160
+ iface = gr.Interface(
161
+ fn=full_pipeline,
162
+ inputs=gr.Image(type="pil", label="Upload Skin Image"),
163
+ outputs=[
164
+ gr.Image(type="pil", label="Segmentation Mask"),
165
+ gr.Image(type="pil", label="Cropped Lesion"),
166
+ gr.Markdown(label="Analysis Result")
167
+ ],
168
+ title="AI Skin Lesion Analyzer",
169
+ description="This tool performs a two-stage analysis on a skin lesion image. **Stage 1:** A UNet model segments the lesion. **Stage 2:** An EfficientNet model classifies the segmented lesion. \n\n**DISCLAIMER:** This is an educational tool and is NOT a substitute for professional medical advice. Always consult a qualified dermatologist for any health concerns.",
170
+ allow_flagging="never"
171
+ )
172
+
173
+ if __name__ == "__main__":
174
+ iface.launch()