shekzee commited on
Commit
6604d70
·
verified ·
1 Parent(s): 74bc278

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -10
app.py CHANGED
@@ -6,8 +6,11 @@ from io import BytesIO
6
  import numpy as np
7
  import tensorflow as tf
8
 
9
- app = FastAPI()
 
 
10
 
 
11
  app.add_middleware(
12
  CORSMiddleware,
13
  allow_origins=["*"],
@@ -15,9 +18,6 @@ app.add_middleware(
15
  allow_headers=["*"],
16
  )
17
 
18
- # Load your trained segmentation model here
19
- # model = tf.keras.models.load_model("seg_model_path")
20
-
21
  @app.post("/predict")
22
  async def predict(file: UploadFile = File(...)):
23
  contents = await file.read()
@@ -26,12 +26,9 @@ async def predict(file: UploadFile = File(...)):
26
  arr = np.array(img) / 255.0
27
  arr = np.expand_dims(arr, 0)
28
 
29
- # Prediction
30
- prediction = model.predict(arr) # (1, 256, 256, num_classes)
31
- mask = np.argmax(prediction[0], axis=-1).astype(np.uint8) # (256, 256)
32
-
33
- # Convert to image (you can colorize or just multiply for visualization)
34
- mask_img = Image.fromarray(mask * 50) # Optional scaling for visibility
35
 
36
  buf = BytesIO()
37
  mask_img.save(buf, format='PNG')
 
6
  import numpy as np
7
  import tensorflow as tf
8
 
9
+ # --------- LOAD YOUR SEGMENTATION MODEL HERE ---------
10
+ model = tf.keras.models.load_model("seg_model") # <<<<=== THIS LINE!
11
+ # -----------------------------------------------------
12
 
13
+ app = FastAPI()
14
  app.add_middleware(
15
  CORSMiddleware,
16
  allow_origins=["*"],
 
18
  allow_headers=["*"],
19
  )
20
 
 
 
 
21
  @app.post("/predict")
22
  async def predict(file: UploadFile = File(...)):
23
  contents = await file.read()
 
26
  arr = np.array(img) / 255.0
27
  arr = np.expand_dims(arr, 0)
28
 
29
+ prediction = model.predict(arr)
30
+ mask = np.argmax(prediction[0], axis=-1).astype(np.uint8)
31
+ mask_img = Image.fromarray(mask * 50) # For visualization
 
 
 
32
 
33
  buf = BytesIO()
34
  mask_img.save(buf, format='PNG')