yassonee commited on
Commit
dbfa07b
·
verified ·
1 Parent(s): ab11292

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -85
app.py CHANGED
@@ -1,8 +1,9 @@
1
  import gradio as gr
2
  from transformers import pipeline
3
  from PIL import Image, ImageDraw
4
- import torch
5
 
 
6
  def load_models():
7
  return {
8
  "KnochenAuge": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
@@ -11,116 +12,188 @@ def load_models():
11
  model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
12
  }
13
 
14
- def draw_boxes(image, predictions, conf_threshold=0.6):
15
- draw = ImageDraw.Draw(image)
16
- fractures_found = False
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- for pred in predictions:
19
- if pred['label'].lower() == 'fracture' and pred['score'] >= conf_threshold:
20
- fractures_found = True
21
- box = pred['box']
22
- label = f"Fraktur ({pred['score']:.1%})"
23
- color = "#2563eb" if pred['score'] > 0.7 else "#eab308"
24
-
25
- draw.rectangle(
26
- [(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
27
- outline=color,
28
- width=2
29
- )
30
-
31
- text_bbox = draw.textbbox((box['xmin'], box['ymin']-15), label)
32
- draw.rectangle(text_bbox, fill=color)
33
- draw.text((box['xmin'], box['ymin']-15), label, fill="white")
34
 
35
- return image if fractures_found else None
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- def analyze_images(images, conf_threshold=0.6):
38
- models = load_models()
39
- results = []
40
 
41
- for img in images:
42
- pil_img = Image.fromarray(img)
 
 
 
 
43
 
44
- # KnochenAuge Analysis
45
- predictions = models["KnochenAuge"](pil_img)
46
- fractures_found = any(p['label'].lower() == 'fracture' and p['score'] >= conf_threshold
47
- for p in predictions)
48
 
49
- if fractures_found:
50
- # Draw boxes on image
51
- result_image = draw_boxes(pil_img.copy(), predictions, conf_threshold)
52
-
53
- # Additional analyses
54
- wachter_pred = models["KnochenWächter"](pil_img)[0]
55
- meister_pred = models["RöntgenMeister"](pil_img)[0]
56
-
57
- if result_image:
58
- results.append({
59
- "image": result_image,
60
- "knochen_wachter": f"KnochenWächter: {wachter_pred['score']:.1%}",
61
- "rontgen_meister": f"RöntgenMeister: {meister_pred['score']:.1%}"
62
- })
63
-
64
- # Format results for display
65
- if not results:
66
- return None, "Keine Frakturen gefunden."
67
-
68
- output_images = [r["image"] for r in results]
69
- analysis_text = "\n\n".join([
70
- f"Bild {i+1}:\n{r['knochen_wachter']}\n{r['rontgen_meister']}"
71
- for i, r in enumerate(results)
72
- ])
73
-
74
- return output_images, analysis_text
75
 
76
- # Interface configuration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  css = """
78
  .gradio-container {
79
- background-color: transparent !important;
 
 
 
 
 
 
80
  }
81
- .dark {
82
- background-color: #1f2937;
83
- color: #f3f4f6;
84
  }
85
- .light {
86
- background-color: #ffffff;
87
- color: #1f2937;
 
 
88
  }
89
  """
90
 
91
  with gr.Blocks(css=css) as demo:
 
 
92
  with gr.Row():
93
  with gr.Column(scale=1):
94
- file_upload = gr.File(
95
- label="Röntgenbilder hochladen",
96
- file_types=["image"],
97
- file_count="multiple"
98
- )
99
- conf_slider = gr.Slider(
100
  minimum=0.0,
101
  maximum=1.0,
102
- value=0.6,
103
  step=0.05,
104
  label="Konfidenzschwelle"
105
  )
106
- analyze_btn = gr.Button("Bilder analysieren", variant="primary")
107
 
108
- with gr.Column(scale=2):
109
- gallery = gr.Gallery(label="Ergebnisse").style(grid=2)
110
- analysis_output = gr.Textbox(label="KI-Analyse", lines=4)
111
-
112
- analyze_btn.click(
113
- fn=analyze_images,
114
- inputs=[file_upload, conf_slider],
115
- outputs=[gallery, analysis_output]
116
  )
117
 
118
- # Launch configuration
119
  demo.launch(
120
- show_api=False,
121
- share=False,
122
  server_name="0.0.0.0",
123
  server_port=7860,
124
- show_error=True,
125
- enable_queue=True
 
 
 
126
  )
 
1
  import gradio as gr
2
  from transformers import pipeline
3
  from PIL import Image, ImageDraw
4
+ import numpy as np
5
 
6
+ # Chargement des modèles
7
  def load_models():
8
  return {
9
  "KnochenAuge": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
 
12
  model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
13
  }
14
 
15
+ def translate_label(label):
16
+ translations = {
17
+ "fracture": "Knochenbruch",
18
+ "no fracture": "Kein Knochenbruch",
19
+ "normal": "Normal",
20
+ "abnormal": "Auffällig",
21
+ "F1": "Knochenbruch",
22
+ "NF": "Kein Knochenbruch"
23
+ }
24
+ return translations.get(label.lower(), label)
25
+
26
+ def create_heatmap_overlay(image, box, score):
27
+ overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
28
+ draw = ImageDraw.Draw(overlay)
29
 
30
+ x1, y1 = box['xmin'], box['ymin']
31
+ x2, y2 = box['xmax'], box['ymax']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ if score > 0.8:
34
+ fill_color = (255, 0, 0, 100)
35
+ border_color = (255, 0, 0, 255)
36
+ elif score > 0.6:
37
+ fill_color = (255, 165, 0, 100)
38
+ border_color = (255, 165, 0, 255)
39
+ else:
40
+ fill_color = (255, 255, 0, 100)
41
+ border_color = (255, 255, 0, 255)
42
+
43
+ draw.rectangle([x1, y1, x2, y2], fill=fill_color)
44
+ draw.rectangle([x1, y1, x2, y2], outline=border_color, width=2)
45
+
46
+ return overlay
47
 
48
+ def draw_boxes(image, predictions):
49
+ result_image = image.copy().convert('RGBA')
 
50
 
51
+ for pred in predictions:
52
+ box = pred['box']
53
+ score = pred['score']
54
+
55
+ overlay = create_heatmap_overlay(image, box, score)
56
+ result_image = Image.alpha_composite(result_image, overlay)
57
 
58
+ draw = ImageDraw.Draw(result_image)
59
+ temp = 36.5 + (score * 2.5)
60
+ label = f"{translate_label(pred['label'])} ({score:.1%} {temp:.1f}°C)"
 
61
 
62
+ text_bbox = draw.textbbox((box['xmin'], box['ymin']-20), label)
63
+ draw.rectangle(text_bbox, fill=(0, 0, 0, 180))
64
+
65
+ draw.text(
66
+ (box['xmin'], box['ymin']-20),
67
+ label,
68
+ fill=(255, 255, 255, 255)
69
+ )
70
+
71
+ return result_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ # Modèles chargés globalement
74
+ models = load_models()
75
+
76
+ def analyze_image(image, conf_threshold=0.60):
77
+ if image is None:
78
+ return None, "Bitte laden Sie ein Bild hoch."
79
+
80
+ # Convertir en PIL Image si nécessaire
81
+ if not isinstance(image, Image.Image):
82
+ image = Image.fromarray(image)
83
+
84
+ # Analyses
85
+ predictions_watcher = models["KnochenWächter"](image)
86
+ predictions_master = models["RöntgenMeister"](image)
87
+ predictions_locator = models["KnochenAuge"](image)
88
+
89
+ has_fracture = False
90
+ max_fracture_score = 0
91
+ result_html = "<div style='background: #f8f9fa; padding: 20px; border-radius: 10px;'>"
92
+
93
+ # KnochenWächter results
94
+ result_html += "<h3>🛡️ KnochenWächter</h3>"
95
+ for pred in predictions_watcher:
96
+ confidence_color = '#0066cc' if pred['score'] > 0.7 else '#ffa500'
97
+ if pred['score'] >= conf_threshold and 'fracture' in pred['label'].lower():
98
+ has_fracture = True
99
+ max_fracture_score = max(max_fracture_score, pred['score'])
100
+ result_html += f"""
101
+ <div style='background: white; padding: 10px; margin: 5px 0; border-radius: 5px; border: 1px solid #e9ecef;'>
102
+ <span style='color: {confidence_color}; font-weight: 500;'>
103
+ {pred['score']:.1%}
104
+ </span> - {translate_label(pred['label'])}
105
+ </div>
106
+ """
107
+
108
+ # RöntgenMeister results
109
+ result_html += "<h3>🎓 RöntgenMeister</h3>"
110
+ for pred in predictions_master:
111
+ confidence_color = '#0066cc' if pred['score'] > 0.7 else '#ffa500'
112
+ result_html += f"""
113
+ <div style='background: white; padding: 10px; margin: 5px 0; border-radius: 5px; border: 1px solid #e9ecef;'>
114
+ <span style='color: {confidence_color}; font-weight: 500;'>
115
+ {pred['score']:.1%}
116
+ </span> - {translate_label(pred['label'])}
117
+ </div>
118
+ """
119
+
120
+ # Probabilité
121
+ if max_fracture_score > 0:
122
+ no_fracture_prob = 1 - max_fracture_score
123
+ result_html += f"""
124
+ <h3>📊 Wahrscheinlichkeit</h3>
125
+ <div style='background: white; padding: 10px; margin: 5px 0; border-radius: 5px; border: 1px solid #e9ecef;'>
126
+ Knochenbruch: <strong style='color: #0066cc'>{max_fracture_score:.1%}</strong><br>
127
+ Kein Knochenbruch: <strong style='color: #ffa500'>{no_fracture_prob:.1%}</strong>
128
+ </div>
129
+ """
130
+
131
+ result_html += "</div>"
132
+
133
+ # Image processing
134
+ predictions = models["KnochenAuge"](image)
135
+ filtered_preds = [p for p in predictions if p['score'] >= conf_threshold]
136
+ if filtered_preds:
137
+ result_image = draw_boxes(image, filtered_preds)
138
+ return result_image, result_html
139
+ else:
140
+ return image, result_html
141
+
142
+ # Interface Gradio
143
  css = """
144
  .gradio-container {
145
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif;
146
+ background-color: #f0f2f5;
147
+ }
148
+ .gr-button {
149
+ background-color: #f8f9fa !important;
150
+ border: 1px solid #e9ecef !important;
151
+ color: #1a1a1a !important;
152
  }
153
+ .gr-button:hover {
154
+ background-color: #e9ecef !important;
155
+ transform: translateY(-1px);
156
  }
157
+ .output-html {
158
+ background: white;
159
+ padding: 20px;
160
+ border-radius: 10px;
161
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
162
  }
163
  """
164
 
165
  with gr.Blocks(css=css) as demo:
166
+ gr.Markdown("### 📤 Fraktur Detektion")
167
+
168
  with gr.Row():
169
  with gr.Column(scale=1):
170
+ input_image = gr.Image(type="pil", label="Röntgenbild hochladen")
171
+ conf_threshold = gr.Slider(
 
 
 
 
172
  minimum=0.0,
173
  maximum=1.0,
174
+ value=0.60,
175
  step=0.05,
176
  label="Konfidenzschwelle"
177
  )
178
+ analyze_button = gr.Button("Analysieren", variant="primary")
179
 
180
+ with gr.Column(scale=1):
181
+ output_image = gr.Image(type="pil", label="Analysiertes Bild")
182
+ output_html = gr.HTML(label="Ergebnisse")
183
+
184
+ analyze_button.click(
185
+ fn=analyze_image,
186
+ inputs=[input_image, conf_threshold],
187
+ outputs=[output_image, output_html]
188
  )
189
 
190
+ # Lancement de l'interface
191
  demo.launch(
 
 
192
  server_name="0.0.0.0",
193
  server_port=7860,
194
+ share=False,
195
+ favicon_path=None,
196
+ show_api=False,
197
+ show_error=False,
198
+ debug=False
199
  )