shekzee commited on
Commit
37cc80f
·
verified ·
1 Parent(s): 6604d70

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -15
app.py CHANGED
@@ -1,16 +1,14 @@
1
  from fastapi import FastAPI, UploadFile, File
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from fastapi.responses import StreamingResponse
4
  from PIL import Image
 
 
5
  from io import BytesIO
 
6
  import numpy as np
7
- import tensorflow as tf
8
-
9
- # --------- LOAD YOUR SEGMENTATION MODEL HERE ---------
10
- model = tf.keras.models.load_model("seg_model") # <<<<=== THIS LINE!
11
- # -----------------------------------------------------
12
 
13
  app = FastAPI()
 
14
  app.add_middleware(
15
  CORSMiddleware,
16
  allow_origins=["*"],
@@ -18,20 +16,25 @@ app.add_middleware(
18
  allow_headers=["*"],
19
  )
20
 
 
 
 
 
21
  @app.post("/predict")
22
  async def predict(file: UploadFile = File(...)):
23
  contents = await file.read()
24
  img = Image.open(BytesIO(contents)).convert("RGB")
25
- img = img.resize((256, 256))
26
- arr = np.array(img) / 255.0
27
- arr = np.expand_dims(arr, 0)
28
 
29
- prediction = model.predict(arr)
30
- mask = np.argmax(prediction[0], axis=-1).astype(np.uint8)
31
- mask_img = Image.fromarray(mask * 50) # For visualization
 
 
32
 
 
 
33
  buf = BytesIO()
34
- mask_img.save(buf, format='PNG')
35
  buf.seek(0)
36
-
37
- return StreamingResponse(buf, media_type="image/png")
 
1
  from fastapi import FastAPI, UploadFile, File
2
  from fastapi.middleware.cors import CORSMiddleware
 
3
  from PIL import Image
4
+ from transformers import MobileNetV2ForSemanticSegmentation, AutoImageProcessor
5
+ import torch
6
  from io import BytesIO
7
+ import base64
8
  import numpy as np
 
 
 
 
 
9
 
10
  app = FastAPI()
11
+
12
  app.add_middleware(
13
  CORSMiddleware,
14
  allow_origins=["*"],
 
16
  allow_headers=["*"],
17
  )
18
 
19
+ # Load processor and model
20
+ processor = AutoImageProcessor.from_pretrained("seg_model")
21
+ model = MobileNetV2ForSemanticSegmentation.from_pretrained("seg_model")
22
+
23
  @app.post("/predict")
24
  async def predict(file: UploadFile = File(...)):
25
  contents = await file.read()
26
  img = Image.open(BytesIO(contents)).convert("RGB")
 
 
 
27
 
28
+ inputs = processor(images=img, return_tensors="pt")
29
+ with torch.no_grad():
30
+ outputs = model(**inputs)
31
+ logits = outputs.logits # (batch, num_labels, H, W)
32
+ mask = torch.argmax(logits, dim=1)[0].numpy().astype(np.uint8)
33
 
34
+ # Optionally, you can convert mask to RGB with a color map for visualization
35
+ mask_img = Image.fromarray(mask)
36
  buf = BytesIO()
37
+ mask_img.save(buf, format="PNG")
38
  buf.seek(0)
39
+ b64 = base64.b64encode(buf.read()).decode()
40
+ return {"success": True, "mask": "data:image/png;base64," + b64}