segmentation / app.py
shekzee's picture
Update app.py
74bc278 verified
raw
history blame
1.18 kB
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from PIL import Image
from io import BytesIO
import numpy as np
import tensorflow as tf
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Load your trained segmentation model here
# model = tf.keras.models.load_model("seg_model_path")
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
contents = await file.read()
img = Image.open(BytesIO(contents)).convert("RGB")
img = img.resize((256, 256))
arr = np.array(img) / 255.0
arr = np.expand_dims(arr, 0)
# Prediction
prediction = model.predict(arr) # (1, 256, 256, num_classes)
mask = np.argmax(prediction[0], axis=-1).astype(np.uint8) # (256, 256)
# Convert to image (you can colorize or just multiply for visualization)
mask_img = Image.fromarray(mask * 50) # Optional scaling for visibility
buf = BytesIO()
mask_img.save(buf, format='PNG')
buf.seek(0)
return StreamingResponse(buf, media_type="image/png")