Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, File, UploadFile | |
from transformers import SamModel, SamProcessor | |
import torch | |
from PIL import Image | |
import io | |
import numpy as np | |
from fastapi.responses import StreamingResponse | |
from pydantic import BaseModel | |
app = FastAPI(title="SAM_MedTesting") | |
MODEL_ID = "facebook/sam-vit-base" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
processor = SamProcessor.from_pretrained(MODEL_ID) | |
model = SamModel.from_pretrained(MODEL_ID).to(device) | |
async def segment( | |
file: UploadFile = File(...), | |
# an optional list of prompt points [[x1,y1], [x2,y2], …] | |
input_points: list[list[int]] | None = None | |
): | |
# read & decode image | |
data = await file.read() | |
img = Image.open(io.BytesIO(data)).convert("RGB") | |
# prepare inputs for SAM | |
inputs = processor( | |
img, | |
input_points=[input_points] if input_points else None, | |
return_tensors="pt" | |
).to(device) | |
# forward pass | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# post-process to binary masks | |
masks = processor.image_processor.post_process_masks( | |
outputs.pred_masks.cpu(), | |
inputs["original_sizes"].cpu(), | |
inputs["reshaped_input_sizes"].cpu() | |
) | |
# take the first mask | |
mask_np = (masks[0][0] * 255).astype(np.uint8) | |
# turn it into a PNG | |
pil_mask = Image.fromarray(mask_np) | |
buf = io.BytesIO() | |
pil_mask.save(buf, format="PNG") | |
buf.seek(0) | |
return StreamingResponse(buf, media_type="image/png") | |