sheikh987 commited on
Commit
935c075
·
verified ·
1 Parent(s): 2caa3b3

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +107 -0
main.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py
2
+
3
+ import base64, io, json, numpy as np, torch
4
+ from fastapi import FastAPI, HTTPException
5
+ from huggingface_hub import hf_hub_download
6
+ from PIL import Image
7
+ from pydantic import BaseModel
8
+ from torchvision import transforms
9
+ import timm, segmentation_models_pytorch as smp, albumentations as A
10
+ from albumentations.pytorch import ToTensorV2
11
+
12
+ # --- 1. SETUP ---
13
+ app = FastAPI(title="AI Skin Lesion Analyzer API")
14
+
15
+ DEVICE = "cpu"
16
+ segmentation_model, classification_model, knowledge_base = None, None, None
17
+ idx_to_class_abbr = {0: 'MEL', 1: 'NV', 2: 'BCC', 3: 'AKIEC', 4: 'BKL', 5: 'DF', 6: 'VASC'}
18
+ transform_segment, transform_classify = None, None
19
+
20
+ class ImageRequest(BaseModel):
21
+ image_base64: str
22
+
23
+ @app.on_event("startup")
24
+ def load_assets():
25
+ global segmentation_model, classification_model, knowledge_base, transform_segment, transform_classify
26
+ print("--> API starting up: Downloading models...")
27
+
28
+ try:
29
+ seg_model_path = hf_hub_download("sheikh987/unet-isic2018", "unet_full_data_best_model.pth")
30
+ segmentation_model = smp.Unet("resnet34", encoder_weights=None, in_channels=3, classes=1).to(DEVICE)
31
+ segmentation_model.load_state_dict(torch.load(seg_model_path, map_location=DEVICE))
32
+ segmentation_model.eval()
33
+ print(" Segmentation model loaded.")
34
+ except Exception as e:
35
+ print(f"!!! FATAL: Could not load segmentation model: {e}")
36
+
37
+ try:
38
+ class_model_path = hf_hub_download("sheikh987/efficientnet-isic", "efficientnet_augmented_best.pth")
39
+ classification_model = timm.create_model('efficientnet_b3', pretrained=False, num_classes=7).to(DEVICE)
40
+ classification_model.load_state_dict(torch.load(class_model_path, map_location=DEVICE))
41
+ classification_model.eval()
42
+ print(" Classification model loaded.")
43
+ except Exception as e:
44
+ print(f"!!! FATAL: Could not load classification model: {e}")
45
+
46
+ with open('knowledge_base.json', 'r') as f:
47
+ knowledge_base = json.load(f)
48
+ print(" Knowledge base loaded.")
49
+
50
+ transform_segment = A.Compose([A.Resize(256, 256), A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0), ToTensorV2()])
51
+ transform_classify = transforms.Compose([transforms.Resize((300, 300)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
52
+ print("--> API ready.")
53
+
54
+ def process_image(image_np):
55
+ # STAGE 1: SEGMENTATION
56
+ aug = transform_segment(image=image_np)
57
+ tensor = aug['image'].unsqueeze(0).to(DEVICE)
58
+ with torch.no_grad():
59
+ logits = segmentation_model(tensor)
60
+ mask = (torch.sigmoid(logits) > 0.5).float().squeeze().cpu().numpy()
61
+
62
+ if mask.sum() < 200: return None
63
+
64
+ # STAGE 2: CROP
65
+ rows, cols = np.any(mask, axis=1), np.any(mask, axis=0)
66
+ rmin, rmax = np.where(rows)[0][[0, -1]]
67
+ cmin, cmax = np.where(cols)[0][[0, -1]]
68
+ padding = 15
69
+ rmin, rmax = max(0, rmin - padding), min(image_np.shape[0], rmax + padding)
70
+ cmin, cmax = max(0, cmin - padding), min(image_np.shape[1], cmax + padding)
71
+ cropped = Image.fromarray(image_np[rmin:rmax, cmin:cmax])
72
+
73
+ # STAGE 3: CLASSIFY
74
+ tensor = transform_classify(cropped).unsqueeze(0).to(DEVICE)
75
+ with torch.no_grad():
76
+ logits = classification_model(tensor)
77
+ probs = torch.nn.functional.softmax(logits, dim=1)
78
+
79
+ conf, idx = torch.max(probs, 1)
80
+ return idx.item(), conf.item()
81
+
82
+ # --- 2. API ENDPOINTS ---
83
+ @app.post("/analyze/")
84
+ async def analyze_image(request: ImageRequest):
85
+ if not all([segmentation_model, classification_model]):
86
+ raise HTTPException(status_code=503, detail="Models are not ready.")
87
+ try:
88
+ img_data = base64.b64decode(request.image_base64)
89
+ img_np = np.array(Image.open(io.BytesIO(img_data)).convert("RGB"))
90
+ except:
91
+ raise HTTPException(status_code=400, detail="Invalid base64 image.")
92
+
93
+ analysis = process_image(img_np)
94
+ if analysis is None:
95
+ return {"status": "Failed", "message": "No lesion could be identified."}
96
+
97
+ pred_idx, confidence = analysis
98
+ if confidence < 0.75:
99
+ return {"status": "Inconclusive", "message": f"Model confidence ({confidence*100:.2f}%) is below the 75% threshold."}
100
+
101
+ pred_abbr = idx_to_class_abbr[pred_idx]
102
+ info = knowledge_base.get(pred_abbr, {})
103
+ return {"status": "Success", "prediction": info, "abbreviation": pred_abbr, "confidence": f"{confidence*100:.2f}%"}
104
+
105
+ @app.get("/")
106
+ def root():
107
+ return {"message": "AI Skin Lesion Analyzer API is running."}