SAM_MedTesting / app.py
Axzyl's picture
Upload app.py
2807856 verified
raw
history blame
1.46 kB
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Optional, List
import base64, io, json
from PIL import Image
import torch
from transformers import SamModel, SamProcessor
from fastapi.responses import StreamingResponse
class SegmentRequest(BaseModel):
file_b64: str
input_points: Optional[List[List[int]]] = None
app = FastAPI(title="SAM MedTesting")
MODEL_ID = "facebook/sam-vit-base"
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = SamProcessor.from_pretrained(MODEL_ID)
model = SamModel.from_pretrained(MODEL_ID).to(device)
@app.post("/segment")
async def segment(req: SegmentRequest):
# decode image
img_bytes = base64.b64decode(req.file_b64)
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
# prepare inputs
pts = req.input_points
inputs = processor(
img,
input_points=[pts] if pts else None,
return_tensors="pt"
).to(device)
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
mask_np = (masks[0][0] * 255).astype("uint8")
buf = io.BytesIO()
Image.fromarray(mask_np).save(buf, format="PNG")
buf.seek(0)
return StreamingResponse(buf, media_type="image/png")