yassonee commited on
Commit
7b58439
·
verified ·
1 Parent(s): 7e5dda6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -13
app.py CHANGED
@@ -1,32 +1,58 @@
1
  import streamlit as st
2
- from ultralytics import YOLO
3
- from PIL import Image
4
  import torch
 
 
5
 
6
- st.set_page_config(layout="centered")
 
 
 
 
 
 
 
 
 
7
 
8
  @st.cache_resource
9
  def load_model():
10
- model = YOLO('yolov8m.pt') # Load base YOLOv8 model
11
- model.load('keremberke/yolov8m-chest-xray-classification.pt') # Load weights
12
- return model
13
 
14
  def main():
15
- st.title("Analyse Radiographie Thoracique")
16
 
17
- model = load_model()
18
 
19
  uploaded_file = st.file_uploader("Télécharger une radiographie", type=["jpg", "jpeg", "png"])
20
 
21
  if uploaded_file:
22
- image = Image.open(uploaded_file)
23
- resized_image = image.resize((640, 640))
24
  st.image(resized_image, width=400)
25
 
26
  if st.button("Analyser"):
27
- results = model.predict(source=resized_image)
28
- st.write(f"Résultat: {results[0].names[results[0].probs.argmax()]}")
29
- st.write(f"Confiance: {results[0].probs.max():.2%}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  if __name__ == "__main__":
32
  main()
 
1
  import streamlit as st
2
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
 
3
  import torch
4
+ from PIL import Image
5
+ import numpy as np
6
 
7
+ def create_overlay(image, attention_map, alpha=0.5):
8
+ attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
9
+ heatmap = np.uint8(255 * attention_map)
10
+ heatmap = Image.fromarray(heatmap).resize(image.size)
11
+ heatmap = np.array(heatmap)
12
+ heatmap = np.stack([heatmap, np.zeros_like(heatmap), np.zeros_like(heatmap)], axis=-1)
13
+
14
+ image_array = np.array(image)
15
+ overlay = Image.fromarray(np.uint8(image_array * (1 - alpha) + heatmap * alpha))
16
+ return overlay
17
 
18
  @st.cache_resource
19
  def load_model():
20
+ processor = AutoImageProcessor.from_pretrained("mrm8488/vit-base-patch16-224_finetuned-pneumothorax")
21
+ model = AutoModelForImageClassification.from_pretrained("mrm8488/vit-base-patch16-224_finetuned-pneumothorax")
22
+ return processor, model
23
 
24
  def main():
25
+ st.title("Détection de Pneumothorax")
26
 
27
+ processor, model = load_model()
28
 
29
  uploaded_file = st.file_uploader("Télécharger une radiographie", type=["jpg", "jpeg", "png"])
30
 
31
  if uploaded_file:
32
+ image = Image.open(uploaded_file).convert('RGB')
33
+ resized_image = image.resize((224, 224))
34
  st.image(resized_image, width=400)
35
 
36
  if st.button("Analyser"):
37
+ with st.spinner("Analyse en cours..."):
38
+ inputs = processor(images=resized_image, return_tensors="pt")
39
+ outputs = model(**inputs)
40
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
41
+
42
+ # Obtenir attentions des dernières couches
43
+ attention = outputs.hidden_states[-1].mean(1)[0].detach().numpy()
44
+ attention_map = attention.reshape(14, 14) # ViT patch size
45
+
46
+ # Créer overlay
47
+ overlay = create_overlay(resized_image, attention_map)
48
+
49
+ col1, col2 = st.columns(2)
50
+ with col1:
51
+ st.write("Résultat:", model.config.id2label[outputs.logits.argmax(-1).item()])
52
+ st.write(f"Confiance: {probs.max().item():.2%}")
53
+
54
+ with col2:
55
+ st.image(overlay, caption="Zones suspectes", width=400)
56
 
57
  if __name__ == "__main__":
58
  main()