shekzee commited on
Commit
74bc278
·
verified ·
1 Parent(s): 64a04e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -5
app.py CHANGED
@@ -1,5 +1,6 @@
1
  from fastapi import FastAPI, UploadFile, File
2
  from fastapi.middleware.cors import CORSMiddleware
 
3
  from PIL import Image
4
  from io import BytesIO
5
  import numpy as np
@@ -7,7 +8,6 @@ import tensorflow as tf
7
 
8
  app = FastAPI()
9
 
10
- # CORS config (optional)
11
  app.add_middleware(
12
  CORSMiddleware,
13
  allow_origins=["*"],
@@ -15,13 +15,26 @@ app.add_middleware(
15
  allow_headers=["*"],
16
  )
17
 
 
 
 
18
  @app.post("/predict")
19
  async def predict(file: UploadFile = File(...)):
20
  contents = await file.read()
21
  img = Image.open(BytesIO(contents)).convert("RGB")
22
- img = img.resize((256, 256)) # or whatever your model expects
23
  arr = np.array(img) / 255.0
24
  arr = np.expand_dims(arr, 0)
25
- # prediction = model.predict(arr)
26
- # result = do_something_with_prediction(prediction)
27
- return {"success": True}
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI, UploadFile, File
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import StreamingResponse
4
  from PIL import Image
5
  from io import BytesIO
6
  import numpy as np
 
8
 
9
  app = FastAPI()
10
 
 
11
  app.add_middleware(
12
  CORSMiddleware,
13
  allow_origins=["*"],
 
15
  allow_headers=["*"],
16
  )
17
 
18
+ # Load your trained segmentation model here
19
+ # model = tf.keras.models.load_model("seg_model_path")
20
+
21
  @app.post("/predict")
22
  async def predict(file: UploadFile = File(...)):
23
  contents = await file.read()
24
  img = Image.open(BytesIO(contents)).convert("RGB")
25
+ img = img.resize((256, 256))
26
  arr = np.array(img) / 255.0
27
  arr = np.expand_dims(arr, 0)
28
+
29
+ # Prediction
30
+ prediction = model.predict(arr) # (1, 256, 256, num_classes)
31
+ mask = np.argmax(prediction[0], axis=-1).astype(np.uint8) # (256, 256)
32
+
33
+ # Convert to image (you can colorize or just multiply for visualization)
34
+ mask_img = Image.fromarray(mask * 50) # Optional scaling for visibility
35
+
36
+ buf = BytesIO()
37
+ mask_img.save(buf, format='PNG')
38
+ buf.seek(0)
39
+
40
+ return StreamingResponse(buf, media_type="image/png")