segmentation / app.py
shekzee's picture
Create app.py
356590c verified
raw
history blame
880 Bytes
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from PIL import Image
import torch, torchvision.transforms as T
from transformers import MobileNetV2ForSemanticSegmentation
import io
# Load the model
model = MobileNetV2ForSemanticSegmentation.from_pretrained("seg_model")
model.eval()
preprocess = T.Compose([
T.Resize(513),
T.ToTensor(),
T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])
app = FastAPI()
@app.get("/")
def root():
return {"status": "API up for segmentation"}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
img = Image.open(await file.read()).convert("RGB")
x = preprocess(img).unsqueeze(0)
with torch.no_grad():
outputs = model(x).logits
seg = outputs.argmax(1)[0].tolist()
return JSONResponse(content={"segmentation_mask": seg})