skibi11 commited on
Commit
1490129
·
verified ·
1 Parent(s): dc2c7e9

app.py using a manual FastAPI endpoint

Browse files
Files changed (1) hide show
  1. app.py +39 -29
app.py CHANGED
@@ -1,14 +1,19 @@
1
- # Final app.py using FastAPI wrapper
2
 
3
  from fastapi import FastAPI
 
 
4
  import gradio as gr
5
  import tensorflow as tf
6
  from huggingface_hub import hf_hub_download
7
  import numpy as np
8
  from PIL import Image
9
  import os
 
 
10
 
11
- # --- 1. Load the Model ---
 
12
  try:
13
  model_path = hf_hub_download(
14
  repo_id="skibi11/leukolook-eye-detector",
@@ -18,9 +23,9 @@ try:
18
  print("--- MODEL LOADED SUCCESSFULLY! ---")
19
  except Exception as e:
20
  print(f"--- ERROR LOADING MODEL: {e} ---")
21
- raise RuntimeError(f"Failed to load model: {e}")
22
 
23
- # --- 2. Pre-processing & Prediction Logic (remains the same) ---
24
  def preprocess_image(img_pil):
25
  img = img_pil.resize((224, 224))
26
  img_array = np.array(img)
@@ -30,32 +35,37 @@ def preprocess_image(img_pil):
30
  img_array = np.expand_dims(img_array, axis=0)
31
  return img_array
32
 
33
- def predict(image_from_gradio):
34
- if not isinstance(image_from_gradio, np.ndarray):
35
- return {"error": "Invalid input type. Expected an image."}
36
- try:
37
- pil_image = Image.fromarray(image_from_gradio)
38
- processed_image = preprocess_image(pil_image)
39
- prediction = model.predict(processed_image)
40
- labels = [f"Class_{i}" for i in range(prediction.shape[1])]
41
- confidences = {label: float(score) for label, score in zip(labels, prediction[0])}
42
- return confidences
43
- except Exception as e:
44
- return {"error": f"Error during prediction: {e}"}
45
 
46
- # --- 3. Create the Gradio Interface (without launching) ---
47
- gradio_interface = gr.Interface(
48
- fn=predict,
49
- inputs=gr.Image(type="numpy"),
50
- outputs=gr.JSON(),
51
- api_name="predict"
52
- )
53
 
54
- # --- 4. Create the FastAPI app and mount the Gradio app to it ---
55
  app = FastAPI()
56
- app = gr.mount_gradio_app(app, gradio_interface, path="/")
57
 
58
- # --- 5. To run the server ---
59
- if __name__ == "__main__":
60
- import uvicorn
61
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The Final app.py using a manual FastAPI endpoint
2
 
3
  from fastapi import FastAPI
4
+ from fastapi.responses import JSONResponse
5
+ from pydantic import BaseModel
6
  import gradio as gr
7
  import tensorflow as tf
8
  from huggingface_hub import hf_hub_download
9
  import numpy as np
10
  from PIL import Image
11
  import os
12
+ import base64
13
+ import io
14
 
15
+ # --- 1. Load the Model (Stays the same) ---
16
+ model = None
17
  try:
18
  model_path = hf_hub_download(
19
  repo_id="skibi11/leukolook-eye-detector",
 
23
  print("--- MODEL LOADED SUCCESSFULLY! ---")
24
  except Exception as e:
25
  print(f"--- ERROR LOADING MODEL: {e} ---")
26
+ model = None # Ensure model is None if loading fails
27
 
28
+ # --- 2. Prediction Logic (Stays the same) ---
29
  def preprocess_image(img_pil):
30
  img = img_pil.resize((224, 224))
31
  img_array = np.array(img)
 
35
  img_array = np.expand_dims(img_array, axis=0)
36
  return img_array
37
 
38
+ def run_prediction(pil_image):
39
+ if model is None:
40
+ return {"error": "Model is not loaded on the server."}
 
 
 
 
 
 
 
 
 
41
 
42
+ processed_image = preprocess_image(pil_image)
43
+ prediction = model.predict(processed_image)
44
+ labels = [f"Class_{i}" for i in range(prediction.shape[1])]
45
+ confidences = {label: float(score) for label, score in zip(labels, prediction[0])}
46
+ return confidences
 
 
47
 
48
+ # --- 3. Create the FastAPI app ---
49
  app = FastAPI()
 
50
 
51
+ # --- 4. Define the input data structure for our new endpoint ---
52
+ class PredictionRequest(BaseModel):
53
+ data: list[str]
54
+
55
+ # --- 5. Create our own reliable API endpoint ---
56
+ @app.post("/api/predict/")
57
+ async def handle_prediction(request: PredictionRequest):
58
+ try:
59
+ # Get the Base64 string from the JSON payload
60
+ base64_string = request.data[0].split(',', 1)[1]
61
+ image_bytes = base64.b64decode(base64_string)
62
+ pil_image = Image.open(io.BytesIO(image_bytes))
63
+
64
+ # Run the prediction
65
+ result_dict = run_prediction(pil_image)
66
+
67
+ # Return the result in the same format Gradio does
68
+ return JSONResponse(content={"data": [result_dict]})
69
+
70
+ except Exception as e:
71
+ return JSONResponse(status_code=500, content={"error": str(e)})