sheikh987 commited on
Commit
cafb713
·
verified ·
1 Parent(s): aa0d9c9

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +35 -17
main.py CHANGED
@@ -1,5 +1,3 @@
1
- # main.py
2
-
3
  import base64
4
  import io
5
  import json
@@ -16,6 +14,7 @@ import timm
16
  import segmentation_models_pytorch as smp
17
  import albumentations as A
18
  from albumentations.pytorch import ToTensorV2
 
19
 
20
  # --- 1. SETUP: Create the FastAPI app ---
21
  app = FastAPI(title="AI Skin Lesion Analyzer API")
@@ -38,13 +37,20 @@ class ImageRequest(BaseModel):
38
  def load_assets():
39
  """Load all models and assets from Hugging Face Hub when the server starts."""
40
  global segmentation_model, classification_model, knowledge_base, transform_segment, transform_classify
41
-
42
  print("--> API starting up: This may take a few minutes...")
43
-
 
 
 
44
  # Load Segmentation Model
45
  try:
46
  print(" Downloading UNet segmentation model...")
47
- seg_model_path = hf_hub_download(repo_id="sheikh987/unet-isic2018", filename="unet_full_data_best_model.pth")
 
 
 
 
48
  segmentation_model = smp.Unet("resnet34", encoder_weights=None, in_channels=3, classes=1).to(DEVICE)
49
  segmentation_model.load_state_dict(torch.load(seg_model_path, map_location=DEVICE))
50
  segmentation_model.eval()
@@ -55,7 +61,11 @@ def load_assets():
55
  # Load Classification Model
56
  try:
57
  print(" Downloading EfficientNet classification model...")
58
- class_model_path = hf_hub_download(repo_id="sheikh987/efficientnet-isic-classifier", filename="efficientnet_isic_classifier_best.pth")
 
 
 
 
59
  classification_model = timm.create_model('efficientnet_b3', pretrained=False, num_classes=7).to(DEVICE)
60
  classification_model.load_state_dict(torch.load(class_model_path, map_location=DEVICE))
61
  classification_model.eval()
@@ -70,11 +80,20 @@ def load_assets():
70
  print(" ✅ Knowledge base loaded.")
71
  except Exception as e:
72
  print(f"!!! FATAL: Could not load knowledge_base.json: {e}")
73
-
74
  # Define Image Transforms
75
- transform_segment = A.Compose([A.Resize(256, 256), A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0), ToTensorV2()])
76
- transform_classify = transforms.Compose([transforms.Resize((300, 300)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
77
-
 
 
 
 
 
 
 
 
 
78
  print("\n--> API is ready to accept requests.")
79
 
80
  # --- 2. DEFINE THE MAIN API ENDPOINT ---
@@ -101,7 +120,7 @@ async def analyze_image(request: ImageRequest):
101
  with torch.no_grad():
102
  seg_logits = segmentation_model(seg_input_tensor)
103
  seg_mask = (torch.sigmoid(seg_logits) > 0.5).float().squeeze().cpu().numpy()
104
-
105
  if seg_mask.sum() < 200:
106
  return {"status": "Failed", "message": "No lesion could be clearly identified in the image."}
107
 
@@ -118,18 +137,17 @@ async def analyze_image(request: ImageRequest):
118
  with torch.no_grad():
119
  class_logits = classification_model(class_input_tensor)
120
  probabilities = torch.nn.functional.softmax(class_logits, dim=1)
121
-
122
  confidence, predicted_idx = torch.max(probabilities, 1)
123
  confidence_percent = confidence.item() * 100
124
-
125
- # SAFETY NET
126
  if confidence_percent < 75.0:
127
- return {"status": "Inconclusive", "message": f"Model confidence ({confidence_percent:.2f}%) is below the 75% threshold."}
128
 
129
  # STAGE 3: LOOKUP AND RETURN RESULT
130
  predicted_abbr = idx_to_class_abbr[predicted_idx.item()]
131
  info = knowledge_base.get(predicted_abbr, {})
132
-
133
  return {
134
  "status": "Success",
135
  "prediction": info,
@@ -140,4 +158,4 @@ async def analyze_image(request: ImageRequest):
140
  # --- Root endpoint to check if the API is alive ---
141
  @app.get("/")
142
  def root():
143
- return {"message": "AI Skin Lesion Analyzer API is running."}
 
 
 
1
  import base64
2
  import io
3
  import json
 
14
  import segmentation_models_pytorch as smp
15
  import albumentations as A
16
  from albumentations.pytorch import ToTensorV2
17
+ import os
18
 
19
  # --- 1. SETUP: Create the FastAPI app ---
20
  app = FastAPI(title="AI Skin Lesion Analyzer API")
 
37
  def load_assets():
38
  """Load all models and assets from Hugging Face Hub when the server starts."""
39
  global segmentation_model, classification_model, knowledge_base, transform_segment, transform_classify
40
+
41
  print("--> API starting up: This may take a few minutes...")
42
+
43
+ # ✅ Create safe folder to store downloaded models
44
+ os.makedirs("./models_cache", exist_ok=True)
45
+
46
  # Load Segmentation Model
47
  try:
48
  print(" Downloading UNet segmentation model...")
49
+ seg_model_path = hf_hub_download(
50
+ repo_id="sheikh987/unet-isic2018",
51
+ filename="unet_full_data_best_model.pth",
52
+ cache_dir="./models_cache"
53
+ )
54
  segmentation_model = smp.Unet("resnet34", encoder_weights=None, in_channels=3, classes=1).to(DEVICE)
55
  segmentation_model.load_state_dict(torch.load(seg_model_path, map_location=DEVICE))
56
  segmentation_model.eval()
 
61
  # Load Classification Model
62
  try:
63
  print(" Downloading EfficientNet classification model...")
64
+ class_model_path = hf_hub_download(
65
+ repo_id="sheikh987/efficientnet-isic-classifier",
66
+ filename="efficientnet_isic_classifier_best.pth",
67
+ cache_dir="./models_cache"
68
+ )
69
  classification_model = timm.create_model('efficientnet_b3', pretrained=False, num_classes=7).to(DEVICE)
70
  classification_model.load_state_dict(torch.load(class_model_path, map_location=DEVICE))
71
  classification_model.eval()
 
80
  print(" ✅ Knowledge base loaded.")
81
  except Exception as e:
82
  print(f"!!! FATAL: Could not load knowledge_base.json: {e}")
83
+
84
  # Define Image Transforms
85
+ transform_segment = A.Compose([
86
+ A.Resize(256, 256),
87
+ A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0),
88
+ ToTensorV2()
89
+ ])
90
+
91
+ transform_classify = transforms.Compose([
92
+ transforms.Resize((300, 300)),
93
+ transforms.ToTensor(),
94
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
95
+ ])
96
+
97
  print("\n--> API is ready to accept requests.")
98
 
99
  # --- 2. DEFINE THE MAIN API ENDPOINT ---
 
120
  with torch.no_grad():
121
  seg_logits = segmentation_model(seg_input_tensor)
122
  seg_mask = (torch.sigmoid(seg_logits) > 0.5).float().squeeze().cpu().numpy()
123
+
124
  if seg_mask.sum() < 200:
125
  return {"status": "Failed", "message": "No lesion could be clearly identified in the image."}
126
 
 
137
  with torch.no_grad():
138
  class_logits = classification_model(class_input_tensor)
139
  probabilities = torch.nn.functional.softmax(class_logits, dim=1)
140
+
141
  confidence, predicted_idx = torch.max(probabilities, 1)
142
  confidence_percent = confidence.item() * 100
143
+
 
144
  if confidence_percent < 75.0:
145
+ return {"status": "Inconclusive", "message": f"Model confidence ({confidence_percent:.2f}%) is below the 75% threshold."}
146
 
147
  # STAGE 3: LOOKUP AND RETURN RESULT
148
  predicted_abbr = idx_to_class_abbr[predicted_idx.item()]
149
  info = knowledge_base.get(predicted_abbr, {})
150
+
151
  return {
152
  "status": "Success",
153
  "prediction": info,
 
158
  # --- Root endpoint to check if the API is alive ---
159
  @app.get("/")
160
  def root():
161
+ return {"message": "AI Skin Lesion Analyzer API is running."}