segmentation / app.py
shekzee's picture
Update app.py
6604d70 verified
raw
history blame
1.1 kB
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from PIL import Image
from io import BytesIO
import numpy as np
import tensorflow as tf
# --------- LOAD YOUR SEGMENTATION MODEL HERE ---------
model = tf.keras.models.load_model("seg_model") # <<<<=== THIS LINE!
# -----------------------------------------------------
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
contents = await file.read()
img = Image.open(BytesIO(contents)).convert("RGB")
img = img.resize((256, 256))
arr = np.array(img) / 255.0
arr = np.expand_dims(arr, 0)
prediction = model.predict(arr)
mask = np.argmax(prediction[0], axis=-1).astype(np.uint8)
mask_img = Image.fromarray(mask * 50) # For visualization
buf = BytesIO()
mask_img.save(buf, format='PNG')
buf.seek(0)
return StreamingResponse(buf, media_type="image/png")