yassonee commited on
Commit
a605c36
·
verified ·
1 Parent(s): bb66eea

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -0
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+ from PIL import Image, ImageDraw
4
+ import torch
5
+
6
+ def load_models():
7
+ return {
8
+ "KnochenAuge": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
9
+ "KnochenWächter": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"),
10
+ "RöntgenMeister": pipeline("image-classification",
11
+ model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
12
+ }
13
+
14
+ def draw_boxes(image, predictions, conf_threshold=0.6):
15
+ draw = ImageDraw.Draw(image)
16
+ fractures_found = False
17
+
18
+ for pred in predictions:
19
+ if pred['label'].lower() == 'fracture' and pred['score'] >= conf_threshold:
20
+ fractures_found = True
21
+ box = pred['box']
22
+ label = f"Fraktur ({pred['score']:.1%})"
23
+ color = "#2563eb" if pred['score'] > 0.7 else "#eab308"
24
+
25
+ draw.rectangle(
26
+ [(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
27
+ outline=color,
28
+ width=2
29
+ )
30
+
31
+ text_bbox = draw.textbbox((box['xmin'], box['ymin']-15), label)
32
+ draw.rectangle(text_bbox, fill=color)
33
+ draw.text((box['xmin'], box['ymin']-15), label, fill="white")
34
+
35
+ return image if fractures_found else None
36
+
37
+ def analyze_images(images, conf_threshold=0.6):
38
+ models = load_models()
39
+ results = []
40
+
41
+ for img in images:
42
+ pil_img = Image.fromarray(img)
43
+
44
+ # KnochenAuge Analysis
45
+ predictions = models["KnochenAuge"](pil_img)
46
+ fractures_found = any(p['label'].lower() == 'fracture' and p['score'] >= conf_threshold
47
+ for p in predictions)
48
+
49
+ if fractures_found:
50
+ # Draw boxes on image
51
+ result_image = draw_boxes(pil_img.copy(), predictions, conf_threshold)
52
+
53
+ # Additional analyses
54
+ wachter_pred = models["KnochenWächter"](pil_img)[0]
55
+ meister_pred = models["RöntgenMeister"](pil_img)[0]
56
+
57
+ if result_image:
58
+ results.append({
59
+ "image": result_image,
60
+ "knochen_wachter": f"KnochenWächter: {wachter_pred['score']:.1%}",
61
+ "rontgen_meister": f"RöntgenMeister: {meister_pred['score']:.1%}"
62
+ })
63
+
64
+ # Format results for display
65
+ if not results:
66
+ return None, "Keine Frakturen gefunden."
67
+
68
+ output_images = [r["image"] for r in results]
69
+ analysis_text = "\n\n".join([
70
+ f"Bild {i+1}:\n{r['knochen_wachter']}\n{r['rontgen_meister']}"
71
+ for i, r in enumerate(results)
72
+ ])
73
+
74
+ return output_images, analysis_text
75
+
76
+ # Interface configuration
77
+ css = """
78
+ .gradio-container {
79
+ background-color: transparent !important;
80
+ }
81
+ .dark {
82
+ background-color: #1f2937;
83
+ color: #f3f4f6;
84
+ }
85
+ .light {
86
+ background-color: #ffffff;
87
+ color: #1f2937;
88
+ }
89
+ """
90
+
91
+ with gr.Blocks(css=css) as demo:
92
+ with gr.Row():
93
+ with gr.Column(scale=1):
94
+ file_upload = gr.File(
95
+ label="Röntgenbilder hochladen",
96
+ file_types=["image"],
97
+ file_count="multiple"
98
+ )
99
+ conf_slider = gr.Slider(
100
+ minimum=0.0,
101
+ maximum=1.0,
102
+ value=0.6,
103
+ step=0.05,
104
+ label="Konfidenzschwelle"
105
+ )
106
+ analyze_btn = gr.Button("Bilder analysieren", variant="primary")
107
+
108
+ with gr.Column(scale=2):
109
+ gallery = gr.Gallery(label="Ergebnisse").style(grid=2)
110
+ analysis_output = gr.Textbox(label="KI-Analyse", lines=4)
111
+
112
+ analyze_btn.click(
113
+ fn=analyze_images,
114
+ inputs=[file_upload, conf_slider],
115
+ outputs=[gallery, analysis_output]
116
+ )
117
+
118
+ # Launch configuration
119
+ demo.launch(
120
+ show_api=False,
121
+ share=False,
122
+ server_name="0.0.0.0",
123
+ server_port=7860,
124
+ show_error=True,
125
+ enable_queue=True
126
+ )