from fastapi import FastAPI, UploadFile, File from fastapi.responses import JSONResponse from PIL import Image import torch, torchvision.transforms as T from transformers import MobileNetV2ForSemanticSegmentation import io # Load the model model = MobileNetV2ForSemanticSegmentation.from_pretrained("seg_model") model.eval() preprocess = T.Compose([ T.Resize(513), T.ToTensor(), T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]) ]) app = FastAPI() @app.get("/") def root(): return {"status": "API up for segmentation"} @app.post("/predict") async def predict(file: UploadFile = File(...)): img = Image.open(await file.read()).convert("RGB") x = preprocess(img).unsqueeze(0) with torch.no_grad(): outputs = model(x).logits seg = outputs.argmax(1)[0].tolist() return JSONResponse(content={"segmentation_mask": seg})