yassonee commited on
Commit
6cc7ff9
·
verified ·
1 Parent(s): d81ec21

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -62
app.py CHANGED
@@ -3,20 +3,45 @@ from transformers import pipeline
3
  from PIL import Image, ImageDraw
4
  import torch
5
 
6
- st.set_page_config(page_title="Multi-Model Fracture Detection", layout="wide")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  @st.cache_resource
9
  def load_models():
10
- models = {
11
- "D3STRON (Object Detection)": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
12
- "Heem2 (Classification)": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"),
13
- "Akhileshav8 (Classification)": pipeline("image-classification", model="akhileshav8/image_classification_for_fracture"),
14
- "Nandodeomkar (Classification)": pipeline("image-classification", model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388"),
15
- "Anirban22 (Object Detection)": pipeline("object-detection", model="anirban22/detr-resnet-50-med_fracture")
16
  }
17
- return models
18
 
19
- def draw_boxes(image, predictions):
20
  draw = ImageDraw.Draw(image)
21
  for pred in predictions:
22
  box = pred['box']
@@ -24,76 +49,89 @@ def draw_boxes(image, predictions):
24
 
25
  draw.rectangle(
26
  [(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
27
- outline="red",
28
- width=3
29
  )
30
 
 
31
  text_bbox = draw.textbbox((box['xmin'], box['ymin']), label)
32
- draw.rectangle(text_bbox, fill="red")
33
  draw.text((box['xmin'], box['ymin']), label, fill="white")
34
  return image
35
 
36
- def process_classification(model, image, conf_threshold):
37
- predictions = model(image)
38
- results = []
39
- for pred in predictions:
40
- if pred['score'] >= conf_threshold:
41
- results.append(f"{pred['label']}: {pred['score']:.2%}")
42
- return results
43
-
44
- def process_detection(model, image, conf_threshold):
45
- predictions = model(image)
46
- return [pred for pred in predictions if pred['score'] >= conf_threshold]
47
-
48
  def main():
49
- st.title("🦴 Multi-Model Fracture Detection")
50
 
51
  models = load_models()
52
 
53
- uploaded_file = st.file_uploader("Upload X-ray image", type=['png', 'jpg', 'jpeg'])
 
 
 
 
 
 
 
54
 
55
- conf_threshold = st.slider(
56
- "Confidence threshold",
57
- min_value=0.0,
58
- max_value=1.0,
59
- value=0.3,
60
- step=0.01
61
  )
62
 
63
  if uploaded_file:
64
- image = Image.open(uploaded_file)
65
- max_size = (400, 400)
66
- image.thumbnail(max_size, Image.Resampling.LANCZOS)
67
-
68
- st.image(image, caption="Original Image", width=400)
69
-
70
- col1, col2 = st.columns(2)
71
 
72
  with col1:
73
- st.subheader("Classification Models")
74
- for name, model in models.items():
75
- if "Classification" in name:
76
- st.write(f"**{name}**")
77
- with st.spinner(f"Running {name}..."):
78
- results = process_classification(model, image, conf_threshold)
79
- for result in results:
80
- st.write(f"• {result}")
81
-
82
  with col2:
83
- st.subheader("Object Detection Models")
84
- for name, model in models.items():
85
- if "Object Detection" in name:
86
- st.write(f"**{name}**")
87
- with st.spinner(f"Running {name}..."):
88
- detections = process_detection(model, image, conf_threshold)
89
- if detections:
90
- result_image = image.copy()
91
- result_image = draw_boxes(result_image, detections)
92
- st.image(result_image, caption=f"Results from {name}")
93
- for det in detections:
94
- st.write(f"• {det['label']}: {det['score']:.2%}")
95
- else:
96
- st.write("No detections above threshold")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  if __name__ == "__main__":
99
  main()
 
3
  from PIL import Image, ImageDraw
4
  import torch
5
 
6
+ st.set_page_config(
7
+ page_title="Fracture Detection",
8
+ layout="wide",
9
+ initial_sidebar_state="collapsed"
10
+ )
11
+
12
+ # Custom CSS
13
+ st.markdown("""
14
+ <style>
15
+ .main > div {
16
+ padding: 2rem;
17
+ background: #f8f9fa;
18
+ border-radius: 1rem;
19
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
20
+ }
21
+ .stButton button {
22
+ width: 100%;
23
+ border-radius: 0.5rem;
24
+ }
25
+ .uploadedFile {
26
+ border-radius: 0.5rem;
27
+ }
28
+ h1, h2, h3 {
29
+ color: #2c3e50;
30
+ }
31
+ </style>
32
+ """, unsafe_allow_html=True)
33
 
34
  @st.cache_resource
35
  def load_models():
36
+ return {
37
+ "D3STRON": pipeline("object-detection", model="D3STRON/bone-fracture-detr"),
38
+ "Heem2": pipeline("image-classification", model="Heem2/bone-fracture-detection-using-xray"),
39
+ "Nandodeomkar": pipeline("image-classification",
40
+ model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388"),
41
+ "Judy07": pipeline("object-detection", model="Judy07/bone-fracture-DETA")
42
  }
 
43
 
44
+ def draw_boxes(image, predictions, color):
45
  draw = ImageDraw.Draw(image)
46
  for pred in predictions:
47
  box = pred['box']
 
49
 
50
  draw.rectangle(
51
  [(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
52
+ outline=color,
53
+ width=2
54
  )
55
 
56
+ # Label background
57
  text_bbox = draw.textbbox((box['xmin'], box['ymin']), label)
58
+ draw.rectangle(text_bbox, fill=color)
59
  draw.text((box['xmin'], box['ymin']), label, fill="white")
60
  return image
61
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def main():
63
+ st.title("🦴 Advanced Fracture Detection System")
64
 
65
  models = load_models()
66
 
67
+ with st.expander("⚙️ Settings", expanded=True):
68
+ conf_threshold = st.slider(
69
+ "Confidence threshold",
70
+ min_value=0.0,
71
+ max_value=1.0,
72
+ value=0.3,
73
+ step=0.01
74
+ )
75
 
76
+ uploaded_file = st.file_uploader(
77
+ "Upload X-ray image",
78
+ type=['png', 'jpg', 'jpeg'],
79
+ key="xray_upload"
 
 
80
  )
81
 
82
  if uploaded_file:
83
+ col1, col2 = st.columns([1, 2])
 
 
 
 
 
 
84
 
85
  with col1:
86
+ image = Image.open(uploaded_file)
87
+ max_size = (300, 300)
88
+ image.thumbnail(max_size, Image.Resampling.LANCZOS)
89
+ st.image(image, caption="Original X-ray", use_column_width=True)
90
+
 
 
 
 
91
  with col2:
92
+ tab1, tab2 = st.tabs(["📊 Classifications", "🔍 Detections"])
93
+
94
+ with tab1:
95
+ for name in ["Heem2", "Nandodeomkar"]:
96
+ with st.container():
97
+ st.subheader(f"Model: {name}")
98
+ with st.spinner("Analyzing..."):
99
+ predictions = models[name](image)
100
+ for pred in predictions:
101
+ if pred['score'] >= conf_threshold:
102
+ score_color = "green" if pred['score'] > 0.7 else "orange"
103
+ st.markdown(f"""
104
+ <div style='padding: 10px; border-radius: 5px; background-color: #f0f2f6;'>
105
+ <span style='color: {score_color}; font-weight: bold;'>
106
+ {pred['score']:.1%}
107
+ </span> - {pred['label']}
108
+ </div>
109
+ """, unsafe_allow_html=True)
110
+
111
+ with tab2:
112
+ detection_colors = {"D3STRON": "#FF6B6B", "Judy07": "#4ECDC4"}
113
+
114
+ for name, color in detection_colors.items():
115
+ with st.container():
116
+ st.subheader(f"Model: {name}")
117
+ with st.spinner("Detecting..."):
118
+ predictions = models[name](image)
119
+ filtered_preds = [p for p in predictions if p['score'] >= conf_threshold]
120
+
121
+ if filtered_preds:
122
+ result_image = image.copy()
123
+ result_image = draw_boxes(result_image, filtered_preds, color)
124
+ st.image(result_image, use_column_width=True)
125
+
126
+ for pred in filtered_preds:
127
+ st.markdown(f"""
128
+ <div style='padding: 8px; border-left: 4px solid {color};
129
+ margin: 5px 0; background-color: #f0f2f6;'>
130
+ {pred['label']}: {pred['score']:.1%}
131
+ </div>
132
+ """, unsafe_allow_html=True)
133
+ else:
134
+ st.info("No detections above threshold")
135
 
136
  if __name__ == "__main__":
137
  main()