from fastapi import FastAPI, UploadFile, File from fastapi.middleware.cors import CORSMiddleware from PIL import Image from transformers import MobileNetV2ForSemanticSegmentation, AutoImageProcessor import torch from io import BytesIO import base64 import numpy as np app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # Load processor and model processor = AutoImageProcessor.from_pretrained("seg_model") model = MobileNetV2ForSemanticSegmentation.from_pretrained("seg_model") @app.post("/predict") async def predict(file: UploadFile = File(...)): contents = await file.read() img = Image.open(BytesIO(contents)).convert("RGB") inputs = processor(images=img, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits # (batch, num_labels, H, W) mask = torch.argmax(logits, dim=1)[0].numpy().astype(np.uint8) # Convert mask to grayscale PNG and return as base64 mask_img = Image.fromarray(mask) buf = BytesIO() mask_img.save(buf, format="PNG") buf.seek(0) b64 = base64.b64encode(buf.read()).decode() return {"success": True, "mask": "data:image/png;base64," + b64}