shekzee's picture
Update app.py
b5d99a8 verified
raw
history blame
1.11 kB
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
import numpy as np
from PIL import Image
import tensorflow as tf
# Load model and classes
model = tf.keras.models.load_model("hf_keras_model.keras")
class_names = ['buildings', 'forest', 'glacier', 'mountain', 'sea', 'street']
# Initialize app
app = FastAPI()
# Allow all CORS (for frontend/test requests)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
def root():
return {"message": "API is working!"}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
# Load image
image = Image.open(file.file).convert("RGB").resize((150, 150))
img_array = np.array(image) / 255.0
img_array = np.expand_dims(img_array, axis=0)
# Predict
predictions = model.predict(img_array)[0]
results = {class_names[i]: float(predictions[i]) for i in range(len(class_names))}
top_class = class_names[np.argmax(predictions)]
return {"top_prediction": top_class, "all_predictions": results}