yassonee commited on
Commit
c16e85e
·
verified ·
1 Parent(s): 0b5ea2b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -59
app.py CHANGED
@@ -1,16 +1,16 @@
1
  import streamlit as st
 
 
 
2
  from transformers import pipeline
3
  import torch
4
  from PIL import Image, ImageDraw
5
  import io
6
- import base64
7
- from fastapi import FastAPI, File, UploadFile, Form
8
- from fastapi.middleware.cors import CORSMiddleware
9
- import numpy as np
10
  import json
 
11
  from starlette.responses import JSONResponse
12
- import uvicorn
13
- from threading import Thread
14
 
15
  # FastAPI app
16
  app = FastAPI()
@@ -24,19 +24,27 @@ app.add_middleware(
24
  allow_headers=["*"],
25
  )
26
 
27
- # Load models
28
  @st.cache_resource
29
  def load_models():
30
- return {
31
- "D3STRON": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
32
- "Heem2": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"),
33
- "Nandodeomkar": pipeline("image-classification",
34
- model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
35
- }
 
 
 
 
36
 
 
37
  models = load_models()
38
 
39
  def draw_boxes(image, predictions, threshold=0.6):
 
 
 
40
  draw = ImageDraw.Draw(image)
41
  filtered_preds = [p for p in predictions if p['score'] >= threshold]
42
 
@@ -44,33 +52,37 @@ def draw_boxes(image, predictions, threshold=0.6):
44
  box = pred['box']
45
  label = f"{pred['label']} ({pred['score']:.2%})"
46
 
 
47
  draw.rectangle(
48
  [(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
49
  outline="red",
50
  width=2
51
  )
52
 
53
- draw.text((box['xmin'], box['ymin']), label, fill="red")
 
 
 
 
 
54
 
55
  return image, filtered_preds
56
 
57
- # API Endpoint
58
- @app.post("/api/predict")
59
- async def predict(file: UploadFile = File(...), confidence: float = Form(default=0.6)):
 
60
  try:
61
- contents = await file.read()
62
- image = Image.open(io.BytesIO(contents))
63
-
64
  # Object detection
65
  detection_preds = models["D3STRON"](image)
66
  result_image = image.copy()
67
- result_image, filtered_detections = draw_boxes(result_image, detection_preds, confidence)
68
 
69
  # Save result image
70
  img_byte_arr = io.BytesIO()
71
  result_image.save(img_byte_arr, format='PNG')
72
  img_byte_arr = img_byte_arr.getvalue()
73
- img_b64 = base64.b64encode(img_byte_arr).decode()
74
 
75
  # Classifications
76
  class_results = {
@@ -78,12 +90,38 @@ async def predict(file: UploadFile = File(...), confidence: float = Form(default
78
  "Nandodeomkar": models["Nandodeomkar"](image)
79
  }
80
 
81
- return JSONResponse({
82
  "success": True,
83
  "detections": filtered_detections,
84
  "classifications": class_results,
85
- "image": img_b64
86
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  except Exception as e:
89
  return JSONResponse({
@@ -91,49 +129,77 @@ async def predict(file: UploadFile = File(...), confidence: float = Form(default
91
  "error": str(e)
92
  }, status_code=500)
93
 
94
- # Streamlit UI
95
- def main():
96
- st.title("🦴 Fraktur Detektion")
 
 
 
 
 
 
97
 
98
- uploaded_file = st.file_uploader("Röntgenbild hochladen", type=['png', 'jpg', 'jpeg'])
99
- confidence = st.slider("Konfidenzschwelle", 0.0, 1.0, 0.6, 0.05)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  if uploaded_file:
102
- # Afficher l'image originale
103
- image = Image.open(uploaded_file)
104
- st.image(image, caption="Original Röntgenbild", use_column_width=True)
105
 
106
- if st.button("Analysieren"):
107
- with st.spinner('Analyse läuft...'):
108
- # Object detection
109
- detection_preds = models["D3STRON"](image)
110
- result_image = image.copy()
111
- result_image, filtered_detections = draw_boxes(result_image, detection_preds, confidence)
112
-
113
- # Afficher l'image avec les détections
114
- st.image(result_image, caption="Erkannte Frakturen", use_column_width=True)
115
-
116
- # Afficher les détections
117
- st.subheader("Detektionen:")
118
- for detection in filtered_detections:
119
- st.write(f"- {detection['label']}: {detection['score']:.2%}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- # Classifications
122
- st.subheader("Klassifikationen:")
123
- class_results = {
124
- "Heem2": models["Heem2"](image),
125
- "Nandodeomkar": models["Nandodeomkar"](image)
126
- }
127
- st.json(class_results)
128
 
129
  def run_fastapi():
 
130
  uvicorn.run(app, host="0.0.0.0", port=8000)
131
 
132
  if __name__ == "__main__":
133
- # Démarrer FastAPI dans un thread séparé
134
- api_thread = Thread(target=run_fastapi)
135
- api_thread.daemon = True
136
  api_thread.start()
137
 
138
- # Lancer Streamlit
139
- main()
 
1
  import streamlit as st
2
+ import base64
3
+ from fastapi import FastAPI, Request
4
+ from fastapi.middleware.cors import CORSMiddleware
5
  from transformers import pipeline
6
  import torch
7
  from PIL import Image, ImageDraw
8
  import io
9
+ from threading import Thread
10
+ import uvicorn
 
 
11
  import json
12
+ import numpy as np
13
  from starlette.responses import JSONResponse
 
 
14
 
15
  # FastAPI app
16
  app = FastAPI()
 
24
  allow_headers=["*"],
25
  )
26
 
27
+ # Load models with caching
28
  @st.cache_resource
29
  def load_models():
30
+ try:
31
+ return {
32
+ "D3STRON": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
33
+ "Heem2": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"),
34
+ "Nandodeomkar": pipeline("image-classification",
35
+ model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
36
+ }
37
+ except Exception as e:
38
+ st.error(f"Error loading models: {str(e)}")
39
+ return None
40
 
41
+ # Initialize models
42
  models = load_models()
43
 
44
  def draw_boxes(image, predictions, threshold=0.6):
45
+ """
46
+ Draw bounding boxes on the image for fracture detections
47
+ """
48
  draw = ImageDraw.Draw(image)
49
  filtered_preds = [p for p in predictions if p['score'] >= threshold]
50
 
 
52
  box = pred['box']
53
  label = f"{pred['label']} ({pred['score']:.2%})"
54
 
55
+ # Draw rectangle
56
  draw.rectangle(
57
  [(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
58
  outline="red",
59
  width=2
60
  )
61
 
62
+ # Draw label
63
+ draw.text(
64
+ (box['xmin'], box['ymin'] - 10),
65
+ label,
66
+ fill="red"
67
+ )
68
 
69
  return image, filtered_preds
70
 
71
+ def process_image(image, confidence_threshold):
72
+ """
73
+ Process an image through all models and return results
74
+ """
75
  try:
 
 
 
76
  # Object detection
77
  detection_preds = models["D3STRON"](image)
78
  result_image = image.copy()
79
+ result_image, filtered_detections = draw_boxes(result_image, detection_preds, confidence_threshold)
80
 
81
  # Save result image
82
  img_byte_arr = io.BytesIO()
83
  result_image.save(img_byte_arr, format='PNG')
84
  img_byte_arr = img_byte_arr.getvalue()
85
+ result_base64 = base64.b64encode(img_byte_arr).decode()
86
 
87
  # Classifications
88
  class_results = {
 
90
  "Nandodeomkar": models["Nandodeomkar"](image)
91
  }
92
 
93
+ return {
94
  "success": True,
95
  "detections": filtered_detections,
96
  "classifications": class_results,
97
+ "image": result_base64
98
+ }
99
+
100
+ except Exception as e:
101
+ return {
102
+ "success": False,
103
+ "error": str(e)
104
+ }
105
+
106
+ # FastAPI endpoint
107
+ @app.post("/api/predict")
108
+ async def predict(request: Request):
109
+ try:
110
+ # Read JSON request body
111
+ body = await request.json()
112
+
113
+ # Extract base64 image and confidence threshold
114
+ image_base64 = body['data'][0]
115
+ confidence_threshold = float(body['data'][1])
116
+
117
+ # Decode base64 image
118
+ image_bytes = base64.b64decode(image_base64)
119
+ image = Image.open(io.BytesIO(image_bytes))
120
+
121
+ # Process image
122
+ result = process_image(image, confidence_threshold)
123
+
124
+ return JSONResponse(result)
125
 
126
  except Exception as e:
127
  return JSONResponse({
 
129
  "error": str(e)
130
  }, status_code=500)
131
 
132
+ # Streamlit interface
133
+ def streamlit_interface():
134
+ st.set_page_config(
135
+ page_title="Fracture Detection System",
136
+ page_icon="🦴",
137
+ layout="wide"
138
+ )
139
+
140
+ st.title("🦴 Système de Détection de Fractures")
141
 
142
+ # File uploader
143
+ uploaded_file = st.file_uploader(
144
+ "Upload X-ray Image",
145
+ type=['png', 'jpg', 'jpeg'],
146
+ help="Upload an X-ray image for fracture detection"
147
+ )
148
+
149
+ # Confidence threshold slider
150
+ confidence = st.slider(
151
+ "Confidence Threshold",
152
+ min_value=0.0,
153
+ max_value=1.0,
154
+ value=0.6,
155
+ step=0.05,
156
+ help="Adjust the confidence threshold for detection"
157
+ )
158
 
159
  if uploaded_file:
160
+ # Display original image
161
+ col1, col2 = st.columns(2)
 
162
 
163
+ with col1:
164
+ st.subheader("Original X-ray")
165
+ image = Image.open(uploaded_file)
166
+ st.image(image, use_column_width=True)
167
+
168
+ if st.button("Analyze"):
169
+ with st.spinner('Analyzing image...'):
170
+ try:
171
+ # Process image
172
+ results = process_image(image, confidence)
173
+
174
+ if results["success"]:
175
+ with col2:
176
+ st.subheader("Detection Results")
177
+ # Display processed image
178
+ result_image = Image.open(io.BytesIO(base64.b64decode(results["image"])))
179
+ st.image(result_image, use_column_width=True)
180
+
181
+ # Display detections
182
+ st.subheader("Detected Fractures:")
183
+ for detection in results["detections"]:
184
+ st.write(f"- {detection['label']}: {detection['score']:.2%}")
185
+
186
+ # Display classifications
187
+ st.subheader("Classification Results:")
188
+ st.json(results["classifications"])
189
+ else:
190
+ st.error("Error processing image: " + results.get("error", "Unknown error"))
191
 
192
+ except Exception as e:
193
+ st.error(f"Error during analysis: {str(e)}")
 
 
 
 
 
194
 
195
  def run_fastapi():
196
+ """Run the FastAPI server"""
197
  uvicorn.run(app, host="0.0.0.0", port=8000)
198
 
199
  if __name__ == "__main__":
200
+ # Start FastAPI in a separate thread
201
+ api_thread = Thread(target=run_fastapi, daemon=True)
 
202
  api_thread.start()
203
 
204
+ # Run Streamlit interface
205
+ streamlit_interface()