yassonee commited on
Commit
3982789
·
verified ·
1 Parent(s): d119a53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -77
app.py CHANGED
@@ -1,28 +1,25 @@
1
  import streamlit as st
2
  from transformers import pipeline
3
- import torch
4
  from PIL import Image, ImageDraw
5
- import io
6
- import base64
7
- from fastapi import FastAPI, File, UploadFile
8
- from fastapi.middleware.cors import CORSMiddleware
9
- import numpy as np
10
- import json
11
- from starlette.responses import JSONResponse
12
-
13
- # FastAPI app
14
- app = FastAPI()
15
 
16
- # Enable CORS
17
- app.add_middleware(
18
- CORSMiddleware,
19
- allow_origins=["*"],
20
- allow_credentials=True,
21
- allow_methods=["*"],
22
- allow_headers=["*"],
23
  )
24
 
25
- # Load models
 
 
 
 
 
 
 
 
 
 
26
  @st.cache_resource
27
  def load_models():
28
  return {
@@ -32,78 +29,104 @@ def load_models():
32
  model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
33
  }
34
 
35
- models = load_models()
 
 
 
 
 
 
 
 
 
 
36
 
37
- def draw_boxes(image, predictions, threshold=0.6):
38
  draw = ImageDraw.Draw(image)
39
- filtered_preds = [p for p in predictions if p['score'] >= threshold]
40
-
41
- for pred in filtered_preds:
42
  box = pred['box']
43
- label = f"{pred['label']} ({pred['score']:.2%})"
44
 
45
  draw.rectangle(
46
  [(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
47
- outline="red",
48
  width=2
49
  )
50
 
51
- draw.text((box['xmin'], box['ymin']), label, fill="red")
52
-
53
- return image, filtered_preds
 
54
 
55
- # API Endpoint
56
- @app.post("/detect")
57
- async def detect_fracture(file: UploadFile = File(...), confidence: float = 0.6):
58
- try:
59
- # Read and process image
60
- contents = await file.read()
61
- image = Image.open(io.BytesIO(contents))
62
-
63
- # Get predictions from all models
64
- results = {}
65
-
66
- # Object detection models
67
- detection_preds = models["D3STRON"](image)
68
- result_image = image.copy()
69
- result_image, filtered_detections = draw_boxes(result_image, detection_preds, confidence)
70
-
71
- # Save result image
72
- img_byte_arr = io.BytesIO()
73
- result_image.save(img_byte_arr, format='PNG')
74
- img_byte_arr = img_byte_arr.getvalue()
75
- img_b64 = base64.b64encode(img_byte_arr).decode()
76
-
77
- # Classification models
78
- class_results = {
79
- "Heem2": models["Heem2"](image),
80
- "Nandodeomkar": models["Nandodeomkar"](image)
81
- }
82
-
83
- return JSONResponse({
84
- "success": True,
85
- "detections": filtered_detections,
86
- "classifications": class_results,
87
- "image": img_b64
88
- })
89
-
90
- except Exception as e:
91
- return JSONResponse({
92
- "success": False,
93
- "error": str(e)
94
- })
95
-
96
- # Streamlit UI
97
  def main():
98
- st.title("🦴 Fraktur Detektion")
 
 
99
 
100
- # UI elements...
101
- uploaded_file = st.file_uploader("Röntgenbild hochladen", type=['png', 'jpg', 'jpeg'])
102
- confidence = st.slider("Konfidenzschwelle", 0.0, 1.0, 0.6, 0.05)
 
 
 
 
 
103
 
 
 
 
 
 
 
104
  if uploaded_file:
105
- # Process image and display results...
106
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  if __name__ == "__main__":
109
  main()
 
1
  import streamlit as st
2
  from transformers import pipeline
 
3
  from PIL import Image, ImageDraw
4
+ import torch
 
 
 
 
 
 
 
 
 
5
 
6
+ st.set_page_config(
7
+ page_title="Knochenbrucherkennung",
8
+ layout="wide",
9
+ initial_sidebar_state="collapsed"
 
 
 
10
  )
11
 
12
+ st.markdown("""
13
+ <style>
14
+ .main > div {
15
+ padding: 2rem;
16
+ background: #f8f9fa;
17
+ border-radius: 1rem;
18
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
19
+ }
20
+ </style>
21
+ """, unsafe_allow_html=True)
22
+
23
  @st.cache_resource
24
  def load_models():
25
  return {
 
29
  model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
30
  }
31
 
32
+ def translate_label(label):
33
+ translations = {
34
+ "fracture": "Knochenbruch",
35
+ "no fracture": "Kein Bruch",
36
+ "normal": "Normal",
37
+ "abnormal": "Abnormal"
38
+ }
39
+ for eng, deu in translations.items():
40
+ if eng.lower() in label.lower():
41
+ return deu
42
+ return label
43
 
44
+ def draw_boxes(image, predictions):
45
  draw = ImageDraw.Draw(image)
46
+ for pred in predictions:
 
 
47
  box = pred['box']
48
+ label = f"{translate_label(pred['label'])} ({pred['score']:.2%})"
49
 
50
  draw.rectangle(
51
  [(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
52
+ outline="#FF6B6B",
53
  width=2
54
  )
55
 
56
+ text_bbox = draw.textbbox((box['xmin'], box['ymin']), label)
57
+ draw.rectangle(text_bbox, fill="#FF6B6B")
58
+ draw.text((box['xmin'], box['ymin']), label, fill="white")
59
+ return image
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  def main():
62
+ st.title("🦴 Knochenbrucherkennung System")
63
+
64
+ models = load_models()
65
 
66
+ with st.expander("⚙️ Einstellungen", expanded=True):
67
+ conf_threshold = st.slider(
68
+ "Konfidenzschwelle",
69
+ min_value=0.0,
70
+ max_value=1.0,
71
+ value=0.60,
72
+ step=0.01
73
+ )
74
 
75
+ uploaded_file = st.file_uploader(
76
+ "Röntgenbild hochladen",
77
+ type=['png', 'jpg', 'jpeg'],
78
+ key="xray_upload"
79
+ )
80
+
81
  if uploaded_file:
82
+ col1, col2 = st.columns([1, 2])
83
+
84
+ with col1:
85
+ image = Image.open(uploaded_file)
86
+ max_size = (250, 250)
87
+ image.thumbnail(max_size, Image.Resampling.LANCZOS)
88
+ st.image(image, caption="Original Röntgenbild", use_container_width=True)
89
+
90
+ with col2:
91
+ tab1, tab2 = st.tabs(["📊 Klassifizierung", "🔍 Erkennung"])
92
+
93
+ with tab1:
94
+ for name in ["Heem2", "Nandodeomkar"]:
95
+ with st.container():
96
+ st.subheader(f"Modell: {name}")
97
+ with st.spinner("Analyse läuft..."):
98
+ predictions = models[name](image)
99
+ for pred in predictions:
100
+ if pred['score'] >= conf_threshold:
101
+ score_color = "green" if pred['score'] > 0.7 else "orange"
102
+ st.markdown(f"""
103
+ <div style='padding: 10px; border-radius: 5px; background-color: #f0f2f6;'>
104
+ <span style='color: {score_color}; font-weight: bold;'>
105
+ {pred['score']:.1%}
106
+ </span> - {translate_label(pred['label'])}
107
+ </div>
108
+ """, unsafe_allow_html=True)
109
+
110
+ with tab2:
111
+ st.subheader("Modell: D3STRON")
112
+ with st.spinner("Erkennung läuft..."):
113
+ predictions = models["D3STRON"](image)
114
+ filtered_preds = [p for p in predictions if p['score'] >= conf_threshold]
115
+
116
+ if filtered_preds:
117
+ result_image = image.copy()
118
+ result_image = draw_boxes(result_image, filtered_preds)
119
+ st.image(result_image, use_container_width=True)
120
+
121
+ for pred in filtered_preds:
122
+ st.markdown(f"""
123
+ <div style='padding: 8px; border-left: 4px solid #FF6B6B;
124
+ margin: 5px 0; background-color: #f0f2f6;'>
125
+ {translate_label(pred['label'])}: {pred['score']:.1%}
126
+ </div>
127
+ """, unsafe_allow_html=True)
128
+ else:
129
+ st.info("Keine Erkennungen über dem Schwellenwert")
130
 
131
  if __name__ == "__main__":
132
  main()