Spaces:
Sleeping
Sleeping
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from typing import Optional, List | |
import base64, io, json | |
from PIL import Image | |
import torch | |
from transformers import SamModel, SamProcessor | |
from fastapi.responses import StreamingResponse | |
class SegmentRequest(BaseModel): | |
file_b64: str | |
input_points: Optional[List[List[int]]] = None | |
app = FastAPI(title="SAM MedTesting") | |
MODEL_ID = "facebook/sam-vit-base" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
processor = SamProcessor.from_pretrained(MODEL_ID) | |
model = SamModel.from_pretrained(MODEL_ID).to(device) | |
async def segment(req: SegmentRequest): | |
# decode image | |
img_bytes = base64.b64decode(req.file_b64) | |
img = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
# prepare inputs | |
pts = req.input_points | |
inputs = processor( | |
img, | |
input_points=[pts] if pts else None, | |
return_tensors="pt" | |
).to(device) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
masks = processor.image_processor.post_process_masks( | |
outputs.pred_masks.cpu(), | |
inputs["original_sizes"].cpu(), | |
inputs["reshaped_input_sizes"].cpu() | |
) | |
mask_np = (masks[0][0] * 255).astype("uint8") | |
buf = io.BytesIO() | |
Image.fromarray(mask_np).save(buf, format="PNG") | |
buf.seek(0) | |
return StreamingResponse(buf, media_type="image/png") | |