File size: 1,104 Bytes
356590c
64a04e7
74bc278
356590c
64a04e7
 
 
356590c
6604d70
 
 
356590c
6604d70
64a04e7
 
 
 
 
 
356590c
 
 
64a04e7
 
74bc278
64a04e7
 
74bc278
6604d70
 
 
74bc278
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from PIL import Image
from io import BytesIO
import numpy as np
import tensorflow as tf

# --------- LOAD YOUR SEGMENTATION MODEL HERE ---------
model = tf.keras.models.load_model("seg_model")   # <<<<=== THIS LINE!
# -----------------------------------------------------

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    contents = await file.read()
    img = Image.open(BytesIO(contents)).convert("RGB")
    img = img.resize((256, 256))
    arr = np.array(img) / 255.0
    arr = np.expand_dims(arr, 0)

    prediction = model.predict(arr)
    mask = np.argmax(prediction[0], axis=-1).astype(np.uint8)
    mask_img = Image.fromarray(mask * 50)  # For visualization

    buf = BytesIO()
    mask_img.save(buf, format='PNG')
    buf.seek(0)

    return StreamingResponse(buf, media_type="image/png")