Spaces:
Sleeping
Sleeping
File size: 1,249 Bytes
356590c 64a04e7 356590c 37cc80f 64a04e7 37cc80f 64a04e7 356590c 6604d70 37cc80f 64a04e7 356590c 37cc80f 356590c 64a04e7 74bc278 37cc80f 74bc278 1f9611c 37cc80f 74bc278 37cc80f 74bc278 37cc80f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
from transformers import MobileNetV2ForSemanticSegmentation, AutoImageProcessor
import torch
from io import BytesIO
import base64
import numpy as np
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Load processor and model
processor = AutoImageProcessor.from_pretrained("seg_model")
model = MobileNetV2ForSemanticSegmentation.from_pretrained("seg_model")
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
contents = await file.read()
img = Image.open(BytesIO(contents)).convert("RGB")
inputs = processor(images=img, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits # (batch, num_labels, H, W)
mask = torch.argmax(logits, dim=1)[0].numpy().astype(np.uint8)
# Convert mask to grayscale PNG and return as base64
mask_img = Image.fromarray(mask)
buf = BytesIO()
mask_img.save(buf, format="PNG")
buf.seek(0)
b64 = base64.b64encode(buf.read()).decode()
return {"success": True, "mask": "data:image/png;base64," + b64}
|