Update main.py
Browse files
main.py
CHANGED
@@ -8,23 +8,17 @@ from huggingface_hub import hf_hub_download
|
|
8 |
from PIL import Image
|
9 |
from pydantic import BaseModel
|
10 |
from torchvision import transforms
|
11 |
-
|
12 |
import timm
|
13 |
-
import segmentation_models_pytorch as smp
|
14 |
-
import albumentations as A
|
15 |
-
from albumentations.pytorch import ToTensorV2
|
16 |
import os
|
17 |
|
18 |
# --- 1. SETUP: Create the FastAPI app ---
|
19 |
app = FastAPI(title="AI Skin Lesion Analyzer API")
|
20 |
|
21 |
# --- Global variables ---
|
22 |
-
DEVICE = "cpu"
|
23 |
-
segmentation_model = None
|
24 |
classification_model = None
|
25 |
knowledge_base = None
|
26 |
idx_to_class_abbr = {0: 'MEL', 1: 'NV', 2: 'BCC', 3: 'AKIEC', 4: 'BKL', 5: 'DF', 6: 'VASC'}
|
27 |
-
transform_segment = None
|
28 |
transform_classify = None
|
29 |
|
30 |
class ImageRequest(BaseModel):
|
@@ -32,48 +26,41 @@ class ImageRequest(BaseModel):
|
|
32 |
|
33 |
@app.on_event("startup")
|
34 |
def load_assets():
|
35 |
-
global
|
36 |
|
37 |
-
|
38 |
-
if segmentation_model is not None and classification_model is not None and knowledge_base is not None:
|
39 |
print("🔁 Models and knowledge base already loaded. Skipping reloading.")
|
40 |
return
|
41 |
|
42 |
-
print("--> API starting up:
|
43 |
|
44 |
# Use /tmp for writable cache directory
|
45 |
cache_dir = "/tmp/models_cache"
|
46 |
os.makedirs(cache_dir, exist_ok=True)
|
47 |
|
48 |
-
# Load Segmentation Model
|
49 |
-
try:
|
50 |
-
print(" Downloading UNet segmentation model...")
|
51 |
-
seg_model_path = hf_hub_download(
|
52 |
-
repo_id="sheikh987/unet-isic2018",
|
53 |
-
filename="unet_full_data_best_model.pth",
|
54 |
-
cache_dir=cache_dir
|
55 |
-
)
|
56 |
-
segmentation_model = smp.Unet("resnet34", encoder_weights=None, in_channels=3, classes=1).to(DEVICE)
|
57 |
-
segmentation_model.load_state_dict(torch.load(seg_model_path, map_location=DEVICE))
|
58 |
-
segmentation_model.eval()
|
59 |
-
print(" ✅ Segmentation model loaded.")
|
60 |
-
except Exception as e:
|
61 |
-
print(f"!!! FATAL: Could not load segmentation model: {e}")
|
62 |
-
|
63 |
# Load Classification Model
|
64 |
try:
|
65 |
-
print(" Downloading EfficientNet classification model...")
|
66 |
class_model_path = hf_hub_download(
|
67 |
-
repo_id="sheikh987/efficientnet-
|
68 |
-
filename="
|
69 |
cache_dir=cache_dir
|
70 |
)
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
classification_model.eval()
|
74 |
-
print(" ✅ Classification model loaded.")
|
75 |
except Exception as e:
|
76 |
-
|
77 |
|
78 |
# Load Knowledge Base
|
79 |
try:
|
@@ -81,69 +68,51 @@ def load_assets():
|
|
81 |
knowledge_base = json.load(f)
|
82 |
print(" ✅ Knowledge base loaded.")
|
83 |
except Exception as e:
|
84 |
-
|
85 |
-
|
86 |
-
# Define Image Transforms
|
87 |
-
transform_segment = A.Compose([
|
88 |
-
A.Resize(256, 256),
|
89 |
-
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0),
|
90 |
-
ToTensorV2()
|
91 |
-
])
|
92 |
|
|
|
93 |
transform_classify = transforms.Compose([
|
94 |
transforms.Resize((300, 300)),
|
95 |
transforms.ToTensor(),
|
96 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
|
97 |
])
|
98 |
|
99 |
-
print("
|
|
|
100 |
|
101 |
@app.post("/analyze/")
|
102 |
async def analyze_image(request: ImageRequest):
|
103 |
-
if not all([
|
104 |
-
raise HTTPException(status_code=503, detail="
|
105 |
|
106 |
try:
|
107 |
image_data = base64.b64decode(request.image_base64)
|
108 |
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
109 |
-
image_np = np.array(image)
|
110 |
except Exception:
|
111 |
raise HTTPException(status_code=400, detail="Invalid base64 image data provided.")
|
112 |
|
113 |
-
#
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
seg_logits = segmentation_model(seg_input_tensor)
|
118 |
-
seg_mask = (torch.sigmoid(seg_logits) > 0.5).float().squeeze().cpu().numpy()
|
119 |
-
|
120 |
-
if seg_mask.sum() < 200:
|
121 |
-
return {"status": "Failed", "message": "No lesion could be clearly identified in the image."}
|
122 |
-
|
123 |
-
# Stage 2: Crop and classify
|
124 |
-
rows, cols = np.any(seg_mask, axis=1), np.any(seg_mask, axis=0)
|
125 |
-
rmin, rmax = np.where(rows)[0][[0, -1]]
|
126 |
-
cmin, cmax = np.where(cols)[0][[0, -1]]
|
127 |
-
padding = 15
|
128 |
-
rmin, rmax = max(0, rmin - padding), min(image_np.shape[0], rmax + padding)
|
129 |
-
cmin, cmax = max(0, cmin - padding), min(image_np.shape[1], cmax + padding)
|
130 |
-
cropped_image_pil = Image.fromarray(image_np[rmin:rmax, cmin:cmax])
|
131 |
-
|
132 |
-
class_input_tensor = transform_classify(cropped_image_pil).unsqueeze(0).to(DEVICE)
|
133 |
with torch.no_grad():
|
134 |
-
|
135 |
-
probabilities = torch.nn.functional.softmax(
|
136 |
|
137 |
confidence, predicted_idx = torch.max(probabilities, 1)
|
138 |
confidence_percent = confidence.item() * 100
|
139 |
-
|
140 |
-
if confidence_percent < 50.0:
|
141 |
-
return {"status": "Inconclusive", "message": f"Model confidence ({confidence_percent:.2f}%) is below the 75% threshold."}
|
142 |
-
|
143 |
-
# Stage 3: Return result
|
144 |
predicted_abbr = idx_to_class_abbr[predicted_idx.item()]
|
145 |
info = knowledge_base.get(predicted_abbr, {})
|
146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
return {
|
148 |
"status": "Success",
|
149 |
"prediction": info,
|
|
|
8 |
from PIL import Image
|
9 |
from pydantic import BaseModel
|
10 |
from torchvision import transforms
|
|
|
11 |
import timm
|
|
|
|
|
|
|
12 |
import os
|
13 |
|
14 |
# --- 1. SETUP: Create the FastAPI app ---
|
15 |
app = FastAPI(title="AI Skin Lesion Analyzer API")
|
16 |
|
17 |
# --- Global variables ---
|
18 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
19 |
classification_model = None
|
20 |
knowledge_base = None
|
21 |
idx_to_class_abbr = {0: 'MEL', 1: 'NV', 2: 'BCC', 3: 'AKIEC', 4: 'BKL', 5: 'DF', 6: 'VASC'}
|
|
|
22 |
transform_classify = None
|
23 |
|
24 |
class ImageRequest(BaseModel):
|
|
|
26 |
|
27 |
@app.on_event("startup")
|
28 |
def load_assets():
|
29 |
+
global classification_model, knowledge_base, transform_classify
|
30 |
|
31 |
+
if classification_model is not None and knowledge_base is not None:
|
|
|
32 |
print("🔁 Models and knowledge base already loaded. Skipping reloading.")
|
33 |
return
|
34 |
|
35 |
+
print("--> API starting up: Loading classification model and knowledge base...")
|
36 |
|
37 |
# Use /tmp for writable cache directory
|
38 |
cache_dir = "/tmp/models_cache"
|
39 |
os.makedirs(cache_dir, exist_ok=True)
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
# Load Classification Model
|
42 |
try:
|
43 |
+
print(" Downloading EfficientNet-B3 classification model...")
|
44 |
class_model_path = hf_hub_download(
|
45 |
+
repo_id="sheikh987/efficientnet-b3-skin",
|
46 |
+
filename="efficientnet_b3_skin_model.pth",
|
47 |
cache_dir=cache_dir
|
48 |
)
|
49 |
+
checkpoint = torch.load(class_model_path, map_location=DEVICE)
|
50 |
+
|
51 |
+
# Auto-detect state_dict vs full model
|
52 |
+
if isinstance(checkpoint, dict):
|
53 |
+
classification_model = timm.create_model(
|
54 |
+
'efficientnet_b3', pretrained=False, num_classes=7
|
55 |
+
).to(DEVICE)
|
56 |
+
classification_model.load_state_dict(checkpoint, strict=False)
|
57 |
+
else:
|
58 |
+
classification_model = checkpoint.to(DEVICE)
|
59 |
+
|
60 |
classification_model.eval()
|
61 |
+
print(" ✅ Classification model loaded successfully.")
|
62 |
except Exception as e:
|
63 |
+
raise RuntimeError(f"Failed to load classification model: {e}")
|
64 |
|
65 |
# Load Knowledge Base
|
66 |
try:
|
|
|
68 |
knowledge_base = json.load(f)
|
69 |
print(" ✅ Knowledge base loaded.")
|
70 |
except Exception as e:
|
71 |
+
raise RuntimeError(f"Failed to load knowledge_base.json: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
+
# Define Image Transform
|
74 |
transform_classify = transforms.Compose([
|
75 |
transforms.Resize((300, 300)),
|
76 |
transforms.ToTensor(),
|
77 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
78 |
+
std=[0.229, 0.224, 0.225])
|
79 |
])
|
80 |
|
81 |
+
print("--> API is ready to accept requests.")
|
82 |
+
|
83 |
|
84 |
@app.post("/analyze/")
|
85 |
async def analyze_image(request: ImageRequest):
|
86 |
+
if not all([classification_model, knowledge_base]):
|
87 |
+
raise HTTPException(status_code=503, detail="Model or knowledge base not loaded yet.")
|
88 |
|
89 |
try:
|
90 |
image_data = base64.b64decode(request.image_base64)
|
91 |
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
|
|
92 |
except Exception:
|
93 |
raise HTTPException(status_code=400, detail="Invalid base64 image data provided.")
|
94 |
|
95 |
+
# Prepare image
|
96 |
+
class_input_tensor = transform_classify(image).unsqueeze(0).to(DEVICE)
|
97 |
+
|
98 |
+
# Classification
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
with torch.no_grad():
|
100 |
+
logits = classification_model(class_input_tensor)
|
101 |
+
probabilities = torch.nn.functional.softmax(logits, dim=1)
|
102 |
|
103 |
confidence, predicted_idx = torch.max(probabilities, 1)
|
104 |
confidence_percent = confidence.item() * 100
|
|
|
|
|
|
|
|
|
|
|
105 |
predicted_abbr = idx_to_class_abbr[predicted_idx.item()]
|
106 |
info = knowledge_base.get(predicted_abbr, {})
|
107 |
|
108 |
+
# Optional: Confidence threshold
|
109 |
+
CONFIDENCE_THRESHOLD = 50.0
|
110 |
+
if confidence_percent < CONFIDENCE_THRESHOLD:
|
111 |
+
return {
|
112 |
+
"status": "Inconclusive",
|
113 |
+
"message": f"Model confidence ({confidence_percent:.2f}%) is below the threshold of {CONFIDENCE_THRESHOLD}%."
|
114 |
+
}
|
115 |
+
|
116 |
return {
|
117 |
"status": "Success",
|
118 |
"prediction": info,
|