yassonee commited on
Commit
d119a53
·
verified ·
1 Parent(s): 81ff8c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -137
app.py CHANGED
@@ -1,23 +1,13 @@
1
  import streamlit as st
2
-
3
- # Set page config must be the first Streamlit command
4
- st.set_page_config(
5
- page_title="Fracture Detection System",
6
- page_icon="🦴",
7
- layout="wide"
8
- )
9
-
10
- import base64
11
- from fastapi import FastAPI, Request
12
- from fastapi.middleware.cors import CORSMiddleware
13
  from transformers import pipeline
14
  import torch
15
  from PIL import Image, ImageDraw
16
  import io
17
- from threading import Thread
18
- import uvicorn
19
- import json
20
  import numpy as np
 
21
  from starlette.responses import JSONResponse
22
 
23
  # FastAPI app
@@ -32,27 +22,19 @@ app.add_middleware(
32
  allow_headers=["*"],
33
  )
34
 
35
- # Load models with caching
36
  @st.cache_resource
37
  def load_models():
38
- try:
39
- return {
40
- "D3STRON": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
41
- "Heem2": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"),
42
- "Nandodeomkar": pipeline("image-classification",
43
- model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
44
- }
45
- except Exception as e:
46
- st.error(f"Error loading models: {str(e)}")
47
- return None
48
 
49
- # Initialize models
50
  models = load_models()
51
 
52
  def draw_boxes(image, predictions, threshold=0.6):
53
- """
54
- Draw bounding boxes on the image for fracture detections
55
- """
56
  draw = ImageDraw.Draw(image)
57
  filtered_preds = [p for p in predictions if p['score'] >= threshold]
58
 
@@ -60,148 +42,68 @@ def draw_boxes(image, predictions, threshold=0.6):
60
  box = pred['box']
61
  label = f"{pred['label']} ({pred['score']:.2%})"
62
 
63
- # Draw rectangle
64
  draw.rectangle(
65
  [(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
66
  outline="red",
67
  width=2
68
  )
69
 
70
- # Draw label
71
- draw.text(
72
- (box['xmin'], box['ymin'] - 10),
73
- label,
74
- fill="red"
75
- )
76
 
77
  return image, filtered_preds
78
 
79
- def process_image(image, confidence_threshold):
80
- """
81
- Process an image through all models and return results
82
- """
83
  try:
84
- # Object detection
 
 
 
 
 
 
 
85
  detection_preds = models["D3STRON"](image)
86
  result_image = image.copy()
87
- result_image, filtered_detections = draw_boxes(result_image, detection_preds, confidence_threshold)
88
 
89
  # Save result image
90
  img_byte_arr = io.BytesIO()
91
  result_image.save(img_byte_arr, format='PNG')
92
  img_byte_arr = img_byte_arr.getvalue()
93
- result_base64 = base64.b64encode(img_byte_arr).decode()
94
 
95
- # Classifications
96
  class_results = {
97
  "Heem2": models["Heem2"](image),
98
  "Nandodeomkar": models["Nandodeomkar"](image)
99
  }
100
 
101
- return {
102
  "success": True,
103
  "detections": filtered_detections,
104
  "classifications": class_results,
105
- "image": result_base64
106
- }
107
-
108
- except Exception as e:
109
- return {
110
- "success": False,
111
- "error": str(e)
112
- }
113
-
114
- # FastAPI endpoint
115
- @app.post("/api/predict")
116
- async def predict(request: Request):
117
- try:
118
- # Read JSON request body
119
- body = await request.json()
120
-
121
- # Extract base64 image and confidence threshold
122
- image_base64 = body['data'][0]
123
- confidence_threshold = float(body['data'][1])
124
-
125
- # Decode base64 image
126
- image_bytes = base64.b64decode(image_base64)
127
- image = Image.open(io.BytesIO(image_bytes))
128
-
129
- # Process image
130
- result = process_image(image, confidence_threshold)
131
-
132
- return JSONResponse(result)
133
 
134
  except Exception as e:
135
  return JSONResponse({
136
  "success": False,
137
  "error": str(e)
138
- }, status_code=500)
139
 
140
- # Streamlit interface
141
- def streamlit_interface():
142
- st.title("🦴 Système de Détection de Fractures")
143
-
144
- # File uploader
145
- uploaded_file = st.file_uploader(
146
- "Upload X-ray Image",
147
- type=['png', 'jpg', 'jpeg'],
148
- help="Upload an X-ray image for fracture detection"
149
- )
150
 
151
- # Confidence threshold slider
152
- confidence = st.slider(
153
- "Confidence Threshold",
154
- min_value=0.0,
155
- max_value=1.0,
156
- value=0.6,
157
- step=0.05,
158
- help="Adjust the confidence threshold for detection"
159
- )
160
 
161
  if uploaded_file:
162
- # Display original image
163
- col1, col2 = st.columns(2)
164
-
165
- with col1:
166
- st.subheader("Original X-ray")
167
- image = Image.open(uploaded_file)
168
- st.image(image, use_column_width=True)
169
-
170
- if st.button("Analyze"):
171
- with st.spinner('Analyzing image...'):
172
- try:
173
- # Process image
174
- results = process_image(image, confidence)
175
-
176
- if results["success"]:
177
- with col2:
178
- st.subheader("Detection Results")
179
- # Display processed image
180
- result_image = Image.open(io.BytesIO(base64.b64decode(results["image"])))
181
- st.image(result_image, use_column_width=True)
182
-
183
- # Display detections
184
- st.subheader("Detected Fractures:")
185
- for detection in results["detections"]:
186
- st.write(f"- {detection['label']}: {detection['score']:.2%}")
187
-
188
- # Display classifications
189
- st.subheader("Classification Results:")
190
- st.json(results["classifications"])
191
- else:
192
- st.error("Error processing image: " + results.get("error", "Unknown error"))
193
-
194
- except Exception as e:
195
- st.error(f"Error during analysis: {str(e)}")
196
-
197
- def run_fastapi():
198
- """Run the FastAPI server"""
199
- uvicorn.run(app, host="0.0.0.0", port=8000)
200
 
201
  if __name__ == "__main__":
202
- # Start FastAPI in a separate thread
203
- api_thread = Thread(target=run_fastapi, daemon=True)
204
- api_thread.start()
205
-
206
- # Run Streamlit interface
207
- streamlit_interface()
 
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
2
  from transformers import pipeline
3
  import torch
4
  from PIL import Image, ImageDraw
5
  import io
6
+ import base64
7
+ from fastapi import FastAPI, File, UploadFile
8
+ from fastapi.middleware.cors import CORSMiddleware
9
  import numpy as np
10
+ import json
11
  from starlette.responses import JSONResponse
12
 
13
  # FastAPI app
 
22
  allow_headers=["*"],
23
  )
24
 
25
+ # Load models
26
  @st.cache_resource
27
  def load_models():
28
+ return {
29
+ "D3STRON": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
30
+ "Heem2": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"),
31
+ "Nandodeomkar": pipeline("image-classification",
32
+ model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
33
+ }
 
 
 
 
34
 
 
35
  models = load_models()
36
 
37
  def draw_boxes(image, predictions, threshold=0.6):
 
 
 
38
  draw = ImageDraw.Draw(image)
39
  filtered_preds = [p for p in predictions if p['score'] >= threshold]
40
 
 
42
  box = pred['box']
43
  label = f"{pred['label']} ({pred['score']:.2%})"
44
 
 
45
  draw.rectangle(
46
  [(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
47
  outline="red",
48
  width=2
49
  )
50
 
51
+ draw.text((box['xmin'], box['ymin']), label, fill="red")
 
 
 
 
 
52
 
53
  return image, filtered_preds
54
 
55
+ # API Endpoint
56
+ @app.post("/detect")
57
+ async def detect_fracture(file: UploadFile = File(...), confidence: float = 0.6):
 
58
  try:
59
+ # Read and process image
60
+ contents = await file.read()
61
+ image = Image.open(io.BytesIO(contents))
62
+
63
+ # Get predictions from all models
64
+ results = {}
65
+
66
+ # Object detection models
67
  detection_preds = models["D3STRON"](image)
68
  result_image = image.copy()
69
+ result_image, filtered_detections = draw_boxes(result_image, detection_preds, confidence)
70
 
71
  # Save result image
72
  img_byte_arr = io.BytesIO()
73
  result_image.save(img_byte_arr, format='PNG')
74
  img_byte_arr = img_byte_arr.getvalue()
75
+ img_b64 = base64.b64encode(img_byte_arr).decode()
76
 
77
+ # Classification models
78
  class_results = {
79
  "Heem2": models["Heem2"](image),
80
  "Nandodeomkar": models["Nandodeomkar"](image)
81
  }
82
 
83
+ return JSONResponse({
84
  "success": True,
85
  "detections": filtered_detections,
86
  "classifications": class_results,
87
+ "image": img_b64
88
+ })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  except Exception as e:
91
  return JSONResponse({
92
  "success": False,
93
  "error": str(e)
94
+ })
95
 
96
+ # Streamlit UI
97
+ def main():
98
+ st.title("🦴 Fraktur Detektion")
 
 
 
 
 
 
 
99
 
100
+ # UI elements...
101
+ uploaded_file = st.file_uploader("Röntgenbild hochladen", type=['png', 'jpg', 'jpeg'])
102
+ confidence = st.slider("Konfidenzschwelle", 0.0, 1.0, 0.6, 0.05)
 
 
 
 
 
 
103
 
104
  if uploaded_file:
105
+ # Process image and display results...
106
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  if __name__ == "__main__":
109
+ main()