segmentation / app.py
shekzee's picture
Update app.py
37cc80f verified
raw
history blame
1.27 kB
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
from transformers import MobileNetV2ForSemanticSegmentation, AutoImageProcessor
import torch
from io import BytesIO
import base64
import numpy as np
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Load processor and model
processor = AutoImageProcessor.from_pretrained("seg_model")
model = MobileNetV2ForSemanticSegmentation.from_pretrained("seg_model")
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
contents = await file.read()
img = Image.open(BytesIO(contents)).convert("RGB")
inputs = processor(images=img, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits # (batch, num_labels, H, W)
mask = torch.argmax(logits, dim=1)[0].numpy().astype(np.uint8)
# Optionally, you can convert mask to RGB with a color map for visualization
mask_img = Image.fromarray(mask)
buf = BytesIO()
mask_img.save(buf, format="PNG")
buf.seek(0)
b64 = base64.b64encode(buf.read()).decode()
return {"success": True, "mask": "data:image/png;base64," + b64}