Spaces:
Running
Running
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from typing import Optional, List | |
import base64, io, json | |
from PIL import Image | |
import torch | |
from transformers import SamModel, SamProcessor | |
from fastapi.responses import StreamingResponse | |
class SegmentRequest(BaseModel): | |
file_b64: str | |
input_points: Optional[List[List[int]]] = None | |
app = FastAPI(title="SAM MedTesting") | |
MODEL_ID = "facebook/sam-vit-base" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
processor = SamProcessor.from_pretrained(MODEL_ID) | |
model = SamModel.from_pretrained(MODEL_ID).to(device) | |
async def segment(req: SegmentRequest): | |
# decode image | |
img_bytes = base64.b64decode(req.file_b64) | |
img = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
# prepare inputs | |
pts = req.input_points | |
inputs = processor( | |
img, | |
input_points=[pts] if pts else None, | |
return_tensors="pt" | |
).to(device) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
masks = processor.image_processor.post_process_masks( | |
outputs.pred_masks.cpu(), | |
inputs["original_sizes"].cpu(), | |
inputs["reshaped_input_sizes"].cpu() | |
) | |
single_mask = masks[0][0] # first image, first mask → torch.Tensor of shape (H, W) | |
mask_np = (single_mask * 255) \ | |
.to(torch.uint8) \ | |
.cpu().numpy() | |
# now mask_np.shape == (H, W), e.g. (10, 10) | |
print(mask_np.shape) | |
pil_mask = Image.fromarray(mask_np[0,:,:]) | |
buf = io.BytesIO() | |
pil_mask.save(buf, format="PNG") | |
buf.seek(0) | |
return StreamingResponse(buf, media_type="image/png") | |