Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, UploadFile, File | |
from fastapi.responses import JSONResponse | |
from PIL import Image | |
import torch, torchvision.transforms as T | |
from transformers import MobileNetV2ForSemanticSegmentation | |
import io | |
# Load the model | |
model = MobileNetV2ForSemanticSegmentation.from_pretrained("seg_model") | |
model.eval() | |
preprocess = T.Compose([ | |
T.Resize(513), | |
T.ToTensor(), | |
T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]) | |
]) | |
app = FastAPI() | |
def root(): | |
return {"status": "API up for segmentation"} | |
async def predict(file: UploadFile = File(...)): | |
img = Image.open(await file.read()).convert("RGB") | |
x = preprocess(img).unsqueeze(0) | |
with torch.no_grad(): | |
outputs = model(x).logits | |
seg = outputs.argmax(1)[0].tolist() | |
return JSONResponse(content={"segmentation_mask": seg}) | |