from fastapi import FastAPI, UploadFile, File from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from PIL import Image from io import BytesIO import numpy as np import tensorflow as tf # --------- LOAD YOUR SEGMENTATION MODEL HERE --------- model = tf.keras.models.load_model("seg_model") # <<<<=== THIS LINE! # ----------------------------------------------------- app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) @app.post("/predict") async def predict(file: UploadFile = File(...)): contents = await file.read() img = Image.open(BytesIO(contents)).convert("RGB") img = img.resize((256, 256)) arr = np.array(img) / 255.0 arr = np.expand_dims(arr, 0) prediction = model.predict(arr) mask = np.argmax(prediction[0], axis=-1).astype(np.uint8) mask_img = Image.fromarray(mask * 50) # For visualization buf = BytesIO() mask_img.save(buf, format='PNG') buf.seek(0) return StreamingResponse(buf, media_type="image/png")