shekzee commited on
Commit
b5d99a8
·
verified ·
1 Parent(s): d3d58c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -14
app.py CHANGED
@@ -1,25 +1,38 @@
1
  from fastapi import FastAPI, UploadFile, File
2
- from fastapi.responses import JSONResponse
3
- from tensorflow.keras.models import load_model
4
- from PIL import Image
5
  import numpy as np
6
- import io
 
 
 
 
 
7
 
 
8
  app = FastAPI()
9
 
10
- # Load model and class names
11
- model = load_model("hf_keras_model.keras")
12
- class_names = ['buildings', 'forest', 'glacier', 'mountain', 'sea', 'street']
 
 
 
 
 
 
 
 
13
 
14
- @app.post("/predict/")
15
  async def predict(file: UploadFile = File(...)):
16
- contents = await file.read()
17
- image = Image.open(io.BytesIO(contents)).convert("RGB")
18
- image = image.resize((150, 150))
19
  img_array = np.array(image) / 255.0
20
  img_array = np.expand_dims(img_array, axis=0)
21
 
22
- preds = model.predict(img_array)[0]
23
- confidences = {class_names[i]: float(preds[i]) for i in range(len(class_names))}
 
 
24
 
25
- return JSONResponse(content=confidences)
 
1
  from fastapi import FastAPI, UploadFile, File
2
+ from fastapi.middleware.cors import CORSMiddleware
 
 
3
  import numpy as np
4
+ from PIL import Image
5
+ import tensorflow as tf
6
+
7
+ # Load model and classes
8
+ model = tf.keras.models.load_model("hf_keras_model.keras")
9
+ class_names = ['buildings', 'forest', 'glacier', 'mountain', 'sea', 'street']
10
 
11
+ # Initialize app
12
  app = FastAPI()
13
 
14
+ # Allow all CORS (for frontend/test requests)
15
+ app.add_middleware(
16
+ CORSMiddleware,
17
+ allow_origins=["*"],
18
+ allow_methods=["*"],
19
+ allow_headers=["*"],
20
+ )
21
+
22
+ @app.get("/")
23
+ def root():
24
+ return {"message": "API is working!"}
25
 
26
+ @app.post("/predict")
27
  async def predict(file: UploadFile = File(...)):
28
+ # Load image
29
+ image = Image.open(file.file).convert("RGB").resize((150, 150))
 
30
  img_array = np.array(image) / 255.0
31
  img_array = np.expand_dims(img_array, axis=0)
32
 
33
+ # Predict
34
+ predictions = model.predict(img_array)[0]
35
+ results = {class_names[i]: float(predictions[i]) for i in range(len(class_names))}
36
+ top_class = class_names[np.argmax(predictions)]
37
 
38
+ return {"top_prediction": top_class, "all_predictions": results}