yassonee commited on
Commit
88fb5fa
·
verified ·
1 Parent(s): a443273

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -33
app.py CHANGED
@@ -1,27 +1,59 @@
1
  import streamlit as st
2
  from transformers import pipeline
3
- from PIL import Image
4
- import io
5
 
6
- st.set_page_config(page_title="Knochenbrucherkennung", layout="centered")
7
 
8
  @st.cache_resource
9
- def load_model():
10
- return pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray")
 
 
 
 
 
 
 
11
 
12
- def main():
13
- st.title("🦴 Knochenbrucherkennung")
14
- st.write("Laden Sie ein Röntgenbild hoch.")
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- pipe = load_model()
17
-
18
- uploaded_file = st.file_uploader(
19
- "Röntgenbild auswählen",
20
- type=['png', 'jpg', 'jpeg']
21
- )
 
 
 
 
 
22
 
 
 
 
 
 
 
 
23
  conf_threshold = st.slider(
24
- "Konfidenzschwelle",
25
  min_value=0.0,
26
  max_value=1.0,
27
  value=0.3,
@@ -30,27 +62,38 @@ def main():
30
 
31
  if uploaded_file:
32
  image = Image.open(uploaded_file)
33
-
34
- # Redimensionner l'image
35
  max_size = (400, 400)
36
  image.thumbnail(max_size, Image.Resampling.LANCZOS)
37
 
38
- st.image(image, caption="Hochgeladenes Bild")
39
-
40
- with st.spinner("Analyse läuft..."):
41
- predictions = pipe(image)
42
-
43
- st.subheader("Ergebnisse")
44
- for pred in predictions:
45
- if pred['score'] >= conf_threshold:
46
- label = "Bruch erkannt" if "fracture" in pred['label'].lower() else "Kein Bruch"
47
- st.write(f" Diagnose: {label}")
48
- st.write(f"• Konfidenz: {pred['score']:.2%}")
49
-
50
- if "fracture" in pred['label'].lower() and pred['score'] >= conf_threshold:
51
- st.warning("⚠️ Möglicher Knochenbruch erkannt!")
52
- else:
53
- st.success(" Kein Bruch erkannt")
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  if __name__ == "__main__":
56
  main()
 
1
  import streamlit as st
2
  from transformers import pipeline
3
+ from PIL import Image, ImageDraw
4
+ import torch
5
 
6
+ st.set_page_config(page_title="Multi-Model Fracture Detection", layout="wide")
7
 
8
  @st.cache_resource
9
+ def load_models():
10
+ models = {
11
+ "D3STRON (Object Detection)": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
12
+ "Heem2 (Classification)": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"),
13
+ "Akhileshav8 (Classification)": pipeline("image-classification", model="akhileshav8/image_classification_for_fracture"),
14
+ "Nandodeomkar (Classification)": pipeline("image-classification", model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388"),
15
+ "Anirban22 (Object Detection)": pipeline("object-detection", model="anirban22/detr-resnet-50-med_fracture")
16
+ }
17
+ return models
18
 
19
+ def draw_boxes(image, predictions):
20
+ draw = ImageDraw.Draw(image)
21
+ for pred in predictions:
22
+ box = pred['box']
23
+ label = f"{pred['label']} ({pred['score']:.2%})"
24
+
25
+ draw.rectangle(
26
+ [(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
27
+ outline="red",
28
+ width=3
29
+ )
30
+
31
+ text_bbox = draw.textbbox((box['xmin'], box['ymin']), label)
32
+ draw.rectangle(text_bbox, fill="red")
33
+ draw.text((box['xmin'], box['ymin']), label, fill="white")
34
+ return image
35
 
36
+ def process_classification(model, image, conf_threshold):
37
+ predictions = model(image)
38
+ results = []
39
+ for pred in predictions:
40
+ if pred['score'] >= conf_threshold:
41
+ results.append(f"{pred['label']}: {pred['score']:.2%}")
42
+ return results
43
+
44
+ def process_detection(model, image, conf_threshold):
45
+ predictions = model(image)
46
+ return [pred for pred in predictions if pred['score'] >= conf_threshold]
47
 
48
+ def main():
49
+ st.title("🦴 Multi-Model Fracture Detection")
50
+
51
+ models = load_models()
52
+
53
+ uploaded_file = st.file_uploader("Upload X-ray image", type=['png', 'jpg', 'jpeg'])
54
+
55
  conf_threshold = st.slider(
56
+ "Confidence threshold",
57
  min_value=0.0,
58
  max_value=1.0,
59
  value=0.3,
 
62
 
63
  if uploaded_file:
64
  image = Image.open(uploaded_file)
 
 
65
  max_size = (400, 400)
66
  image.thumbnail(max_size, Image.Resampling.LANCZOS)
67
 
68
+ st.image(image, caption="Original Image", width=400)
69
+
70
+ col1, col2 = st.columns(2)
71
+
72
+ with col1:
73
+ st.subheader("Classification Models")
74
+ for name, model in models.items():
75
+ if "Classification" in name:
76
+ st.write(f"**{name}**")
77
+ with st.spinner(f"Running {name}..."):
78
+ results = process_classification(model, image, conf_threshold)
79
+ for result in results:
80
+ st.write(f"• {result}")
81
+
82
+ with col2:
83
+ st.subheader("Object Detection Models")
84
+ for name, model in models.items():
85
+ if "Object Detection" in name:
86
+ st.write(f"**{name}**")
87
+ with st.spinner(f"Running {name}..."):
88
+ detections = process_detection(model, image, conf_threshold)
89
+ if detections:
90
+ result_image = image.copy()
91
+ result_image = draw_boxes(result_image, detections)
92
+ st.image(result_image, caption=f"Results from {name}")
93
+ for det in detections:
94
+ st.write(f"• {det['label']}: {det['score']:.2%}")
95
+ else:
96
+ st.write("No detections above threshold")
97
 
98
  if __name__ == "__main__":
99
  main()