knoch / app.py
yassonee's picture
Update app.py
dbfa07b verified
raw
history blame
6.7 kB
import gradio as gr
from transformers import pipeline
from PIL import Image, ImageDraw
import numpy as np
# Chargement des modèles
def load_models():
return {
"KnochenAuge": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
"KnochenWächter": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"),
"RöntgenMeister": pipeline("image-classification",
model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
}
def translate_label(label):
translations = {
"fracture": "Knochenbruch",
"no fracture": "Kein Knochenbruch",
"normal": "Normal",
"abnormal": "Auffällig",
"F1": "Knochenbruch",
"NF": "Kein Knochenbruch"
}
return translations.get(label.lower(), label)
def create_heatmap_overlay(image, box, score):
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
draw = ImageDraw.Draw(overlay)
x1, y1 = box['xmin'], box['ymin']
x2, y2 = box['xmax'], box['ymax']
if score > 0.8:
fill_color = (255, 0, 0, 100)
border_color = (255, 0, 0, 255)
elif score > 0.6:
fill_color = (255, 165, 0, 100)
border_color = (255, 165, 0, 255)
else:
fill_color = (255, 255, 0, 100)
border_color = (255, 255, 0, 255)
draw.rectangle([x1, y1, x2, y2], fill=fill_color)
draw.rectangle([x1, y1, x2, y2], outline=border_color, width=2)
return overlay
def draw_boxes(image, predictions):
result_image = image.copy().convert('RGBA')
for pred in predictions:
box = pred['box']
score = pred['score']
overlay = create_heatmap_overlay(image, box, score)
result_image = Image.alpha_composite(result_image, overlay)
draw = ImageDraw.Draw(result_image)
temp = 36.5 + (score * 2.5)
label = f"{translate_label(pred['label'])} ({score:.1%}{temp:.1f}°C)"
text_bbox = draw.textbbox((box['xmin'], box['ymin']-20), label)
draw.rectangle(text_bbox, fill=(0, 0, 0, 180))
draw.text(
(box['xmin'], box['ymin']-20),
label,
fill=(255, 255, 255, 255)
)
return result_image
# Modèles chargés globalement
models = load_models()
def analyze_image(image, conf_threshold=0.60):
if image is None:
return None, "Bitte laden Sie ein Bild hoch."
# Convertir en PIL Image si nécessaire
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
# Analyses
predictions_watcher = models["KnochenWächter"](image)
predictions_master = models["RöntgenMeister"](image)
predictions_locator = models["KnochenAuge"](image)
has_fracture = False
max_fracture_score = 0
result_html = "<div style='background: #f8f9fa; padding: 20px; border-radius: 10px;'>"
# KnochenWächter results
result_html += "<h3>🛡️ KnochenWächter</h3>"
for pred in predictions_watcher:
confidence_color = '#0066cc' if pred['score'] > 0.7 else '#ffa500'
if pred['score'] >= conf_threshold and 'fracture' in pred['label'].lower():
has_fracture = True
max_fracture_score = max(max_fracture_score, pred['score'])
result_html += f"""
<div style='background: white; padding: 10px; margin: 5px 0; border-radius: 5px; border: 1px solid #e9ecef;'>
<span style='color: {confidence_color}; font-weight: 500;'>
{pred['score']:.1%}
</span> - {translate_label(pred['label'])}
</div>
"""
# RöntgenMeister results
result_html += "<h3>🎓 RöntgenMeister</h3>"
for pred in predictions_master:
confidence_color = '#0066cc' if pred['score'] > 0.7 else '#ffa500'
result_html += f"""
<div style='background: white; padding: 10px; margin: 5px 0; border-radius: 5px; border: 1px solid #e9ecef;'>
<span style='color: {confidence_color}; font-weight: 500;'>
{pred['score']:.1%}
</span> - {translate_label(pred['label'])}
</div>
"""
# Probabilité
if max_fracture_score > 0:
no_fracture_prob = 1 - max_fracture_score
result_html += f"""
<h3>📊 Wahrscheinlichkeit</h3>
<div style='background: white; padding: 10px; margin: 5px 0; border-radius: 5px; border: 1px solid #e9ecef;'>
Knochenbruch: <strong style='color: #0066cc'>{max_fracture_score:.1%}</strong><br>
Kein Knochenbruch: <strong style='color: #ffa500'>{no_fracture_prob:.1%}</strong>
</div>
"""
result_html += "</div>"
# Image processing
predictions = models["KnochenAuge"](image)
filtered_preds = [p for p in predictions if p['score'] >= conf_threshold]
if filtered_preds:
result_image = draw_boxes(image, filtered_preds)
return result_image, result_html
else:
return image, result_html
# Interface Gradio
css = """
.gradio-container {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif;
background-color: #f0f2f5;
}
.gr-button {
background-color: #f8f9fa !important;
border: 1px solid #e9ecef !important;
color: #1a1a1a !important;
}
.gr-button:hover {
background-color: #e9ecef !important;
transform: translateY(-1px);
}
.output-html {
background: white;
padding: 20px;
border-radius: 10px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown("### 📤 Fraktur Detektion")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Röntgenbild hochladen")
conf_threshold = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.60,
step=0.05,
label="Konfidenzschwelle"
)
analyze_button = gr.Button("Analysieren", variant="primary")
with gr.Column(scale=1):
output_image = gr.Image(type="pil", label="Analysiertes Bild")
output_html = gr.HTML(label="Ergebnisse")
analyze_button.click(
fn=analyze_image,
inputs=[input_image, conf_threshold],
outputs=[output_image, output_html]
)
# Lancement de l'interface
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
favicon_path=None,
show_api=False,
show_error=False,
debug=False
)