File size: 1,249 Bytes
356590c
64a04e7
356590c
37cc80f
 
64a04e7
37cc80f
64a04e7
356590c
6604d70
37cc80f
64a04e7
 
 
 
 
 
356590c
37cc80f
 
 
 
356590c
 
64a04e7
 
74bc278
37cc80f
 
 
 
 
74bc278
1f9611c
37cc80f
74bc278
37cc80f
74bc278
37cc80f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
from transformers import MobileNetV2ForSemanticSegmentation, AutoImageProcessor
import torch
from io import BytesIO
import base64
import numpy as np

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# Load processor and model
processor = AutoImageProcessor.from_pretrained("seg_model")
model = MobileNetV2ForSemanticSegmentation.from_pretrained("seg_model")

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    contents = await file.read()
    img = Image.open(BytesIO(contents)).convert("RGB")

    inputs = processor(images=img, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits  # (batch, num_labels, H, W)
    mask = torch.argmax(logits, dim=1)[0].numpy().astype(np.uint8)

    # Convert mask to grayscale PNG and return as base64
    mask_img = Image.fromarray(mask)
    buf = BytesIO()
    mask_img.save(buf, format="PNG")
    buf.seek(0)
    b64 = base64.b64encode(buf.read()).decode()
    return {"success": True, "mask": "data:image/png;base64," + b64}