yassonee commited on
Commit
806ecee
·
verified ·
1 Parent(s): 6cc7ff9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -53
app.py CHANGED
@@ -4,12 +4,11 @@ from PIL import Image, ImageDraw
4
  import torch
5
 
6
  st.set_page_config(
7
- page_title="Fracture Detection",
8
  layout="wide",
9
  initial_sidebar_state="collapsed"
10
  )
11
 
12
- # Custom CSS
13
  st.markdown("""
14
  <style>
15
  .main > div {
@@ -18,16 +17,6 @@ st.markdown("""
18
  border-radius: 1rem;
19
  box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
20
  }
21
- .stButton button {
22
- width: 100%;
23
- border-radius: 0.5rem;
24
- }
25
- .uploadedFile {
26
- border-radius: 0.5rem;
27
- }
28
- h1, h2, h3 {
29
- color: #2c3e50;
30
- }
31
  </style>
32
  """, unsafe_allow_html=True)
33
 
@@ -37,44 +26,54 @@ def load_models():
37
  "D3STRON": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
38
  "Heem2": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"),
39
  "Nandodeomkar": pipeline("image-classification",
40
- model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388"),
41
- "Judy07": pipeline("object-detection", model="Judy07/bone-fracture-DETA")
 
 
 
 
 
 
 
42
  }
 
 
 
 
43
 
44
- def draw_boxes(image, predictions, color):
45
  draw = ImageDraw.Draw(image)
46
  for pred in predictions:
47
  box = pred['box']
48
- label = f"{pred['label']} ({pred['score']:.2%})"
49
 
50
  draw.rectangle(
51
  [(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
52
- outline=color,
53
  width=2
54
  )
55
 
56
- # Label background
57
  text_bbox = draw.textbbox((box['xmin'], box['ymin']), label)
58
- draw.rectangle(text_bbox, fill=color)
59
  draw.text((box['xmin'], box['ymin']), label, fill="white")
60
  return image
61
 
62
  def main():
63
- st.title("🦴 Advanced Fracture Detection System")
64
 
65
  models = load_models()
66
 
67
- with st.expander("⚙️ Settings", expanded=True):
68
  conf_threshold = st.slider(
69
- "Confidence threshold",
70
  min_value=0.0,
71
  max_value=1.0,
72
- value=0.3,
73
  step=0.01
74
  )
75
 
76
  uploaded_file = st.file_uploader(
77
- "Upload X-ray image",
78
  type=['png', 'jpg', 'jpeg'],
79
  key="xray_upload"
80
  )
@@ -84,18 +83,18 @@ def main():
84
 
85
  with col1:
86
  image = Image.open(uploaded_file)
87
- max_size = (300, 300)
88
  image.thumbnail(max_size, Image.Resampling.LANCZOS)
89
- st.image(image, caption="Original X-ray", use_column_width=True)
90
 
91
  with col2:
92
- tab1, tab2 = st.tabs(["📊 Classifications", "🔍 Detections"])
93
 
94
  with tab1:
95
  for name in ["Heem2", "Nandodeomkar"]:
96
  with st.container():
97
- st.subheader(f"Model: {name}")
98
- with st.spinner("Analyzing..."):
99
  predictions = models[name](image)
100
  for pred in predictions:
101
  if pred['score'] >= conf_threshold:
@@ -104,34 +103,30 @@ def main():
104
  <div style='padding: 10px; border-radius: 5px; background-color: #f0f2f6;'>
105
  <span style='color: {score_color}; font-weight: bold;'>
106
  {pred['score']:.1%}
107
- </span> - {pred['label']}
108
  </div>
109
  """, unsafe_allow_html=True)
110
 
111
  with tab2:
112
- detection_colors = {"D3STRON": "#FF6B6B", "Judy07": "#4ECDC4"}
113
-
114
- for name, color in detection_colors.items():
115
- with st.container():
116
- st.subheader(f"Model: {name}")
117
- with st.spinner("Detecting..."):
118
- predictions = models[name](image)
119
- filtered_preds = [p for p in predictions if p['score'] >= conf_threshold]
120
-
121
- if filtered_preds:
122
- result_image = image.copy()
123
- result_image = draw_boxes(result_image, filtered_preds, color)
124
- st.image(result_image, use_column_width=True)
125
-
126
- for pred in filtered_preds:
127
- st.markdown(f"""
128
- <div style='padding: 8px; border-left: 4px solid {color};
129
- margin: 5px 0; background-color: #f0f2f6;'>
130
- {pred['label']}: {pred['score']:.1%}
131
- </div>
132
- """, unsafe_allow_html=True)
133
- else:
134
- st.info("No detections above threshold")
135
 
136
  if __name__ == "__main__":
137
  main()
 
4
  import torch
5
 
6
  st.set_page_config(
7
+ page_title="Knochenbrucherkennung",
8
  layout="wide",
9
  initial_sidebar_state="collapsed"
10
  )
11
 
 
12
  st.markdown("""
13
  <style>
14
  .main > div {
 
17
  border-radius: 1rem;
18
  box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
19
  }
 
 
 
 
 
 
 
 
 
 
20
  </style>
21
  """, unsafe_allow_html=True)
22
 
 
26
  "D3STRON": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
27
  "Heem2": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"),
28
  "Nandodeomkar": pipeline("image-classification",
29
+ model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
30
+ }
31
+
32
+ def translate_label(label):
33
+ translations = {
34
+ "fracture": "Knochenbruch",
35
+ "no fracture": "Kein Bruch",
36
+ "normal": "Normal",
37
+ "abnormal": "Abnormal"
38
  }
39
+ for eng, deu in translations.items():
40
+ if eng.lower() in label.lower():
41
+ return deu
42
+ return label
43
 
44
+ def draw_boxes(image, predictions):
45
  draw = ImageDraw.Draw(image)
46
  for pred in predictions:
47
  box = pred['box']
48
+ label = f"{translate_label(pred['label'])} ({pred['score']:.2%})"
49
 
50
  draw.rectangle(
51
  [(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
52
+ outline="#FF6B6B",
53
  width=2
54
  )
55
 
 
56
  text_bbox = draw.textbbox((box['xmin'], box['ymin']), label)
57
+ draw.rectangle(text_bbox, fill="#FF6B6B")
58
  draw.text((box['xmin'], box['ymin']), label, fill="white")
59
  return image
60
 
61
  def main():
62
+ st.title("🦴 Knochenbrucherkennung System")
63
 
64
  models = load_models()
65
 
66
+ with st.expander("⚙️ Einstellungen", expanded=True):
67
  conf_threshold = st.slider(
68
+ "Konfidenzschwelle",
69
  min_value=0.0,
70
  max_value=1.0,
71
+ value=0.60,
72
  step=0.01
73
  )
74
 
75
  uploaded_file = st.file_uploader(
76
+ "Röntgenbild hochladen",
77
  type=['png', 'jpg', 'jpeg'],
78
  key="xray_upload"
79
  )
 
83
 
84
  with col1:
85
  image = Image.open(uploaded_file)
86
+ max_size = (250, 250)
87
  image.thumbnail(max_size, Image.Resampling.LANCZOS)
88
+ st.image(image, caption="Original Röntgenbild", use_container_width=True)
89
 
90
  with col2:
91
+ tab1, tab2 = st.tabs(["📊 Klassifizierung", "🔍 Erkennung"])
92
 
93
  with tab1:
94
  for name in ["Heem2", "Nandodeomkar"]:
95
  with st.container():
96
+ st.subheader(f"Modell: {name}")
97
+ with st.spinner("Analyse läuft..."):
98
  predictions = models[name](image)
99
  for pred in predictions:
100
  if pred['score'] >= conf_threshold:
 
103
  <div style='padding: 10px; border-radius: 5px; background-color: #f0f2f6;'>
104
  <span style='color: {score_color}; font-weight: bold;'>
105
  {pred['score']:.1%}
106
+ </span> - {translate_label(pred['label'])}
107
  </div>
108
  """, unsafe_allow_html=True)
109
 
110
  with tab2:
111
+ st.subheader("Modell: D3STRON")
112
+ with st.spinner("Erkennung läuft..."):
113
+ predictions = models["D3STRON"](image)
114
+ filtered_preds = [p for p in predictions if p['score'] >= conf_threshold]
115
+
116
+ if filtered_preds:
117
+ result_image = image.copy()
118
+ result_image = draw_boxes(result_image, filtered_preds)
119
+ st.image(result_image, use_container_width=True)
120
+
121
+ for pred in filtered_preds:
122
+ st.markdown(f"""
123
+ <div style='padding: 8px; border-left: 4px solid #FF6B6B;
124
+ margin: 5px 0; background-color: #f0f2f6;'>
125
+ {translate_label(pred['label'])}: {pred['score']:.1%}
126
+ </div>
127
+ """, unsafe_allow_html=True)
128
+ else:
129
+ st.info("Keine Erkennungen über dem Schwellenwert")
 
 
 
 
130
 
131
  if __name__ == "__main__":
132
  main()