Axzyl commited on
Commit
2807856
·
verified ·
1 Parent(s): d5bd11b

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -28
app.py CHANGED
@@ -1,53 +1,48 @@
1
- from fastapi import FastAPI, File, UploadFile
2
- from transformers import SamModel, SamProcessor
3
- import torch
 
4
  from PIL import Image
5
- import io
6
- import numpy as np
7
  from fastapi.responses import StreamingResponse
8
- from pydantic import BaseModel
9
 
10
- app = FastAPI(title="SAM_MedTesting")
 
 
 
 
 
11
  MODEL_ID = "facebook/sam-vit-base"
12
- device = "cuda" if torch.cuda.is_available() else "cpu"
13
  processor = SamProcessor.from_pretrained(MODEL_ID)
14
  model = SamModel.from_pretrained(MODEL_ID).to(device)
15
 
16
  @app.post("/segment")
17
- async def segment(
18
- file: UploadFile = File(...),
19
- # an optional list of prompt points [[x1,y1], [x2,y2], …]
20
- input_points: list[list[int]] | None = None
21
- ):
22
-
23
- # read & decode image
24
- data = await file.read()
25
- img = Image.open(io.BytesIO(data)).convert("RGB")
26
-
27
- # prepare inputs for SAM
28
  inputs = processor(
29
  img,
30
- input_points=[input_points] if input_points else None,
31
  return_tensors="pt"
32
  ).to(device)
33
 
34
- # forward pass
35
  with torch.no_grad():
36
  outputs = model(**inputs)
37
 
38
- # post-process to binary masks
39
  masks = processor.image_processor.post_process_masks(
40
  outputs.pred_masks.cpu(),
41
  inputs["original_sizes"].cpu(),
42
  inputs["reshaped_input_sizes"].cpu()
43
  )
44
- # take the first mask
45
- mask_np = (masks[0][0] * 255).astype(np.uint8)
46
-
47
- # turn it into a PNG
48
- pil_mask = Image.fromarray(mask_np)
49
  buf = io.BytesIO()
50
- pil_mask.save(buf, format="PNG")
51
  buf.seek(0)
52
 
53
  return StreamingResponse(buf, media_type="image/png")
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from typing import Optional, List
4
+ import base64, io, json
5
  from PIL import Image
6
+ import torch
7
+ from transformers import SamModel, SamProcessor
8
  from fastapi.responses import StreamingResponse
 
9
 
10
+ class SegmentRequest(BaseModel):
11
+ file_b64: str
12
+ input_points: Optional[List[List[int]]] = None
13
+
14
+ app = FastAPI(title="SAM MedTesting")
15
+
16
  MODEL_ID = "facebook/sam-vit-base"
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
  processor = SamProcessor.from_pretrained(MODEL_ID)
19
  model = SamModel.from_pretrained(MODEL_ID).to(device)
20
 
21
  @app.post("/segment")
22
+ async def segment(req: SegmentRequest):
23
+ # decode image
24
+ img_bytes = base64.b64decode(req.file_b64)
25
+ img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
26
+
27
+ # prepare inputs
28
+ pts = req.input_points
 
 
 
 
29
  inputs = processor(
30
  img,
31
+ input_points=[pts] if pts else None,
32
  return_tensors="pt"
33
  ).to(device)
34
 
 
35
  with torch.no_grad():
36
  outputs = model(**inputs)
37
 
 
38
  masks = processor.image_processor.post_process_masks(
39
  outputs.pred_masks.cpu(),
40
  inputs["original_sizes"].cpu(),
41
  inputs["reshaped_input_sizes"].cpu()
42
  )
43
+ mask_np = (masks[0][0] * 255).astype("uint8")
 
 
 
 
44
  buf = io.BytesIO()
45
+ Image.fromarray(mask_np).save(buf, format="PNG")
46
  buf.seek(0)
47
 
48
  return StreamingResponse(buf, media_type="image/png")