Spaces:
Sleeping
Sleeping
File size: 880 Bytes
356590c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from PIL import Image
import torch, torchvision.transforms as T
from transformers import MobileNetV2ForSemanticSegmentation
import io
# Load the model
model = MobileNetV2ForSemanticSegmentation.from_pretrained("seg_model")
model.eval()
preprocess = T.Compose([
T.Resize(513),
T.ToTensor(),
T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])
app = FastAPI()
@app.get("/")
def root():
return {"status": "API up for segmentation"}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
img = Image.open(await file.read()).convert("RGB")
x = preprocess(img).unsqueeze(0)
with torch.no_grad():
outputs = model(x).logits
seg = outputs.argmax(1)[0].tolist()
return JSONResponse(content={"segmentation_mask": seg})
|