yassonee commited on
Commit
04a7bfd
·
verified ·
1 Parent(s): 806ecee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -100
app.py CHANGED
@@ -1,25 +1,28 @@
1
  import streamlit as st
2
  from transformers import pipeline
3
- from PIL import Image, ImageDraw
4
  import torch
 
 
 
 
 
 
 
 
5
 
6
- st.set_page_config(
7
- page_title="Knochenbrucherkennung",
8
- layout="wide",
9
- initial_sidebar_state="collapsed"
10
- )
11
 
12
- st.markdown("""
13
- <style>
14
- .main > div {
15
- padding: 2rem;
16
- background: #f8f9fa;
17
- border-radius: 1rem;
18
- box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
19
- }
20
- </style>
21
- """, unsafe_allow_html=True)
22
 
 
23
  @st.cache_resource
24
  def load_models():
25
  return {
@@ -29,104 +32,78 @@ def load_models():
29
  model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
30
  }
31
 
32
- def translate_label(label):
33
- translations = {
34
- "fracture": "Knochenbruch",
35
- "no fracture": "Kein Bruch",
36
- "normal": "Normal",
37
- "abnormal": "Abnormal"
38
- }
39
- for eng, deu in translations.items():
40
- if eng.lower() in label.lower():
41
- return deu
42
- return label
43
 
44
- def draw_boxes(image, predictions):
45
  draw = ImageDraw.Draw(image)
46
- for pred in predictions:
 
 
47
  box = pred['box']
48
- label = f"{translate_label(pred['label'])} ({pred['score']:.2%})"
49
 
50
  draw.rectangle(
51
  [(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
52
- outline="#FF6B6B",
53
  width=2
54
  )
55
 
56
- text_bbox = draw.textbbox((box['xmin'], box['ymin']), label)
57
- draw.rectangle(text_bbox, fill="#FF6B6B")
58
- draw.text((box['xmin'], box['ymin']), label, fill="white")
59
- return image
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  def main():
62
- st.title("🦴 Knochenbrucherkennung System")
63
-
64
- models = load_models()
65
 
66
- with st.expander("⚙️ Einstellungen", expanded=True):
67
- conf_threshold = st.slider(
68
- "Konfidenzschwelle",
69
- min_value=0.0,
70
- max_value=1.0,
71
- value=0.60,
72
- step=0.01
73
- )
74
 
75
- uploaded_file = st.file_uploader(
76
- "Röntgenbild hochladen",
77
- type=['png', 'jpg', 'jpeg'],
78
- key="xray_upload"
79
- )
80
-
81
  if uploaded_file:
82
- col1, col2 = st.columns([1, 2])
83
-
84
- with col1:
85
- image = Image.open(uploaded_file)
86
- max_size = (250, 250)
87
- image.thumbnail(max_size, Image.Resampling.LANCZOS)
88
- st.image(image, caption="Original Röntgenbild", use_container_width=True)
89
-
90
- with col2:
91
- tab1, tab2 = st.tabs(["📊 Klassifizierung", "🔍 Erkennung"])
92
-
93
- with tab1:
94
- for name in ["Heem2", "Nandodeomkar"]:
95
- with st.container():
96
- st.subheader(f"Modell: {name}")
97
- with st.spinner("Analyse läuft..."):
98
- predictions = models[name](image)
99
- for pred in predictions:
100
- if pred['score'] >= conf_threshold:
101
- score_color = "green" if pred['score'] > 0.7 else "orange"
102
- st.markdown(f"""
103
- <div style='padding: 10px; border-radius: 5px; background-color: #f0f2f6;'>
104
- <span style='color: {score_color}; font-weight: bold;'>
105
- {pred['score']:.1%}
106
- </span> - {translate_label(pred['label'])}
107
- </div>
108
- """, unsafe_allow_html=True)
109
-
110
- with tab2:
111
- st.subheader("Modell: D3STRON")
112
- with st.spinner("Erkennung läuft..."):
113
- predictions = models["D3STRON"](image)
114
- filtered_preds = [p for p in predictions if p['score'] >= conf_threshold]
115
-
116
- if filtered_preds:
117
- result_image = image.copy()
118
- result_image = draw_boxes(result_image, filtered_preds)
119
- st.image(result_image, use_container_width=True)
120
-
121
- for pred in filtered_preds:
122
- st.markdown(f"""
123
- <div style='padding: 8px; border-left: 4px solid #FF6B6B;
124
- margin: 5px 0; background-color: #f0f2f6;'>
125
- {translate_label(pred['label'])}: {pred['score']:.1%}
126
- </div>
127
- """, unsafe_allow_html=True)
128
- else:
129
- st.info("Keine Erkennungen über dem Schwellenwert")
130
 
131
  if __name__ == "__main__":
132
  main()
 
1
  import streamlit as st
2
  from transformers import pipeline
 
3
  import torch
4
+ from PIL import Image, ImageDraw
5
+ import io
6
+ import base64
7
+ from fastapi import FastAPI, File, UploadFile
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ import numpy as np
10
+ import json
11
+ from starlette.responses import JSONResponse
12
 
13
+ # FastAPI app
14
+ app = FastAPI()
 
 
 
15
 
16
+ # Enable CORS
17
+ app.add_middleware(
18
+ CORSMiddleware,
19
+ allow_origins=["*"],
20
+ allow_credentials=True,
21
+ allow_methods=["*"],
22
+ allow_headers=["*"],
23
+ )
 
 
24
 
25
+ # Load models
26
  @st.cache_resource
27
  def load_models():
28
  return {
 
32
  model="nandodeomkar/autotrain-fracture-detection-using-google-vit-base-patch-16-54382127388")
33
  }
34
 
35
+ models = load_models()
 
 
 
 
 
 
 
 
 
 
36
 
37
+ def draw_boxes(image, predictions, threshold=0.6):
38
  draw = ImageDraw.Draw(image)
39
+ filtered_preds = [p for p in predictions if p['score'] >= threshold]
40
+
41
+ for pred in filtered_preds:
42
  box = pred['box']
43
+ label = f"{pred['label']} ({pred['score']:.2%})"
44
 
45
  draw.rectangle(
46
  [(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
47
+ outline="red",
48
  width=2
49
  )
50
 
51
+ draw.text((box['xmin'], box['ymin']), label, fill="red")
52
+
53
+ return image, filtered_preds
 
54
 
55
+ # API Endpoint
56
+ @app.post("/detect")
57
+ async def detect_fracture(file: UploadFile = File(...), confidence: float = 0.6):
58
+ try:
59
+ # Read and process image
60
+ contents = await file.read()
61
+ image = Image.open(io.BytesIO(contents))
62
+
63
+ # Get predictions from all models
64
+ results = {}
65
+
66
+ # Object detection models
67
+ detection_preds = models["D3STRON"](image)
68
+ result_image = image.copy()
69
+ result_image, filtered_detections = draw_boxes(result_image, detection_preds, confidence)
70
+
71
+ # Save result image
72
+ img_byte_arr = io.BytesIO()
73
+ result_image.save(img_byte_arr, format='PNG')
74
+ img_byte_arr = img_byte_arr.getvalue()
75
+ img_b64 = base64.b64encode(img_byte_arr).decode()
76
+
77
+ # Classification models
78
+ class_results = {
79
+ "Heem2": models["Heem2"](image),
80
+ "Nandodeomkar": models["Nandodeomkar"](image)
81
+ }
82
+
83
+ return JSONResponse({
84
+ "success": True,
85
+ "detections": filtered_detections,
86
+ "classifications": class_results,
87
+ "image": img_b64
88
+ })
89
+
90
+ except Exception as e:
91
+ return JSONResponse({
92
+ "success": False,
93
+ "error": str(e)
94
+ })
95
+
96
+ # Streamlit UI
97
  def main():
98
+ st.title("🦴 Fraktur Detektion")
 
 
99
 
100
+ # UI elements...
101
+ uploaded_file = st.file_uploader("Röntgenbild hochladen", type=['png', 'jpg', 'jpeg'])
102
+ confidence = st.slider("Konfidenzschwelle", 0.0, 1.0, 0.6, 0.05)
 
 
 
 
 
103
 
 
 
 
 
 
 
104
  if uploaded_file:
105
+ # Process image and display results...
106
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  if __name__ == "__main__":
109
  main()