SAM_MedTesting / app.py
Axzyl's picture
Upload 2 files
2474580 verified
raw
history blame
1.59 kB
from fastapi import FastAPI, File, UploadFile
from transformers import SamModel, SamProcessor
import torch
from PIL import Image
import io
import numpy as np
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
app = FastAPI(title="SAM_MedTesting")
MODEL_ID = "facebook/sam-vit-base"
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = SamProcessor.from_pretrained(MODEL_ID)
model = SamModel.from_pretrained(MODEL_ID).to(device)
@app.post("/segment")
async def segment(
file: UploadFile = File(...),
# an optional list of prompt points [[x1,y1], [x2,y2], …]
input_points: list[list[int]] | None = None
):
# read & decode image
data = await file.read()
img = Image.open(io.BytesIO(data)).convert("RGB")
# prepare inputs for SAM
inputs = processor(
img,
input_points=[input_points] if input_points else None,
return_tensors="pt"
).to(device)
# forward pass
with torch.no_grad():
outputs = model(**inputs)
# post-process to binary masks
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
# take the first mask
mask_np = (masks[0][0] * 255).astype(np.uint8)
# turn it into a PNG
pil_mask = Image.fromarray(mask_np)
buf = io.BytesIO()
pil_mask.save(buf, format="PNG")
buf.seek(0)
return StreamingResponse(buf, media_type="image/png")