yassonee commited on
Commit
3bb1400
·
verified ·
1 Parent(s): c2384bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -54
app.py CHANGED
@@ -1,69 +1,98 @@
1
  import streamlit as st
2
- from transformers import AutoModelForObjectDetection
3
  import torch
4
- from PIL import Image
5
- import numpy as np
6
- import cv2
7
 
8
- st.set_page_config(page_title="Détection de nodules pulmonaires")
9
- st.title("Détection de nodules pulmonaires sur images scanner")
10
 
11
  @st.cache_resource
12
  def load_model():
13
- model = AutoModelForObjectDetection.from_pretrained("monai-test/lung_nodule_ct_detection")
14
- model.eval()
15
- return model
16
 
17
- def process_image(image):
18
- # Convertir en niveau de gris
19
- img_array = np.array(image.convert('L'))
20
- # Normaliser
21
- normalized = (img_array - img_array.min()) / (img_array.max() - img_array.min())
22
- # Redimensionner
23
- resized = cv2.resize(normalized, (512, 512))
24
- # Préparer pour PyTorch
25
- tensor = torch.FloatTensor(resized).unsqueeze(0).unsqueeze(0)
26
- return tensor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- try:
29
- model = load_model()
 
 
 
 
30
 
31
- uploaded_file = st.file_uploader("Téléchargez une image scanner", type=["jpg", "jpeg", "png"])
 
 
 
 
 
 
32
 
33
  if uploaded_file:
34
- image = Image.open(uploaded_file)
35
-
36
  col1, col2 = st.columns(2)
37
 
38
- with col1:
39
- st.image(image, caption="Image originale", use_container_width=True)
40
-
41
- with col2:
42
- with torch.no_grad():
43
- input_tensor = process_image(image)
44
- predictions = model(input_tensor)
45
-
46
- # Visualisation
47
- img_array = np.array(image)
48
- for pred in predictions:
49
- if pred['score'] > 0.5:
50
- box = pred['box']
51
- x1, y1, x2, y2 = map(int, [box['xmin'], box['ymin'], box['xmax'], box['ymax']])
52
- cv2.rectangle(img_array, (x1, y1), (x2, y2), (255, 0, 0), 2)
53
- text = f"Nodule: {pred['score']:.2f}"
54
- cv2.putText(img_array, text, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
55
 
56
- st.image(img_array, caption="Détections", use_container_width=True)
57
-
58
- # Résultats
59
- if len(predictions) > 0:
60
- st.warning(f"⚠️ {len(predictions)} nodules détectés")
61
- for i, pred in enumerate(predictions, 1):
62
- if pred['score'] > 0.5:
63
- st.write(f"Nodule {i}: Confiance {pred['score']:.1%}")
64
- else:
65
- st.success("✅ Aucun nodule détecté")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
- except Exception as e:
68
- st.error(f"Erreur lors du chargement du modèle: {str(e)}")
69
- st.info("Veuillez vérifier que le modèle est correctement configuré sur Hugging Face.")
 
1
  import streamlit as st
2
+ from transformers import pipeline
3
  import torch
4
+ from PIL import Image, ImageDraw
5
+ import io
 
6
 
7
+ st.set_page_config(page_title="Détection de Fractures Osseuses", layout="wide")
 
8
 
9
  @st.cache_resource
10
  def load_model():
11
+ return pipeline("object-detection", model="D3STRON/bone-fracture-detr")
 
 
12
 
13
+ def draw_boxes(image, predictions):
14
+ draw = ImageDraw.Draw(image)
15
+ for pred in predictions:
16
+ box = pred['box']
17
+ label = f"{pred['label']} ({pred['score']:.2%})"
18
+
19
+ # Draw bounding box
20
+ draw.rectangle(
21
+ [(box['xmin'], box['ymin']), (box['xmax'], box['ymax'])],
22
+ outline="red",
23
+ width=3
24
+ )
25
+
26
+ # Draw label background
27
+ text_bbox = draw.textbbox((box['xmin'], box['ymin']), label)
28
+ draw.rectangle(text_bbox, fill="red")
29
+
30
+ # Draw label text
31
+ draw.text(
32
+ (box['xmin'], box['ymin']),
33
+ label,
34
+ fill="white"
35
+ )
36
+ return image
37
+
38
+ def main():
39
+ st.title("🦴 Détecteur de Fractures Osseuses")
40
+ st.write("Téléchargez une radiographie pour détecter les fractures osseuses.")
41
 
42
+ pipe = load_model()
43
+
44
+ uploaded_file = st.file_uploader(
45
+ "Choisissez une image de radiographie",
46
+ type=['png', 'jpg', 'jpeg']
47
+ )
48
 
49
+ conf_threshold = st.slider(
50
+ "Seuil de confiance",
51
+ min_value=0.0,
52
+ max_value=1.0,
53
+ value=0.5,
54
+ step=0.05
55
+ )
56
 
57
  if uploaded_file:
 
 
58
  col1, col2 = st.columns(2)
59
 
60
+ # Original image
61
+ image = Image.open(uploaded_file)
62
+ col1.header("Image originale")
63
+ col1.image(image)
64
+
65
+ # Process image
66
+ with st.spinner("Analyse en cours..."):
67
+ predictions = pipe(image)
 
 
 
 
 
 
 
 
 
68
 
69
+ # Filter predictions based on confidence threshold
70
+ filtered_preds = [
71
+ pred for pred in predictions
72
+ if pred['score'] >= conf_threshold
73
+ ]
74
+
75
+ # Draw boxes on a copy of the image
76
+ result_image = image.copy()
77
+ result_image = draw_boxes(result_image, filtered_preds)
78
+
79
+ # Display results
80
+ col2.header("Résultats de la détection")
81
+ col2.image(result_image)
82
+
83
+ # Display detailed predictions
84
+ if filtered_preds:
85
+ st.subheader("Détails des détections")
86
+ for pred in filtered_preds:
87
+ st.write(
88
+ f"• Type: {pred['label']} - "
89
+ f"Confiance: {pred['score']:.2%}"
90
+ )
91
+ else:
92
+ st.warning(
93
+ "Aucune fracture détectée avec le seuil de confiance actuel. "
94
+ "Essayez de baisser le seuil pour plus de résultats."
95
+ )
96
 
97
+ if __name__ == "__main__":
98
+ main()