Spaces:
Running
Running
File size: 1,778 Bytes
2807856 2474580 2807856 2474580 650149e 2807856 2474580 2807856 2474580 650149e 2474580 2807856 2474580 2807856 2474580 650149e 2474580 650149e 2474580 1d67d5a 2474580 1d67d5a 2474580 1d67d5a 2474580 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Optional, List
import base64, io, json
from PIL import Image
import torch
from transformers import SamModel, SamProcessor
from fastapi.responses import StreamingResponse
class SegmentRequest(BaseModel):
file_b64: str
input_points: Optional[List[List[int]]] = None
app = FastAPI(title="SAM MedTesting")
MODEL_ID = "facebook/sam-vit-base"
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = SamProcessor.from_pretrained(MODEL_ID)
model = SamModel.from_pretrained(MODEL_ID).to(device)
@app.post("/segment")
async def segment(req: SegmentRequest):
# decode image
img_bytes = base64.b64decode(req.file_b64)
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
# prepare inputs
pts = req.input_points
inputs = processor(
img,
input_points=[pts] if pts else None,
return_tensors="pt"
).to(device)
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
single_mask = masks[0][0] # first image, first mask → torch.Tensor of shape (H, W)
mask_np = (single_mask * 255) \
.to(torch.uint8) \
.cpu().numpy()
# now mask_np.shape == (H, W), e.g. (10, 10)
print(mask_np.shape)
pil_mask = Image.fromarray(mask_np[0,:,:])
buf = io.BytesIO()
pil_mask.save(buf, format="PNG")
buf.seek(0)
return StreamingResponse(buf, media_type="image/png")
|