File size: 5,861 Bytes
aa0d9c9 935c075 aa0d9c9 935c075 cafb713 935c075 aa0d9c9 935c075 668ad74 935c075 aa0d9c9 935c075 aa0d9c9 935c075 cafb713 718ea9a aa0d9c9 cafb713 668ad74 cafb713 aa0d9c9 935c075 aa0d9c9 cafb713 668ad74 cafb713 935c075 aa0d9c9 935c075 aa0d9c9 935c075 aa0d9c9 cafb713 2be1cc1 c0aec73 668ad74 cafb713 935c075 aa0d9c9 935c075 aa0d9c9 cafb713 aa0d9c9 cafb713 aa0d9c9 668ad74 aa0d9c9 668ad74 aa0d9c9 935c075 668ad74 aa0d9c9 935c075 aa0d9c9 cafb713 aa0d9c9 935c075 668ad74 aa0d9c9 935c075 aa0d9c9 935c075 aa0d9c9 935c075 aa0d9c9 cafb713 aa0d9c9 cafb713 ba3cd47 cafb713 935c075 668ad74 aa0d9c9 cafb713 aa0d9c9 935c075 cafb713 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import base64
import io
import json
import numpy as np
import torch
from fastapi import FastAPI, HTTPException
from huggingface_hub import hf_hub_download
from PIL import Image
from pydantic import BaseModel
from torchvision import transforms
import timm
import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2
import os
# --- 1. SETUP: Create the FastAPI app ---
app = FastAPI(title="AI Skin Lesion Analyzer API")
# --- Global variables ---
DEVICE = "cpu"
segmentation_model = None
classification_model = None
knowledge_base = None
idx_to_class_abbr = {0: 'MEL', 1: 'NV', 2: 'BCC', 3: 'AKIEC', 4: 'BKL', 5: 'DF', 6: 'VASC'}
transform_segment = None
transform_classify = None
class ImageRequest(BaseModel):
image_base64: str
@app.on_event("startup")
def load_assets():
global segmentation_model, classification_model, knowledge_base, transform_segment, transform_classify
# Haddii hore loo load gareeyay, iska dhaaf
if segmentation_model is not None and classification_model is not None and knowledge_base is not None:
print("🔁 Models and knowledge base already loaded. Skipping reloading.")
return
print("--> API starting up: This may take a few minutes...")
# Use /tmp for writable cache directory
cache_dir = "/tmp/models_cache"
os.makedirs(cache_dir, exist_ok=True)
# Load Segmentation Model
try:
print(" Downloading UNet segmentation model...")
seg_model_path = hf_hub_download(
repo_id="sheikh987/unet-isic2018",
filename="unet_full_data_best_model.pth",
cache_dir=cache_dir
)
segmentation_model = smp.Unet("resnet34", encoder_weights=None, in_channels=3, classes=1).to(DEVICE)
segmentation_model.load_state_dict(torch.load(seg_model_path, map_location=DEVICE))
segmentation_model.eval()
print(" ✅ Segmentation model loaded.")
except Exception as e:
print(f"!!! FATAL: Could not load segmentation model: {e}")
# Load Classification Model
try:
print(" Downloading EfficientNet classification model...")
class_model_path = hf_hub_download(
repo_id="sheikh987/efficientnet-isic",
filename="efficientnet_augmented_best.pth",
cache_dir=cache_dir
)
classification_model = timm.create_model('efficientnet_b3', pretrained=False, num_classes=7).to(DEVICE)
classification_model.load_state_dict(torch.load(class_model_path, map_location=DEVICE))
classification_model.eval()
print(" ✅ Classification model loaded.")
except Exception as e:
print(f"!!! FATAL: Could not load classification model: {e}")
# Load Knowledge Base
try:
with open('knowledge_base.json', 'r') as f:
knowledge_base = json.load(f)
print(" ✅ Knowledge base loaded.")
except Exception as e:
print(f"!!! FATAL: Could not load knowledge_base.json: {e}")
# Define Image Transforms
transform_segment = A.Compose([
A.Resize(256, 256),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0),
ToTensorV2()
])
transform_classify = transforms.Compose([
transforms.Resize((300, 300)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
print("\n--> API is ready to accept requests.")
@app.post("/analyze/")
async def analyze_image(request: ImageRequest):
if not all([segmentation_model, classification_model, knowledge_base]):
raise HTTPException(status_code=503, detail="Models not loaded yet. Please retry shortly.")
try:
image_data = base64.b64decode(request.image_base64)
image = Image.open(io.BytesIO(image_data)).convert("RGB")
image_np = np.array(image)
except Exception:
raise HTTPException(status_code=400, detail="Invalid base64 image data provided.")
# Stage 1: Segmentation
augmented_seg = transform_segment(image=image_np)
seg_input_tensor = augmented_seg['image'].unsqueeze(0).to(DEVICE)
with torch.no_grad():
seg_logits = segmentation_model(seg_input_tensor)
seg_mask = (torch.sigmoid(seg_logits) > 0.5).float().squeeze().cpu().numpy()
if seg_mask.sum() < 200:
return {"status": "Failed", "message": "No lesion could be clearly identified in the image."}
# Stage 2: Crop and classify
rows, cols = np.any(seg_mask, axis=1), np.any(seg_mask, axis=0)
rmin, rmax = np.where(rows)[0][[0, -1]]
cmin, cmax = np.where(cols)[0][[0, -1]]
padding = 15
rmin, rmax = max(0, rmin - padding), min(image_np.shape[0], rmax + padding)
cmin, cmax = max(0, cmin - padding), min(image_np.shape[1], cmax + padding)
cropped_image_pil = Image.fromarray(image_np[rmin:rmax, cmin:cmax])
class_input_tensor = transform_classify(cropped_image_pil).unsqueeze(0).to(DEVICE)
with torch.no_grad():
class_logits = classification_model(class_input_tensor)
probabilities = torch.nn.functional.softmax(class_logits, dim=1)
confidence, predicted_idx = torch.max(probabilities, 1)
confidence_percent = confidence.item() * 100
if confidence_percent < 50.0:
return {"status": "Inconclusive", "message": f"Model confidence ({confidence_percent:.2f}%) is below the 75% threshold."}
# Stage 3: Return result
predicted_abbr = idx_to_class_abbr[predicted_idx.item()]
info = knowledge_base.get(predicted_abbr, {})
return {
"status": "Success",
"prediction": info,
"abbreviation": predicted_abbr,
"confidence": f"{confidence_percent:.2f}%"
}
@app.get("/")
def root():
return {"message": "AI Skin Lesion Analyzer API is running."}
|