Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline | |
from PIL import Image, ImageDraw | |
import numpy as np | |
# Chargement des modèles | |
def load_models(): | |
return { | |
"KnochenAuge": pipeline("object-detection", model="D3STRON/bone-fracture-detr"), | |
"KnochenWächter": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"), | |
"RöntgenMeister": pipeline("image-classification", | |
model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388") | |
} | |
def translate_label(label): | |
translations = { | |
"fracture": "Knochenbruch", | |
"no fracture": "Kein Knochenbruch", | |
"normal": "Normal", | |
"abnormal": "Auffällig", | |
"F1": "Knochenbruch", | |
"NF": "Kein Knochenbruch" | |
} | |
return translations.get(label.lower(), label) | |
def create_heatmap_overlay(image, box, score): | |
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0)) | |
draw = ImageDraw.Draw(overlay) | |
x1, y1 = box['xmin'], box['ymin'] | |
x2, y2 = box['xmax'], box['ymax'] | |
if score > 0.8: | |
fill_color = (255, 0, 0, 100) | |
border_color = (255, 0, 0, 255) | |
elif score > 0.6: | |
fill_color = (255, 165, 0, 100) | |
border_color = (255, 165, 0, 255) | |
else: | |
fill_color = (255, 255, 0, 100) | |
border_color = (255, 255, 0, 255) | |
draw.rectangle([x1, y1, x2, y2], fill=fill_color) | |
draw.rectangle([x1, y1, x2, y2], outline=border_color, width=2) | |
return overlay | |
def draw_boxes(image, predictions): | |
result_image = image.copy().convert('RGBA') | |
for pred in predictions: | |
box = pred['box'] | |
score = pred['score'] | |
overlay = create_heatmap_overlay(image, box, score) | |
result_image = Image.alpha_composite(result_image, overlay) | |
draw = ImageDraw.Draw(result_image) | |
temp = 36.5 + (score * 2.5) | |
label = f"{translate_label(pred['label'])} ({score:.1%} • {temp:.1f}°C)" | |
text_bbox = draw.textbbox((box['xmin'], box['ymin']-20), label) | |
draw.rectangle(text_bbox, fill=(0, 0, 0, 180)) | |
draw.text( | |
(box['xmin'], box['ymin']-20), | |
label, | |
fill=(255, 255, 255, 255) | |
) | |
return result_image | |
# Modèles chargés globalement | |
models = load_models() | |
def analyze_image(image, conf_threshold=0.60): | |
if image is None: | |
return None, "Bitte laden Sie ein Bild hoch." | |
# Convertir en PIL Image si nécessaire | |
if not isinstance(image, Image.Image): | |
image = Image.fromarray(image) | |
# Analyses | |
predictions_watcher = models["KnochenWächter"](image) | |
predictions_master = models["RöntgenMeister"](image) | |
predictions_locator = models["KnochenAuge"](image) | |
has_fracture = False | |
max_fracture_score = 0 | |
result_html = "<div style='background: #f8f9fa; padding: 20px; border-radius: 10px;'>" | |
# KnochenWächter results | |
result_html += "<h3>🛡️ KnochenWächter</h3>" | |
for pred in predictions_watcher: | |
confidence_color = '#0066cc' if pred['score'] > 0.7 else '#ffa500' | |
if pred['score'] >= conf_threshold and 'fracture' in pred['label'].lower(): | |
has_fracture = True | |
max_fracture_score = max(max_fracture_score, pred['score']) | |
result_html += f""" | |
<div style='background: white; padding: 10px; margin: 5px 0; border-radius: 5px; border: 1px solid #e9ecef;'> | |
<span style='color: {confidence_color}; font-weight: 500;'> | |
{pred['score']:.1%} | |
</span> - {translate_label(pred['label'])} | |
</div> | |
""" | |
# RöntgenMeister results | |
result_html += "<h3>🎓 RöntgenMeister</h3>" | |
for pred in predictions_master: | |
confidence_color = '#0066cc' if pred['score'] > 0.7 else '#ffa500' | |
result_html += f""" | |
<div style='background: white; padding: 10px; margin: 5px 0; border-radius: 5px; border: 1px solid #e9ecef;'> | |
<span style='color: {confidence_color}; font-weight: 500;'> | |
{pred['score']:.1%} | |
</span> - {translate_label(pred['label'])} | |
</div> | |
""" | |
# Probabilité | |
if max_fracture_score > 0: | |
no_fracture_prob = 1 - max_fracture_score | |
result_html += f""" | |
<h3>📊 Wahrscheinlichkeit</h3> | |
<div style='background: white; padding: 10px; margin: 5px 0; border-radius: 5px; border: 1px solid #e9ecef;'> | |
Knochenbruch: <strong style='color: #0066cc'>{max_fracture_score:.1%}</strong><br> | |
Kein Knochenbruch: <strong style='color: #ffa500'>{no_fracture_prob:.1%}</strong> | |
</div> | |
""" | |
result_html += "</div>" | |
# Image processing | |
predictions = models["KnochenAuge"](image) | |
filtered_preds = [p for p in predictions if p['score'] >= conf_threshold] | |
if filtered_preds: | |
result_image = draw_boxes(image, filtered_preds) | |
return result_image, result_html | |
else: | |
return image, result_html | |
# Interface Gradio | |
css = """ | |
.gradio-container { | |
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif; | |
background-color: #f0f2f5; | |
} | |
.gr-button { | |
background-color: #f8f9fa !important; | |
border: 1px solid #e9ecef !important; | |
color: #1a1a1a !important; | |
} | |
.gr-button:hover { | |
background-color: #e9ecef !important; | |
transform: translateY(-1px); | |
} | |
.output-html { | |
background: white; | |
padding: 20px; | |
border-radius: 10px; | |
box-shadow: 0 2px 4px rgba(0,0,0,0.1); | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown("### 📤 Fraktur Detektion") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_image = gr.Image(type="pil", label="Röntgenbild hochladen") | |
conf_threshold = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.60, | |
step=0.05, | |
label="Konfidenzschwelle" | |
) | |
analyze_button = gr.Button("Analysieren", variant="primary") | |
with gr.Column(scale=1): | |
output_image = gr.Image(type="pil", label="Analysiertes Bild") | |
output_html = gr.HTML(label="Ergebnisse") | |
analyze_button.click( | |
fn=analyze_image, | |
inputs=[input_image, conf_threshold], | |
outputs=[output_image, output_html] | |
) | |
# Lancement de l'interface | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False, | |
favicon_path=None, | |
show_api=False, | |
show_error=False, | |
debug=False | |
) |