from fastapi import FastAPI, UploadFile, File from fastapi.middleware.cors import CORSMiddleware import numpy as np from PIL import Image import tensorflow as tf # Load model and classes model = tf.keras.models.load_model("hf_keras_model.keras") class_names = ['buildings', 'forest', 'glacier', 'mountain', 'sea', 'street'] # Initialize app app = FastAPI() # Allow all CORS (for frontend/test requests) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) @app.get("/") def root(): return {"message": "API is working!"} @app.post("/predict") async def predict(file: UploadFile = File(...)): # Load image image = Image.open(file.file).convert("RGB").resize((150, 150)) img_array = np.array(image) / 255.0 img_array = np.expand_dims(img_array, axis=0) # Predict predictions = model.predict(img_array)[0] results = {class_names[i]: float(predictions[i]) for i in range(len(class_names))} top_class = class_names[np.argmax(predictions)] return {"top_prediction": top_class, "all_predictions": results}