sheikh987 commited on
Commit
aa0d9c9
·
verified ·
1 Parent(s): 42e72fa

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +89 -53
main.py CHANGED
@@ -1,107 +1,143 @@
1
  # main.py
2
 
3
- import base64, io, json, numpy as np, torch
 
 
 
 
4
  from fastapi import FastAPI, HTTPException
5
  from huggingface_hub import hf_hub_download
6
  from PIL import Image
7
  from pydantic import BaseModel
8
  from torchvision import transforms
9
- import timm, segmentation_models_pytorch as smp, albumentations as A
 
 
 
 
10
  from albumentations.pytorch import ToTensorV2
11
 
12
- # --- 1. SETUP ---
13
  app = FastAPI(title="AI Skin Lesion Analyzer API")
14
 
 
15
  DEVICE = "cpu"
16
- segmentation_model, classification_model, knowledge_base = None, None, None
 
 
17
  idx_to_class_abbr = {0: 'MEL', 1: 'NV', 2: 'BCC', 3: 'AKIEC', 4: 'BKL', 5: 'DF', 6: 'VASC'}
18
- transform_segment, transform_classify = None, None
 
19
 
 
20
  class ImageRequest(BaseModel):
21
  image_base64: str
22
 
 
23
  @app.on_event("startup")
24
  def load_assets():
 
25
  global segmentation_model, classification_model, knowledge_base, transform_segment, transform_classify
26
- print("--> API starting up: Downloading models...")
27
 
 
 
 
28
  try:
29
- seg_model_path = hf_hub_download("sheikh987/unet-isic2018", "unet_full_data_best_model.pth")
 
30
  segmentation_model = smp.Unet("resnet34", encoder_weights=None, in_channels=3, classes=1).to(DEVICE)
31
  segmentation_model.load_state_dict(torch.load(seg_model_path, map_location=DEVICE))
32
  segmentation_model.eval()
33
- print(" Segmentation model loaded.")
34
  except Exception as e:
35
  print(f"!!! FATAL: Could not load segmentation model: {e}")
36
 
 
37
  try:
38
- class_model_path = hf_hub_download("sheikh987/efficientnet-isic", "efficientnet_augmented_best.pth")
 
39
  classification_model = timm.create_model('efficientnet_b3', pretrained=False, num_classes=7).to(DEVICE)
40
  classification_model.load_state_dict(torch.load(class_model_path, map_location=DEVICE))
41
  classification_model.eval()
42
- print(" Classification model loaded.")
43
  except Exception as e:
44
  print(f"!!! FATAL: Could not load classification model: {e}")
45
 
46
- with open('knowledge_base.json', 'r') as f:
47
- knowledge_base = json.load(f)
48
- print(" Knowledge base loaded.")
 
 
 
 
49
 
 
50
  transform_segment = A.Compose([A.Resize(256, 256), A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0), ToTensorV2()])
51
  transform_classify = transforms.Compose([transforms.Resize((300, 300)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
52
- print("--> API ready.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- def process_image(image_np):
55
  # STAGE 1: SEGMENTATION
56
- aug = transform_segment(image=image_np)
57
- tensor = aug['image'].unsqueeze(0).to(DEVICE)
58
  with torch.no_grad():
59
- logits = segmentation_model(tensor)
60
- mask = (torch.sigmoid(logits) > 0.5).float().squeeze().cpu().numpy()
61
-
62
- if mask.sum() < 200: return None
 
63
 
64
- # STAGE 2: CROP
65
- rows, cols = np.any(mask, axis=1), np.any(mask, axis=0)
66
  rmin, rmax = np.where(rows)[0][[0, -1]]
67
  cmin, cmax = np.where(cols)[0][[0, -1]]
68
  padding = 15
69
  rmin, rmax = max(0, rmin - padding), min(image_np.shape[0], rmax + padding)
70
  cmin, cmax = max(0, cmin - padding), min(image_np.shape[1], cmax + padding)
71
- cropped = Image.fromarray(image_np[rmin:rmax, cmin:cmax])
72
 
73
- # STAGE 3: CLASSIFY
74
- tensor = transform_classify(cropped).unsqueeze(0).to(DEVICE)
75
  with torch.no_grad():
76
- logits = classification_model(tensor)
77
- probs = torch.nn.functional.softmax(logits, dim=1)
78
 
79
- conf, idx = torch.max(probs, 1)
80
- return idx.item(), conf.item()
81
-
82
- # --- 2. API ENDPOINTS ---
83
- @app.post("/analyze/")
84
- async def analyze_image(request: ImageRequest):
85
- if not all([segmentation_model, classification_model]):
86
- raise HTTPException(status_code=503, detail="Models are not ready.")
87
- try:
88
- img_data = base64.b64decode(request.image_base64)
89
- img_np = np.array(Image.open(io.BytesIO(img_data)).convert("RGB"))
90
- except:
91
- raise HTTPException(status_code=400, detail="Invalid base64 image.")
92
-
93
- analysis = process_image(img_np)
94
- if analysis is None:
95
- return {"status": "Failed", "message": "No lesion could be identified."}
96
-
97
- pred_idx, confidence = analysis
98
- if confidence < 0.75:
99
- return {"status": "Inconclusive", "message": f"Model confidence ({confidence*100:.2f}%) is below the 75% threshold."}
100
-
101
- pred_abbr = idx_to_class_abbr[pred_idx]
102
- info = knowledge_base.get(pred_abbr, {})
103
- return {"status": "Success", "prediction": info, "abbreviation": pred_abbr, "confidence": f"{confidence*100:.2f}%"}
104
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  @app.get("/")
106
  def root():
107
  return {"message": "AI Skin Lesion Analyzer API is running."}
 
1
  # main.py
2
 
3
+ import base64
4
+ import io
5
+ import json
6
+ import numpy as np
7
+ import torch
8
  from fastapi import FastAPI, HTTPException
9
  from huggingface_hub import hf_hub_download
10
  from PIL import Image
11
  from pydantic import BaseModel
12
  from torchvision import transforms
13
+
14
+ # Import necessary model and processing libraries
15
+ import timm
16
+ import segmentation_models_pytorch as smp
17
+ import albumentations as A
18
  from albumentations.pytorch import ToTensorV2
19
 
20
+ # --- 1. SETUP: Create the FastAPI app ---
21
  app = FastAPI(title="AI Skin Lesion Analyzer API")
22
 
23
+ # --- Global variables to hold the loaded models and assets ---
24
  DEVICE = "cpu"
25
+ segmentation_model = None
26
+ classification_model = None
27
+ knowledge_base = None
28
  idx_to_class_abbr = {0: 'MEL', 1: 'NV', 2: 'BCC', 3: 'AKIEC', 4: 'BKL', 5: 'DF', 6: 'VASC'}
29
+ transform_segment = None
30
+ transform_classify = None
31
 
32
+ # --- Define the request body model for receiving the image ---
33
  class ImageRequest(BaseModel):
34
  image_base64: str
35
 
36
+ # --- This function runs once when the server starts up ---
37
  @app.on_event("startup")
38
  def load_assets():
39
+ """Load all models and assets from Hugging Face Hub when the server starts."""
40
  global segmentation_model, classification_model, knowledge_base, transform_segment, transform_classify
 
41
 
42
+ print("--> API starting up: This may take a few minutes...")
43
+
44
+ # Load Segmentation Model
45
  try:
46
+ print(" Downloading UNet segmentation model...")
47
+ seg_model_path = hf_hub_download(repo_id="sheikh987/unet-isic2018", filename="unet_full_data_best_model.pth")
48
  segmentation_model = smp.Unet("resnet34", encoder_weights=None, in_channels=3, classes=1).to(DEVICE)
49
  segmentation_model.load_state_dict(torch.load(seg_model_path, map_location=DEVICE))
50
  segmentation_model.eval()
51
+ print(" Segmentation model loaded.")
52
  except Exception as e:
53
  print(f"!!! FATAL: Could not load segmentation model: {e}")
54
 
55
+ # Load Classification Model
56
  try:
57
+ print(" Downloading EfficientNet classification model...")
58
+ class_model_path = hf_hub_download(repo_id="sheikh987/efficientnet-isic-classifier", filename="efficientnet_isic_classifier_best.pth")
59
  classification_model = timm.create_model('efficientnet_b3', pretrained=False, num_classes=7).to(DEVICE)
60
  classification_model.load_state_dict(torch.load(class_model_path, map_location=DEVICE))
61
  classification_model.eval()
62
+ print(" Classification model loaded.")
63
  except Exception as e:
64
  print(f"!!! FATAL: Could not load classification model: {e}")
65
 
66
+ # Load Knowledge Base
67
+ try:
68
+ with open('knowledge_base.json', 'r') as f:
69
+ knowledge_base = json.load(f)
70
+ print(" ✅ Knowledge base loaded.")
71
+ except Exception as e:
72
+ print(f"!!! FATAL: Could not load knowledge_base.json: {e}")
73
 
74
+ # Define Image Transforms
75
  transform_segment = A.Compose([A.Resize(256, 256), A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0), ToTensorV2()])
76
  transform_classify = transforms.Compose([transforms.Resize((300, 300)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
77
+
78
+ print("\n--> API is ready to accept requests.")
79
+
80
+ # --- 2. DEFINE THE MAIN API ENDPOINT ---
81
+ @app.post("/analyze/")
82
+ async def analyze_image(request: ImageRequest):
83
+ """
84
+ This endpoint receives a base64 encoded image, runs the full analysis pipeline,
85
+ and returns a JSON response.
86
+ """
87
+ if not all([segmentation_model, classification_model, knowledge_base]):
88
+ raise HTTPException(status_code=503, detail="Models are not ready. The server may still be starting up. Please try again in a minute.")
89
+
90
+ try:
91
+ image_data = base64.b64decode(request.image_base64)
92
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
93
+ image_np = np.array(image)
94
+ except:
95
+ raise HTTPException(status_code=400, detail="Invalid base64 image data provided.")
96
 
97
+ # --- Full Pipeline Logic ---
98
  # STAGE 1: SEGMENTATION
99
+ augmented_seg = transform_segment(image=image_np)
100
+ seg_input_tensor = augmented_seg['image'].unsqueeze(0).to(DEVICE)
101
  with torch.no_grad():
102
+ seg_logits = segmentation_model(seg_input_tensor)
103
+ seg_mask = (torch.sigmoid(seg_logits) > 0.5).float().squeeze().cpu().numpy()
104
+
105
+ if seg_mask.sum() < 200:
106
+ return {"status": "Failed", "message": "No lesion could be clearly identified in the image."}
107
 
108
+ # STAGE 2: CROP AND CLASSIFY
109
+ rows, cols = np.any(seg_mask, axis=1), np.any(seg_mask, axis=0)
110
  rmin, rmax = np.where(rows)[0][[0, -1]]
111
  cmin, cmax = np.where(cols)[0][[0, -1]]
112
  padding = 15
113
  rmin, rmax = max(0, rmin - padding), min(image_np.shape[0], rmax + padding)
114
  cmin, cmax = max(0, cmin - padding), min(image_np.shape[1], cmax + padding)
115
+ cropped_image_pil = Image.fromarray(image_np[rmin:rmax, cmin:cmax])
116
 
117
+ class_input_tensor = transform_classify(cropped_image_pil).unsqueeze(0).to(DEVICE)
 
118
  with torch.no_grad():
119
+ class_logits = classification_model(class_input_tensor)
120
+ probabilities = torch.nn.functional.softmax(class_logits, dim=1)
121
 
122
+ confidence, predicted_idx = torch.max(probabilities, 1)
123
+ confidence_percent = confidence.item() * 100
124
+
125
+ # SAFETY NET
126
+ if confidence_percent < 75.0:
127
+ return {"status": "Inconclusive", "message": f"Model confidence ({confidence_percent:.2f}%) is below the 75% threshold."}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
+ # STAGE 3: LOOKUP AND RETURN RESULT
130
+ predicted_abbr = idx_to_class_abbr[predicted_idx.item()]
131
+ info = knowledge_base.get(predicted_abbr, {})
132
+
133
+ return {
134
+ "status": "Success",
135
+ "prediction": info,
136
+ "abbreviation": predicted_abbr,
137
+ "confidence": f"{confidence_percent:.2f}%"
138
+ }
139
+
140
+ # --- Root endpoint to check if the API is alive ---
141
  @app.get("/")
142
  def root():
143
  return {"message": "AI Skin Lesion Analyzer API is running."}