SAM_MedTesting / app.py
Axzyl's picture
Upload app.py
6744a7e verified
raw
history blame
1.68 kB
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Optional, List
import base64, io, json
from PIL import Image
import torch
from transformers import SamModel, SamProcessor
from fastapi.responses import StreamingResponse
class SegmentRequest(BaseModel):
file_b64: str
input_points: Optional[List[List[int]]] = None
app = FastAPI(title="SAM MedTesting")
MODEL_ID = "facebook/sam-vit-base"
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = SamProcessor.from_pretrained(MODEL_ID)
model = SamModel.from_pretrained(MODEL_ID).to(device)
@app.post("/segment")
async def segment(req: SegmentRequest):
# decode image
img_bytes = base64.b64decode(req.file_b64)
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
# prepare inputs
pts = req.input_points
inputs = processor(
img,
input_points=[pts] if pts else None,
return_tensors="pt"
).to(device)
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
mask_tensor = masks[0][0] # torch.FloatTensor
mask_np = (mask_tensor * 255) \
.to(torch.uint8) \
.cpu() \
.numpy() # now uint8 numpy array
buf = io.BytesIO()
Image.fromarray(mask_np).save(buf, format="PNG")
buf.seek(0)
return StreamingResponse(buf, media_type="image/png")