File size: 5,861 Bytes
aa0d9c9
 
 
 
 
935c075
 
 
 
 
aa0d9c9
 
 
 
935c075
cafb713
935c075
aa0d9c9
935c075
 
668ad74
935c075
aa0d9c9
 
 
935c075
aa0d9c9
 
935c075
 
 
 
 
 
 
cafb713
718ea9a
 
 
 
 
aa0d9c9
cafb713
668ad74
 
 
cafb713
aa0d9c9
935c075
aa0d9c9
cafb713
 
 
668ad74
cafb713
935c075
 
 
aa0d9c9
935c075
 
 
aa0d9c9
935c075
aa0d9c9
cafb713
2be1cc1
c0aec73
668ad74
cafb713
935c075
 
 
aa0d9c9
935c075
 
 
aa0d9c9
 
 
 
 
 
 
cafb713
aa0d9c9
cafb713
 
 
 
 
 
 
 
 
 
 
 
aa0d9c9
 
 
 
 
668ad74
aa0d9c9
 
 
 
 
668ad74
aa0d9c9
935c075
668ad74
aa0d9c9
 
935c075
aa0d9c9
 
cafb713
aa0d9c9
 
935c075
668ad74
aa0d9c9
935c075
 
 
 
 
aa0d9c9
935c075
aa0d9c9
935c075
aa0d9c9
 
cafb713
aa0d9c9
 
cafb713
ba3cd47
cafb713
935c075
668ad74
aa0d9c9
 
cafb713
aa0d9c9
 
 
 
 
 
 
935c075
 
cafb713
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import base64
import io
import json
import numpy as np
import torch
from fastapi import FastAPI, HTTPException
from huggingface_hub import hf_hub_download
from PIL import Image
from pydantic import BaseModel
from torchvision import transforms

import timm
import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2
import os

# --- 1. SETUP: Create the FastAPI app ---
app = FastAPI(title="AI Skin Lesion Analyzer API")

# --- Global variables ---
DEVICE = "cpu"
segmentation_model = None
classification_model = None
knowledge_base = None
idx_to_class_abbr = {0: 'MEL', 1: 'NV', 2: 'BCC', 3: 'AKIEC', 4: 'BKL', 5: 'DF', 6: 'VASC'}
transform_segment = None
transform_classify = None

class ImageRequest(BaseModel):
    image_base64: str

@app.on_event("startup")
def load_assets():
    global segmentation_model, classification_model, knowledge_base, transform_segment, transform_classify

    # Haddii hore loo load gareeyay, iska dhaaf
    if segmentation_model is not None and classification_model is not None and knowledge_base is not None:
        print("🔁 Models and knowledge base already loaded. Skipping reloading.")
        return

    print("--> API starting up: This may take a few minutes...")

    # Use /tmp for writable cache directory
    cache_dir = "/tmp/models_cache"
    os.makedirs(cache_dir, exist_ok=True)

    # Load Segmentation Model
    try:
        print("    Downloading UNet segmentation model...")
        seg_model_path = hf_hub_download(
            repo_id="sheikh987/unet-isic2018",
            filename="unet_full_data_best_model.pth",
            cache_dir=cache_dir
        )
        segmentation_model = smp.Unet("resnet34", encoder_weights=None, in_channels=3, classes=1).to(DEVICE)
        segmentation_model.load_state_dict(torch.load(seg_model_path, map_location=DEVICE))
        segmentation_model.eval()
        print("    ✅ Segmentation model loaded.")
    except Exception as e:
        print(f"!!! FATAL: Could not load segmentation model: {e}")

    # Load Classification Model
    try:
        print("    Downloading EfficientNet classification model...")
        class_model_path = hf_hub_download(
            repo_id="sheikh987/efficientnet-isic",
            filename="efficientnet_augmented_best.pth",
            cache_dir=cache_dir
        )
        classification_model = timm.create_model('efficientnet_b3', pretrained=False, num_classes=7).to(DEVICE)
        classification_model.load_state_dict(torch.load(class_model_path, map_location=DEVICE))
        classification_model.eval()
        print("    ✅ Classification model loaded.")
    except Exception as e:
        print(f"!!! FATAL: Could not load classification model: {e}")

    # Load Knowledge Base
    try:
        with open('knowledge_base.json', 'r') as f:
            knowledge_base = json.load(f)
        print("    ✅ Knowledge base loaded.")
    except Exception as e:
        print(f"!!! FATAL: Could not load knowledge_base.json: {e}")

    # Define Image Transforms
    transform_segment = A.Compose([
        A.Resize(256, 256),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0),
        ToTensorV2()
    ])

    transform_classify = transforms.Compose([
        transforms.Resize((300, 300)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    print("\n--> API is ready to accept requests.")

@app.post("/analyze/")
async def analyze_image(request: ImageRequest):
    if not all([segmentation_model, classification_model, knowledge_base]):
        raise HTTPException(status_code=503, detail="Models not loaded yet. Please retry shortly.")

    try:
        image_data = base64.b64decode(request.image_base64)
        image = Image.open(io.BytesIO(image_data)).convert("RGB")
        image_np = np.array(image)
    except Exception:
        raise HTTPException(status_code=400, detail="Invalid base64 image data provided.")

    # Stage 1: Segmentation
    augmented_seg = transform_segment(image=image_np)
    seg_input_tensor = augmented_seg['image'].unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        seg_logits = segmentation_model(seg_input_tensor)
    seg_mask = (torch.sigmoid(seg_logits) > 0.5).float().squeeze().cpu().numpy()

    if seg_mask.sum() < 200:
        return {"status": "Failed", "message": "No lesion could be clearly identified in the image."}

    # Stage 2: Crop and classify
    rows, cols = np.any(seg_mask, axis=1), np.any(seg_mask, axis=0)
    rmin, rmax = np.where(rows)[0][[0, -1]]
    cmin, cmax = np.where(cols)[0][[0, -1]]
    padding = 15
    rmin, rmax = max(0, rmin - padding), min(image_np.shape[0], rmax + padding)
    cmin, cmax = max(0, cmin - padding), min(image_np.shape[1], cmax + padding)
    cropped_image_pil = Image.fromarray(image_np[rmin:rmax, cmin:cmax])

    class_input_tensor = transform_classify(cropped_image_pil).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        class_logits = classification_model(class_input_tensor)
        probabilities = torch.nn.functional.softmax(class_logits, dim=1)

    confidence, predicted_idx = torch.max(probabilities, 1)
    confidence_percent = confidence.item() * 100

    if confidence_percent < 50.0:
        return {"status": "Inconclusive", "message": f"Model confidence ({confidence_percent:.2f}%) is below the 75% threshold."}

    # Stage 3: Return result
    predicted_abbr = idx_to_class_abbr[predicted_idx.item()]
    info = knowledge_base.get(predicted_abbr, {})

    return {
        "status": "Success",
        "prediction": info,
        "abbreviation": predicted_abbr,
        "confidence": f"{confidence_percent:.2f}%"
    }

@app.get("/")
def root():
    return {"message": "AI Skin Lesion Analyzer API is running."}