yassonee commited on
Commit
0b5ea2b
·
verified ·
1 Parent(s): d2f6121

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -13
app.py CHANGED
@@ -4,11 +4,13 @@ import torch
4
  from PIL import Image, ImageDraw
5
  import io
6
  import base64
7
- from fastapi import FastAPI, File, UploadFile
8
  from fastapi.middleware.cors import CORSMiddleware
9
  import numpy as np
10
  import json
11
  from starlette.responses import JSONResponse
 
 
12
 
13
  # FastAPI app
14
  app = FastAPI()
@@ -53,17 +55,13 @@ def draw_boxes(image, predictions, threshold=0.6):
53
  return image, filtered_preds
54
 
55
  # API Endpoint
56
- @app.post("/detect")
57
- async def detect_fracture(file: UploadFile = File(...), confidence: float = 0.6):
58
  try:
59
- # Read and process image
60
  contents = await file.read()
61
  image = Image.open(io.BytesIO(contents))
62
 
63
- # Get predictions from all models
64
- results = {}
65
-
66
- # Object detection models
67
  detection_preds = models["D3STRON"](image)
68
  result_image = image.copy()
69
  result_image, filtered_detections = draw_boxes(result_image, detection_preds, confidence)
@@ -74,7 +72,7 @@ async def detect_fracture(file: UploadFile = File(...), confidence: float = 0.6)
74
  img_byte_arr = img_byte_arr.getvalue()
75
  img_b64 = base64.b64encode(img_byte_arr).decode()
76
 
77
- # Classification models
78
  class_results = {
79
  "Heem2": models["Heem2"](image),
80
  "Nandodeomkar": models["Nandodeomkar"](image)
@@ -91,19 +89,51 @@ async def detect_fracture(file: UploadFile = File(...), confidence: float = 0.6)
91
  return JSONResponse({
92
  "success": False,
93
  "error": str(e)
94
- })
95
 
96
  # Streamlit UI
97
  def main():
98
  st.title("🦴 Fraktur Detektion")
99
 
100
- # UI elements...
101
  uploaded_file = st.file_uploader("Röntgenbild hochladen", type=['png', 'jpg', 'jpeg'])
102
  confidence = st.slider("Konfidenzschwelle", 0.0, 1.0, 0.6, 0.05)
103
 
104
  if uploaded_file:
105
- # Process image and display results...
106
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  if __name__ == "__main__":
 
 
 
 
 
 
109
  main()
 
4
  from PIL import Image, ImageDraw
5
  import io
6
  import base64
7
+ from fastapi import FastAPI, File, UploadFile, Form
8
  from fastapi.middleware.cors import CORSMiddleware
9
  import numpy as np
10
  import json
11
  from starlette.responses import JSONResponse
12
+ import uvicorn
13
+ from threading import Thread
14
 
15
  # FastAPI app
16
  app = FastAPI()
 
55
  return image, filtered_preds
56
 
57
  # API Endpoint
58
+ @app.post("/api/predict")
59
+ async def predict(file: UploadFile = File(...), confidence: float = Form(default=0.6)):
60
  try:
 
61
  contents = await file.read()
62
  image = Image.open(io.BytesIO(contents))
63
 
64
+ # Object detection
 
 
 
65
  detection_preds = models["D3STRON"](image)
66
  result_image = image.copy()
67
  result_image, filtered_detections = draw_boxes(result_image, detection_preds, confidence)
 
72
  img_byte_arr = img_byte_arr.getvalue()
73
  img_b64 = base64.b64encode(img_byte_arr).decode()
74
 
75
+ # Classifications
76
  class_results = {
77
  "Heem2": models["Heem2"](image),
78
  "Nandodeomkar": models["Nandodeomkar"](image)
 
89
  return JSONResponse({
90
  "success": False,
91
  "error": str(e)
92
+ }, status_code=500)
93
 
94
  # Streamlit UI
95
  def main():
96
  st.title("🦴 Fraktur Detektion")
97
 
 
98
  uploaded_file = st.file_uploader("Röntgenbild hochladen", type=['png', 'jpg', 'jpeg'])
99
  confidence = st.slider("Konfidenzschwelle", 0.0, 1.0, 0.6, 0.05)
100
 
101
  if uploaded_file:
102
+ # Afficher l'image originale
103
+ image = Image.open(uploaded_file)
104
+ st.image(image, caption="Original Röntgenbild", use_column_width=True)
105
+
106
+ if st.button("Analysieren"):
107
+ with st.spinner('Analyse läuft...'):
108
+ # Object detection
109
+ detection_preds = models["D3STRON"](image)
110
+ result_image = image.copy()
111
+ result_image, filtered_detections = draw_boxes(result_image, detection_preds, confidence)
112
+
113
+ # Afficher l'image avec les détections
114
+ st.image(result_image, caption="Erkannte Frakturen", use_column_width=True)
115
+
116
+ # Afficher les détections
117
+ st.subheader("Detektionen:")
118
+ for detection in filtered_detections:
119
+ st.write(f"- {detection['label']}: {detection['score']:.2%}")
120
+
121
+ # Classifications
122
+ st.subheader("Klassifikationen:")
123
+ class_results = {
124
+ "Heem2": models["Heem2"](image),
125
+ "Nandodeomkar": models["Nandodeomkar"](image)
126
+ }
127
+ st.json(class_results)
128
+
129
+ def run_fastapi():
130
+ uvicorn.run(app, host="0.0.0.0", port=8000)
131
 
132
  if __name__ == "__main__":
133
+ # Démarrer FastAPI dans un thread séparé
134
+ api_thread = Thread(target=run_fastapi)
135
+ api_thread.daemon = True
136
+ api_thread.start()
137
+
138
+ # Lancer Streamlit
139
  main()