segmentation / app.py
shekzee's picture
Update app.py
1f9611c verified
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
from transformers import MobileNetV2ForSemanticSegmentation, AutoImageProcessor
import torch
from io import BytesIO
import base64
import numpy as np
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Load processor and model
processor = AutoImageProcessor.from_pretrained("seg_model")
model = MobileNetV2ForSemanticSegmentation.from_pretrained("seg_model")
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
contents = await file.read()
img = Image.open(BytesIO(contents)).convert("RGB")
inputs = processor(images=img, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits # (batch, num_labels, H, W)
mask = torch.argmax(logits, dim=1)[0].numpy().astype(np.uint8)
# Convert mask to grayscale PNG and return as base64
mask_img = Image.fromarray(mask)
buf = BytesIO()
mask_img.save(buf, format="PNG")
buf.seek(0)
b64 = base64.b64encode(buf.read()).decode()
return {"success": True, "mask": "data:image/png;base64," + b64}