File size: 5,607 Bytes
aa0d9c9 935c075 aa0d9c9 935c075 cafb713 935c075 aa0d9c9 935c075 668ad74 935c075 aa0d9c9 935c075 aa0d9c9 935c075 cafb713 aa0d9c9 cafb713 668ad74 cafb713 aa0d9c9 935c075 aa0d9c9 cafb713 668ad74 cafb713 935c075 aa0d9c9 935c075 aa0d9c9 935c075 aa0d9c9 cafb713 c0aec73 668ad74 cafb713 935c075 aa0d9c9 935c075 aa0d9c9 cafb713 aa0d9c9 cafb713 aa0d9c9 668ad74 aa0d9c9 668ad74 aa0d9c9 935c075 668ad74 aa0d9c9 935c075 aa0d9c9 cafb713 aa0d9c9 935c075 668ad74 aa0d9c9 935c075 aa0d9c9 935c075 aa0d9c9 935c075 aa0d9c9 cafb713 aa0d9c9 cafb713 aa0d9c9 cafb713 935c075 668ad74 aa0d9c9 cafb713 aa0d9c9 935c075 cafb713 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import base64
import io
import json
import numpy as np
import torch
from fastapi import FastAPI, HTTPException
from huggingface_hub import hf_hub_download
from PIL import Image
from pydantic import BaseModel
from torchvision import transforms
import timm
import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2
import os
# --- 1. SETUP: Create the FastAPI app ---
app = FastAPI(title="AI Skin Lesion Analyzer API")
# --- Global variables ---
DEVICE = "cpu"
segmentation_model = None
classification_model = None
knowledge_base = None
idx_to_class_abbr = {0: 'MEL', 1: 'NV', 2: 'BCC', 3: 'AKIEC', 4: 'BKL', 5: 'DF', 6: 'VASC'}
transform_segment = None
transform_classify = None
class ImageRequest(BaseModel):
image_base64: str
@app.on_event("startup")
def load_assets():
global segmentation_model, classification_model, knowledge_base, transform_segment, transform_classify
print("--> API starting up: This may take a few minutes...")
# Use /tmp for writable cache directory
cache_dir = "/tmp/models_cache"
os.makedirs(cache_dir, exist_ok=True)
# Load Segmentation Model
try:
print(" Downloading UNet segmentation model...")
seg_model_path = hf_hub_download(
repo_id="sheikh987/unet-isic2018",
filename="unet_full_data_best_model.pth",
cache_dir=cache_dir
)
segmentation_model = smp.Unet("resnet34", encoder_weights=None, in_channels=3, classes=1).to(DEVICE)
segmentation_model.load_state_dict(torch.load(seg_model_path, map_location=DEVICE))
segmentation_model.eval()
print(" ✅ Segmentation model loaded.")
except Exception as e:
print(f"!!! FATAL: Could not load segmentation model: {e}")
# Load Classification Model
try:
print(" Downloading EfficientNet classification model...")
class_model_path = hf_hub_download(
repo_id="sheikh987//efficientnet-isic",
filename="efficientnet_augmented_best.pth",
cache_dir=cache_dir
)
classification_model = timm.create_model('efficientnet_b3', pretrained=False, num_classes=7).to(DEVICE)
classification_model.load_state_dict(torch.load(class_model_path, map_location=DEVICE))
classification_model.eval()
print(" ✅ Classification model loaded.")
except Exception as e:
print(f"!!! FATAL: Could not load classification model: {e}")
# Load Knowledge Base
try:
with open('knowledge_base.json', 'r') as f:
knowledge_base = json.load(f)
print(" ✅ Knowledge base loaded.")
except Exception as e:
print(f"!!! FATAL: Could not load knowledge_base.json: {e}")
# Define Image Transforms
transform_segment = A.Compose([
A.Resize(256, 256),
A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0),
ToTensorV2()
])
transform_classify = transforms.Compose([
transforms.Resize((300, 300)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
print("\n--> API is ready to accept requests.")
@app.post("/analyze/")
async def analyze_image(request: ImageRequest):
if not all([segmentation_model, classification_model, knowledge_base]):
raise HTTPException(status_code=503, detail="Models not loaded yet. Please retry shortly.")
try:
image_data = base64.b64decode(request.image_base64)
image = Image.open(io.BytesIO(image_data)).convert("RGB")
image_np = np.array(image)
except Exception:
raise HTTPException(status_code=400, detail="Invalid base64 image data provided.")
# Stage 1: Segmentation
augmented_seg = transform_segment(image=image_np)
seg_input_tensor = augmented_seg['image'].unsqueeze(0).to(DEVICE)
with torch.no_grad():
seg_logits = segmentation_model(seg_input_tensor)
seg_mask = (torch.sigmoid(seg_logits) > 0.5).float().squeeze().cpu().numpy()
if seg_mask.sum() < 200:
return {"status": "Failed", "message": "No lesion could be clearly identified in the image."}
# Stage 2: Crop and classify
rows, cols = np.any(seg_mask, axis=1), np.any(seg_mask, axis=0)
rmin, rmax = np.where(rows)[0][[0, -1]]
cmin, cmax = np.where(cols)[0][[0, -1]]
padding = 15
rmin, rmax = max(0, rmin - padding), min(image_np.shape[0], rmax + padding)
cmin, cmax = max(0, cmin - padding), min(image_np.shape[1], cmax + padding)
cropped_image_pil = Image.fromarray(image_np[rmin:rmax, cmin:cmax])
class_input_tensor = transform_classify(cropped_image_pil).unsqueeze(0).to(DEVICE)
with torch.no_grad():
class_logits = classification_model(class_input_tensor)
probabilities = torch.nn.functional.softmax(class_logits, dim=1)
confidence, predicted_idx = torch.max(probabilities, 1)
confidence_percent = confidence.item() * 100
if confidence_percent < 75.0:
return {"status": "Inconclusive", "message": f"Model confidence ({confidence_percent:.2f}%) is below the 75% threshold."}
# Stage 3: Return result
predicted_abbr = idx_to_class_abbr[predicted_idx.item()]
info = knowledge_base.get(predicted_abbr, {})
return {
"status": "Success",
"prediction": info,
"abbreviation": predicted_abbr,
"confidence": f"{confidence_percent:.2f}%"
}
@app.get("/")
def root():
return {"message": "AI Skin Lesion Analyzer API is running."}
|