|
import base64 |
|
import io |
|
import json |
|
import numpy as np |
|
import torch |
|
from fastapi import FastAPI, HTTPException |
|
from huggingface_hub import hf_hub_download |
|
from PIL import Image |
|
from pydantic import BaseModel |
|
from torchvision import transforms |
|
import timm |
|
import os |
|
|
|
|
|
app = FastAPI(title="AI Skin Lesion Analyzer API") |
|
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
classification_model = None |
|
knowledge_base = None |
|
idx_to_class_abbr = {0: 'MEL', 1: 'NV', 2: 'BCC', 3: 'AKIEC', 4: 'BKL', 5: 'DF', 6: 'VASC'} |
|
transform_classify = None |
|
|
|
class ImageRequest(BaseModel): |
|
image_base64: str |
|
|
|
@app.on_event("startup") |
|
def load_assets(): |
|
global classification_model, knowledge_base, transform_classify |
|
|
|
if classification_model is not None and knowledge_base is not None: |
|
print("🔁 Models and knowledge base already loaded. Skipping reloading.") |
|
return |
|
|
|
print("--> API starting up: Loading classification model and knowledge base...") |
|
|
|
|
|
cache_dir = "/tmp/models_cache" |
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
|
|
|
try: |
|
print(" Downloading EfficientNet-B3 classification model...") |
|
class_model_path = hf_hub_download( |
|
repo_id="sheikh987/efficientnet-b3-skin", |
|
filename="efficientnet_b3_skin_model.pth", |
|
cache_dir=cache_dir |
|
) |
|
checkpoint = torch.load(class_model_path, map_location=DEVICE) |
|
|
|
|
|
if isinstance(checkpoint, dict): |
|
classification_model = timm.create_model( |
|
'efficientnet_b3', pretrained=False, num_classes=7 |
|
).to(DEVICE) |
|
classification_model.load_state_dict(checkpoint, strict=False) |
|
else: |
|
classification_model = checkpoint.to(DEVICE) |
|
|
|
classification_model.eval() |
|
print(" ✅ Classification model loaded successfully.") |
|
except Exception as e: |
|
raise RuntimeError(f"Failed to load classification model: {e}") |
|
|
|
|
|
try: |
|
with open('knowledge_base.json', 'r') as f: |
|
knowledge_base = json.load(f) |
|
print(" ✅ Knowledge base loaded.") |
|
except Exception as e: |
|
raise RuntimeError(f"Failed to load knowledge_base.json: {e}") |
|
|
|
|
|
transform_classify = transforms.Compose([ |
|
transforms.Resize((300, 300)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
print("--> API is ready to accept requests.") |
|
|
|
|
|
@app.post("/analyze/") |
|
async def analyze_image(request: ImageRequest): |
|
if not all([classification_model, knowledge_base]): |
|
raise HTTPException(status_code=503, detail="Model or knowledge base not loaded yet.") |
|
|
|
try: |
|
image_data = base64.b64decode(request.image_base64) |
|
image = Image.open(io.BytesIO(image_data)).convert("RGB") |
|
except Exception: |
|
raise HTTPException(status_code=400, detail="Invalid base64 image data provided.") |
|
|
|
|
|
class_input_tensor = transform_classify(image).unsqueeze(0).to(DEVICE) |
|
|
|
|
|
with torch.no_grad(): |
|
logits = classification_model(class_input_tensor) |
|
probabilities = torch.nn.functional.softmax(logits, dim=1) |
|
|
|
confidence, predicted_idx = torch.max(probabilities, 1) |
|
confidence_percent = confidence.item() * 100 |
|
predicted_abbr = idx_to_class_abbr[predicted_idx.item()] |
|
info = knowledge_base.get(predicted_abbr, {}) |
|
|
|
|
|
CONFIDENCE_THRESHOLD = 50.0 |
|
if confidence_percent < CONFIDENCE_THRESHOLD: |
|
return { |
|
"status": "Inconclusive", |
|
"message": f"Model confidence ({confidence_percent:.2f}%) is below the threshold of {CONFIDENCE_THRESHOLD}%." |
|
} |
|
|
|
return { |
|
"status": "Success", |
|
"prediction": info, |
|
"abbreviation": predicted_abbr, |
|
"confidence": f"{confidence_percent:.2f}%" |
|
} |
|
|
|
@app.get("/") |
|
def root(): |
|
return {"message": "AI Skin Lesion Analyzer API is running."} |
|
|