sheikh987 commited on
Commit
4bb7ae0
·
verified ·
1 Parent(s): ba3cd47

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +42 -73
main.py CHANGED
@@ -8,23 +8,17 @@ from huggingface_hub import hf_hub_download
8
  from PIL import Image
9
  from pydantic import BaseModel
10
  from torchvision import transforms
11
-
12
  import timm
13
- import segmentation_models_pytorch as smp
14
- import albumentations as A
15
- from albumentations.pytorch import ToTensorV2
16
  import os
17
 
18
  # --- 1. SETUP: Create the FastAPI app ---
19
  app = FastAPI(title="AI Skin Lesion Analyzer API")
20
 
21
  # --- Global variables ---
22
- DEVICE = "cpu"
23
- segmentation_model = None
24
  classification_model = None
25
  knowledge_base = None
26
  idx_to_class_abbr = {0: 'MEL', 1: 'NV', 2: 'BCC', 3: 'AKIEC', 4: 'BKL', 5: 'DF', 6: 'VASC'}
27
- transform_segment = None
28
  transform_classify = None
29
 
30
  class ImageRequest(BaseModel):
@@ -32,48 +26,41 @@ class ImageRequest(BaseModel):
32
 
33
  @app.on_event("startup")
34
  def load_assets():
35
- global segmentation_model, classification_model, knowledge_base, transform_segment, transform_classify
36
 
37
- # Haddii hore loo load gareeyay, iska dhaaf
38
- if segmentation_model is not None and classification_model is not None and knowledge_base is not None:
39
  print("🔁 Models and knowledge base already loaded. Skipping reloading.")
40
  return
41
 
42
- print("--> API starting up: This may take a few minutes...")
43
 
44
  # Use /tmp for writable cache directory
45
  cache_dir = "/tmp/models_cache"
46
  os.makedirs(cache_dir, exist_ok=True)
47
 
48
- # Load Segmentation Model
49
- try:
50
- print(" Downloading UNet segmentation model...")
51
- seg_model_path = hf_hub_download(
52
- repo_id="sheikh987/unet-isic2018",
53
- filename="unet_full_data_best_model.pth",
54
- cache_dir=cache_dir
55
- )
56
- segmentation_model = smp.Unet("resnet34", encoder_weights=None, in_channels=3, classes=1).to(DEVICE)
57
- segmentation_model.load_state_dict(torch.load(seg_model_path, map_location=DEVICE))
58
- segmentation_model.eval()
59
- print(" ✅ Segmentation model loaded.")
60
- except Exception as e:
61
- print(f"!!! FATAL: Could not load segmentation model: {e}")
62
-
63
  # Load Classification Model
64
  try:
65
- print(" Downloading EfficientNet classification model...")
66
  class_model_path = hf_hub_download(
67
- repo_id="sheikh987/efficientnet-isic",
68
- filename="efficientnet_augmented_best.pth",
69
  cache_dir=cache_dir
70
  )
71
- classification_model = timm.create_model('efficientnet_b3', pretrained=False, num_classes=7).to(DEVICE)
72
- classification_model.load_state_dict(torch.load(class_model_path, map_location=DEVICE))
 
 
 
 
 
 
 
 
 
73
  classification_model.eval()
74
- print(" ✅ Classification model loaded.")
75
  except Exception as e:
76
- print(f"!!! FATAL: Could not load classification model: {e}")
77
 
78
  # Load Knowledge Base
79
  try:
@@ -81,69 +68,51 @@ def load_assets():
81
  knowledge_base = json.load(f)
82
  print(" ✅ Knowledge base loaded.")
83
  except Exception as e:
84
- print(f"!!! FATAL: Could not load knowledge_base.json: {e}")
85
-
86
- # Define Image Transforms
87
- transform_segment = A.Compose([
88
- A.Resize(256, 256),
89
- A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0),
90
- ToTensorV2()
91
- ])
92
 
 
93
  transform_classify = transforms.Compose([
94
  transforms.Resize((300, 300)),
95
  transforms.ToTensor(),
96
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
 
97
  ])
98
 
99
- print("\n--> API is ready to accept requests.")
 
100
 
101
  @app.post("/analyze/")
102
  async def analyze_image(request: ImageRequest):
103
- if not all([segmentation_model, classification_model, knowledge_base]):
104
- raise HTTPException(status_code=503, detail="Models not loaded yet. Please retry shortly.")
105
 
106
  try:
107
  image_data = base64.b64decode(request.image_base64)
108
  image = Image.open(io.BytesIO(image_data)).convert("RGB")
109
- image_np = np.array(image)
110
  except Exception:
111
  raise HTTPException(status_code=400, detail="Invalid base64 image data provided.")
112
 
113
- # Stage 1: Segmentation
114
- augmented_seg = transform_segment(image=image_np)
115
- seg_input_tensor = augmented_seg['image'].unsqueeze(0).to(DEVICE)
116
- with torch.no_grad():
117
- seg_logits = segmentation_model(seg_input_tensor)
118
- seg_mask = (torch.sigmoid(seg_logits) > 0.5).float().squeeze().cpu().numpy()
119
-
120
- if seg_mask.sum() < 200:
121
- return {"status": "Failed", "message": "No lesion could be clearly identified in the image."}
122
-
123
- # Stage 2: Crop and classify
124
- rows, cols = np.any(seg_mask, axis=1), np.any(seg_mask, axis=0)
125
- rmin, rmax = np.where(rows)[0][[0, -1]]
126
- cmin, cmax = np.where(cols)[0][[0, -1]]
127
- padding = 15
128
- rmin, rmax = max(0, rmin - padding), min(image_np.shape[0], rmax + padding)
129
- cmin, cmax = max(0, cmin - padding), min(image_np.shape[1], cmax + padding)
130
- cropped_image_pil = Image.fromarray(image_np[rmin:rmax, cmin:cmax])
131
-
132
- class_input_tensor = transform_classify(cropped_image_pil).unsqueeze(0).to(DEVICE)
133
  with torch.no_grad():
134
- class_logits = classification_model(class_input_tensor)
135
- probabilities = torch.nn.functional.softmax(class_logits, dim=1)
136
 
137
  confidence, predicted_idx = torch.max(probabilities, 1)
138
  confidence_percent = confidence.item() * 100
139
-
140
- if confidence_percent < 50.0:
141
- return {"status": "Inconclusive", "message": f"Model confidence ({confidence_percent:.2f}%) is below the 75% threshold."}
142
-
143
- # Stage 3: Return result
144
  predicted_abbr = idx_to_class_abbr[predicted_idx.item()]
145
  info = knowledge_base.get(predicted_abbr, {})
146
 
 
 
 
 
 
 
 
 
147
  return {
148
  "status": "Success",
149
  "prediction": info,
 
8
  from PIL import Image
9
  from pydantic import BaseModel
10
  from torchvision import transforms
 
11
  import timm
 
 
 
12
  import os
13
 
14
  # --- 1. SETUP: Create the FastAPI app ---
15
  app = FastAPI(title="AI Skin Lesion Analyzer API")
16
 
17
  # --- Global variables ---
18
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
19
  classification_model = None
20
  knowledge_base = None
21
  idx_to_class_abbr = {0: 'MEL', 1: 'NV', 2: 'BCC', 3: 'AKIEC', 4: 'BKL', 5: 'DF', 6: 'VASC'}
 
22
  transform_classify = None
23
 
24
  class ImageRequest(BaseModel):
 
26
 
27
  @app.on_event("startup")
28
  def load_assets():
29
+ global classification_model, knowledge_base, transform_classify
30
 
31
+ if classification_model is not None and knowledge_base is not None:
 
32
  print("🔁 Models and knowledge base already loaded. Skipping reloading.")
33
  return
34
 
35
+ print("--> API starting up: Loading classification model and knowledge base...")
36
 
37
  # Use /tmp for writable cache directory
38
  cache_dir = "/tmp/models_cache"
39
  os.makedirs(cache_dir, exist_ok=True)
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  # Load Classification Model
42
  try:
43
+ print(" Downloading EfficientNet-B3 classification model...")
44
  class_model_path = hf_hub_download(
45
+ repo_id="sheikh987/efficientnet-b3-skin",
46
+ filename="efficientnet_b3_skin_model.pth",
47
  cache_dir=cache_dir
48
  )
49
+ checkpoint = torch.load(class_model_path, map_location=DEVICE)
50
+
51
+ # Auto-detect state_dict vs full model
52
+ if isinstance(checkpoint, dict):
53
+ classification_model = timm.create_model(
54
+ 'efficientnet_b3', pretrained=False, num_classes=7
55
+ ).to(DEVICE)
56
+ classification_model.load_state_dict(checkpoint, strict=False)
57
+ else:
58
+ classification_model = checkpoint.to(DEVICE)
59
+
60
  classification_model.eval()
61
+ print(" ✅ Classification model loaded successfully.")
62
  except Exception as e:
63
+ raise RuntimeError(f"Failed to load classification model: {e}")
64
 
65
  # Load Knowledge Base
66
  try:
 
68
  knowledge_base = json.load(f)
69
  print(" ✅ Knowledge base loaded.")
70
  except Exception as e:
71
+ raise RuntimeError(f"Failed to load knowledge_base.json: {e}")
 
 
 
 
 
 
 
72
 
73
+ # Define Image Transform
74
  transform_classify = transforms.Compose([
75
  transforms.Resize((300, 300)),
76
  transforms.ToTensor(),
77
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
78
+ std=[0.229, 0.224, 0.225])
79
  ])
80
 
81
+ print("--> API is ready to accept requests.")
82
+
83
 
84
  @app.post("/analyze/")
85
  async def analyze_image(request: ImageRequest):
86
+ if not all([classification_model, knowledge_base]):
87
+ raise HTTPException(status_code=503, detail="Model or knowledge base not loaded yet.")
88
 
89
  try:
90
  image_data = base64.b64decode(request.image_base64)
91
  image = Image.open(io.BytesIO(image_data)).convert("RGB")
 
92
  except Exception:
93
  raise HTTPException(status_code=400, detail="Invalid base64 image data provided.")
94
 
95
+ # Prepare image
96
+ class_input_tensor = transform_classify(image).unsqueeze(0).to(DEVICE)
97
+
98
+ # Classification
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  with torch.no_grad():
100
+ logits = classification_model(class_input_tensor)
101
+ probabilities = torch.nn.functional.softmax(logits, dim=1)
102
 
103
  confidence, predicted_idx = torch.max(probabilities, 1)
104
  confidence_percent = confidence.item() * 100
 
 
 
 
 
105
  predicted_abbr = idx_to_class_abbr[predicted_idx.item()]
106
  info = knowledge_base.get(predicted_abbr, {})
107
 
108
+ # Optional: Confidence threshold
109
+ CONFIDENCE_THRESHOLD = 50.0
110
+ if confidence_percent < CONFIDENCE_THRESHOLD:
111
+ return {
112
+ "status": "Inconclusive",
113
+ "message": f"Model confidence ({confidence_percent:.2f}%) is below the threshold of {CONFIDENCE_THRESHOLD}%."
114
+ }
115
+
116
  return {
117
  "status": "Success",
118
  "prediction": info,