skintest / main.py
sheikh987's picture
Update main.py
4bb7ae0 verified
raw
history blame
4.26 kB
import base64
import io
import json
import numpy as np
import torch
from fastapi import FastAPI, HTTPException
from huggingface_hub import hf_hub_download
from PIL import Image
from pydantic import BaseModel
from torchvision import transforms
import timm
import os
# --- 1. SETUP: Create the FastAPI app ---
app = FastAPI(title="AI Skin Lesion Analyzer API")
# --- Global variables ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
classification_model = None
knowledge_base = None
idx_to_class_abbr = {0: 'MEL', 1: 'NV', 2: 'BCC', 3: 'AKIEC', 4: 'BKL', 5: 'DF', 6: 'VASC'}
transform_classify = None
class ImageRequest(BaseModel):
image_base64: str
@app.on_event("startup")
def load_assets():
global classification_model, knowledge_base, transform_classify
if classification_model is not None and knowledge_base is not None:
print("🔁 Models and knowledge base already loaded. Skipping reloading.")
return
print("--> API starting up: Loading classification model and knowledge base...")
# Use /tmp for writable cache directory
cache_dir = "/tmp/models_cache"
os.makedirs(cache_dir, exist_ok=True)
# Load Classification Model
try:
print(" Downloading EfficientNet-B3 classification model...")
class_model_path = hf_hub_download(
repo_id="sheikh987/efficientnet-b3-skin",
filename="efficientnet_b3_skin_model.pth",
cache_dir=cache_dir
)
checkpoint = torch.load(class_model_path, map_location=DEVICE)
# Auto-detect state_dict vs full model
if isinstance(checkpoint, dict):
classification_model = timm.create_model(
'efficientnet_b3', pretrained=False, num_classes=7
).to(DEVICE)
classification_model.load_state_dict(checkpoint, strict=False)
else:
classification_model = checkpoint.to(DEVICE)
classification_model.eval()
print(" ✅ Classification model loaded successfully.")
except Exception as e:
raise RuntimeError(f"Failed to load classification model: {e}")
# Load Knowledge Base
try:
with open('knowledge_base.json', 'r') as f:
knowledge_base = json.load(f)
print(" ✅ Knowledge base loaded.")
except Exception as e:
raise RuntimeError(f"Failed to load knowledge_base.json: {e}")
# Define Image Transform
transform_classify = transforms.Compose([
transforms.Resize((300, 300)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
print("--> API is ready to accept requests.")
@app.post("/analyze/")
async def analyze_image(request: ImageRequest):
if not all([classification_model, knowledge_base]):
raise HTTPException(status_code=503, detail="Model or knowledge base not loaded yet.")
try:
image_data = base64.b64decode(request.image_base64)
image = Image.open(io.BytesIO(image_data)).convert("RGB")
except Exception:
raise HTTPException(status_code=400, detail="Invalid base64 image data provided.")
# Prepare image
class_input_tensor = transform_classify(image).unsqueeze(0).to(DEVICE)
# Classification
with torch.no_grad():
logits = classification_model(class_input_tensor)
probabilities = torch.nn.functional.softmax(logits, dim=1)
confidence, predicted_idx = torch.max(probabilities, 1)
confidence_percent = confidence.item() * 100
predicted_abbr = idx_to_class_abbr[predicted_idx.item()]
info = knowledge_base.get(predicted_abbr, {})
# Optional: Confidence threshold
CONFIDENCE_THRESHOLD = 50.0
if confidence_percent < CONFIDENCE_THRESHOLD:
return {
"status": "Inconclusive",
"message": f"Model confidence ({confidence_percent:.2f}%) is below the threshold of {CONFIDENCE_THRESHOLD}%."
}
return {
"status": "Success",
"prediction": info,
"abbreviation": predicted_abbr,
"confidence": f"{confidence_percent:.2f}%"
}
@app.get("/")
def root():
return {"message": "AI Skin Lesion Analyzer API is running."}