Update main.py
Browse files
main.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1 |
-
# main.py
|
2 |
-
|
3 |
import base64
|
4 |
import io
|
5 |
import json
|
@@ -16,6 +14,7 @@ import timm
|
|
16 |
import segmentation_models_pytorch as smp
|
17 |
import albumentations as A
|
18 |
from albumentations.pytorch import ToTensorV2
|
|
|
19 |
|
20 |
# --- 1. SETUP: Create the FastAPI app ---
|
21 |
app = FastAPI(title="AI Skin Lesion Analyzer API")
|
@@ -38,13 +37,20 @@ class ImageRequest(BaseModel):
|
|
38 |
def load_assets():
|
39 |
"""Load all models and assets from Hugging Face Hub when the server starts."""
|
40 |
global segmentation_model, classification_model, knowledge_base, transform_segment, transform_classify
|
41 |
-
|
42 |
print("--> API starting up: This may take a few minutes...")
|
43 |
-
|
|
|
|
|
|
|
44 |
# Load Segmentation Model
|
45 |
try:
|
46 |
print(" Downloading UNet segmentation model...")
|
47 |
-
seg_model_path = hf_hub_download(
|
|
|
|
|
|
|
|
|
48 |
segmentation_model = smp.Unet("resnet34", encoder_weights=None, in_channels=3, classes=1).to(DEVICE)
|
49 |
segmentation_model.load_state_dict(torch.load(seg_model_path, map_location=DEVICE))
|
50 |
segmentation_model.eval()
|
@@ -55,7 +61,11 @@ def load_assets():
|
|
55 |
# Load Classification Model
|
56 |
try:
|
57 |
print(" Downloading EfficientNet classification model...")
|
58 |
-
class_model_path = hf_hub_download(
|
|
|
|
|
|
|
|
|
59 |
classification_model = timm.create_model('efficientnet_b3', pretrained=False, num_classes=7).to(DEVICE)
|
60 |
classification_model.load_state_dict(torch.load(class_model_path, map_location=DEVICE))
|
61 |
classification_model.eval()
|
@@ -70,11 +80,20 @@ def load_assets():
|
|
70 |
print(" ✅ Knowledge base loaded.")
|
71 |
except Exception as e:
|
72 |
print(f"!!! FATAL: Could not load knowledge_base.json: {e}")
|
73 |
-
|
74 |
# Define Image Transforms
|
75 |
-
transform_segment = A.Compose([
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
print("\n--> API is ready to accept requests.")
|
79 |
|
80 |
# --- 2. DEFINE THE MAIN API ENDPOINT ---
|
@@ -101,7 +120,7 @@ async def analyze_image(request: ImageRequest):
|
|
101 |
with torch.no_grad():
|
102 |
seg_logits = segmentation_model(seg_input_tensor)
|
103 |
seg_mask = (torch.sigmoid(seg_logits) > 0.5).float().squeeze().cpu().numpy()
|
104 |
-
|
105 |
if seg_mask.sum() < 200:
|
106 |
return {"status": "Failed", "message": "No lesion could be clearly identified in the image."}
|
107 |
|
@@ -118,18 +137,17 @@ async def analyze_image(request: ImageRequest):
|
|
118 |
with torch.no_grad():
|
119 |
class_logits = classification_model(class_input_tensor)
|
120 |
probabilities = torch.nn.functional.softmax(class_logits, dim=1)
|
121 |
-
|
122 |
confidence, predicted_idx = torch.max(probabilities, 1)
|
123 |
confidence_percent = confidence.item() * 100
|
124 |
-
|
125 |
-
# SAFETY NET
|
126 |
if confidence_percent < 75.0:
|
127 |
-
|
128 |
|
129 |
# STAGE 3: LOOKUP AND RETURN RESULT
|
130 |
predicted_abbr = idx_to_class_abbr[predicted_idx.item()]
|
131 |
info = knowledge_base.get(predicted_abbr, {})
|
132 |
-
|
133 |
return {
|
134 |
"status": "Success",
|
135 |
"prediction": info,
|
@@ -140,4 +158,4 @@ async def analyze_image(request: ImageRequest):
|
|
140 |
# --- Root endpoint to check if the API is alive ---
|
141 |
@app.get("/")
|
142 |
def root():
|
143 |
-
return {"message": "AI Skin Lesion Analyzer API is running."}
|
|
|
|
|
|
|
1 |
import base64
|
2 |
import io
|
3 |
import json
|
|
|
14 |
import segmentation_models_pytorch as smp
|
15 |
import albumentations as A
|
16 |
from albumentations.pytorch import ToTensorV2
|
17 |
+
import os
|
18 |
|
19 |
# --- 1. SETUP: Create the FastAPI app ---
|
20 |
app = FastAPI(title="AI Skin Lesion Analyzer API")
|
|
|
37 |
def load_assets():
|
38 |
"""Load all models and assets from Hugging Face Hub when the server starts."""
|
39 |
global segmentation_model, classification_model, knowledge_base, transform_segment, transform_classify
|
40 |
+
|
41 |
print("--> API starting up: This may take a few minutes...")
|
42 |
+
|
43 |
+
# ✅ Create safe folder to store downloaded models
|
44 |
+
os.makedirs("./models_cache", exist_ok=True)
|
45 |
+
|
46 |
# Load Segmentation Model
|
47 |
try:
|
48 |
print(" Downloading UNet segmentation model...")
|
49 |
+
seg_model_path = hf_hub_download(
|
50 |
+
repo_id="sheikh987/unet-isic2018",
|
51 |
+
filename="unet_full_data_best_model.pth",
|
52 |
+
cache_dir="./models_cache"
|
53 |
+
)
|
54 |
segmentation_model = smp.Unet("resnet34", encoder_weights=None, in_channels=3, classes=1).to(DEVICE)
|
55 |
segmentation_model.load_state_dict(torch.load(seg_model_path, map_location=DEVICE))
|
56 |
segmentation_model.eval()
|
|
|
61 |
# Load Classification Model
|
62 |
try:
|
63 |
print(" Downloading EfficientNet classification model...")
|
64 |
+
class_model_path = hf_hub_download(
|
65 |
+
repo_id="sheikh987/efficientnet-isic-classifier",
|
66 |
+
filename="efficientnet_isic_classifier_best.pth",
|
67 |
+
cache_dir="./models_cache"
|
68 |
+
)
|
69 |
classification_model = timm.create_model('efficientnet_b3', pretrained=False, num_classes=7).to(DEVICE)
|
70 |
classification_model.load_state_dict(torch.load(class_model_path, map_location=DEVICE))
|
71 |
classification_model.eval()
|
|
|
80 |
print(" ✅ Knowledge base loaded.")
|
81 |
except Exception as e:
|
82 |
print(f"!!! FATAL: Could not load knowledge_base.json: {e}")
|
83 |
+
|
84 |
# Define Image Transforms
|
85 |
+
transform_segment = A.Compose([
|
86 |
+
A.Resize(256, 256),
|
87 |
+
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0),
|
88 |
+
ToTensorV2()
|
89 |
+
])
|
90 |
+
|
91 |
+
transform_classify = transforms.Compose([
|
92 |
+
transforms.Resize((300, 300)),
|
93 |
+
transforms.ToTensor(),
|
94 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
95 |
+
])
|
96 |
+
|
97 |
print("\n--> API is ready to accept requests.")
|
98 |
|
99 |
# --- 2. DEFINE THE MAIN API ENDPOINT ---
|
|
|
120 |
with torch.no_grad():
|
121 |
seg_logits = segmentation_model(seg_input_tensor)
|
122 |
seg_mask = (torch.sigmoid(seg_logits) > 0.5).float().squeeze().cpu().numpy()
|
123 |
+
|
124 |
if seg_mask.sum() < 200:
|
125 |
return {"status": "Failed", "message": "No lesion could be clearly identified in the image."}
|
126 |
|
|
|
137 |
with torch.no_grad():
|
138 |
class_logits = classification_model(class_input_tensor)
|
139 |
probabilities = torch.nn.functional.softmax(class_logits, dim=1)
|
140 |
+
|
141 |
confidence, predicted_idx = torch.max(probabilities, 1)
|
142 |
confidence_percent = confidence.item() * 100
|
143 |
+
|
|
|
144 |
if confidence_percent < 75.0:
|
145 |
+
return {"status": "Inconclusive", "message": f"Model confidence ({confidence_percent:.2f}%) is below the 75% threshold."}
|
146 |
|
147 |
# STAGE 3: LOOKUP AND RETURN RESULT
|
148 |
predicted_abbr = idx_to_class_abbr[predicted_idx.item()]
|
149 |
info = knowledge_base.get(predicted_abbr, {})
|
150 |
+
|
151 |
return {
|
152 |
"status": "Success",
|
153 |
"prediction": info,
|
|
|
158 |
# --- Root endpoint to check if the API is alive ---
|
159 |
@app.get("/")
|
160 |
def root():
|
161 |
+
return {"message": "AI Skin Lesion Analyzer API is running."}
|