File size: 1,591 Bytes
2474580
 
 
 
 
 
 
650149e
 
 
2474580
 
 
 
650149e
2474580
 
 
 
 
 
 
 
 
 
650149e
2474580
 
 
 
 
 
650149e
2474580
 
 
650149e
2474580
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from fastapi import FastAPI, File, UploadFile
from transformers import SamModel, SamProcessor
import torch
from PIL import Image
import io
import numpy as np
from fastapi.responses import StreamingResponse
from pydantic import BaseModel

app = FastAPI(title="SAM_MedTesting")
MODEL_ID = "facebook/sam-vit-base"
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = SamProcessor.from_pretrained(MODEL_ID)
model     = SamModel.from_pretrained(MODEL_ID).to(device)

@app.post("/segment")
async def segment(

    file: UploadFile = File(...),

    # an optional list of prompt points [[x1,y1], [x2,y2], …]

    input_points: list[list[int]] | None = None

    ):
    
    # read & decode image
    data = await file.read()
    img  = Image.open(io.BytesIO(data)).convert("RGB")

    # prepare inputs for SAM
    inputs = processor(
        img,
        input_points=[input_points] if input_points else None,
        return_tensors="pt"
    ).to(device)

    # forward pass
    with torch.no_grad():
        outputs = model(**inputs)

    # post-process to binary masks
    masks = processor.image_processor.post_process_masks(
        outputs.pred_masks.cpu(),
        inputs["original_sizes"].cpu(),
        inputs["reshaped_input_sizes"].cpu()
    )
    # take the first mask
    mask_np = (masks[0][0] * 255).astype(np.uint8)

    # turn it into a PNG
    pil_mask = Image.fromarray(mask_np)
    buf = io.BytesIO()
    pil_mask.save(buf, format="PNG")
    buf.seek(0)

    return StreamingResponse(buf, media_type="image/png")