File size: 6,449 Bytes
c9595c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
from flask import Flask, request, render_template, jsonify, send_from_directory
import os
import torch
import numpy as np
import cv2
from segment_anything import sam_model_registry, SamPredictor
from werkzeug.utils import secure_filename
import warnings
import json

# Initialisation de Flask
app = Flask(
    __name__,
    template_folder='templates',
    static_folder='static'
)
app.config['UPLOAD_FOLDER'] = os.path.join('static', 'uploads')
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)

# Charger le modèle SAM
MODEL_TYPE = "vit_b"
MODEL_PATH = os.path.join('models', 'sam_vit_b_01ec64.pth')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print("Chargement du modèle SAM...")
try:
    state_dict = torch.load(MODEL_PATH, map_location="cpu", weights_only=True)
except TypeError:
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=UserWarning)
        state_dict = torch.load(MODEL_PATH, map_location="cpu")

# Initialiser et charger le modèle
sam = sam_model_registry[MODEL_TYPE]()
sam.load_state_dict(state_dict, strict=False)
sam.to(device=device)
predictor = SamPredictor(sam)
print("Modèle SAM chargé avec succès!")

# Fonction pour générer une couleur unique pour chaque classe
def get_color_for_class(class_name):
    np.random.seed(hash(class_name) % (2**32))
    return tuple(np.random.randint(0, 256, size=3).tolist())

# Convertir un masque en bounding box au format YOLOv5
def mask_to_yolo_bbox(mask):
    y_indices, x_indices = np.where(mask > 0)
    if len(x_indices) == 0 or len(y_indices) == 0:
        return None
    x_min, x_max = x_indices.min(), x_indices.max()
    y_min, y_max = y_indices.min(), y_indices.max()

    # YOLOv5 format: x_center, y_center, width, height (normalized)
    x_center = (x_min + x_max) / 2
    y_center = (y_min + y_max) / 2
    width = x_max - x_min
    height = y_max - y_min

    return x_center, y_center, width, height

@app.route('/', methods=['GET', 'POST'])
def index():
    if request.method == 'POST':
        files = request.files.getlist('images')  # Get multiple files
        if not files:
            return "Aucun fichier sélectionné", 400

        filenames = []
        for file in files:
            filename = secure_filename(file.filename)
            filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
            file.save(filepath)
            filenames.append(filename)

        return render_template('index.html', uploaded_images=filenames, all_annotated=False)

    # Pour l'affichage des images déjà téléchargées
    uploaded_images = os.listdir(app.config['UPLOAD_FOLDER'])
    return render_template('index.html', uploaded_images=uploaded_images, all_annotated=False)

@app.route('/uploads/<filename>')
def uploaded_file(filename):
    return send_from_directory(app.config['UPLOAD_FOLDER'], filename)

@app.route('/segment', methods=['POST'])
def segment():
    data = request.get_json()
    print("Données reçues :", data)  # Log pour vérifier les données envoyées par le frontend

    image_names = data.get('image_names')
    points = data.get('points')

    if not image_names or not points:
        return jsonify({'success': False, 'error': 'Données manquantes'}), 400

    output = []
    for image_name in image_names:
        image_path = os.path.join(app.config['UPLOAD_FOLDER'], image_name)
        if not os.path.exists(image_path):
            return jsonify({'success': False, 'error': f'Image {image_name} non trouvée'}), 404

        # Créer un dossier pour sauvegarder les résultats
        output_dir = os.path.join(app.config['UPLOAD_FOLDER'], os.path.splitext(image_name)[0])
        os.makedirs(output_dir, exist_ok=True)

        # Charger l'image et effectuer la segmentation
        image = cv2.imread(image_path)
        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        predictor.set_image(image_rgb)
        annotated_image = image.copy()

        # YOLOv5 annotation
        yolo_annotations = []

        for point in points:
            x, y = point['x'], point['y']
            class_name = point.get('class', 'Unknown')
            class_id = hash(class_name) % 1000  # Générer un ID unique basé sur le nom
            color = get_color_for_class(class_name)  # Couleur unique pour chaque classe
            masks, _, _ = predictor.predict(
                point_coords=np.array([[x, y]]),
                point_labels=np.array([1]),
                multimask_output=False
            )
            mask = masks[0]
            annotated_image[mask > 0] = color  # Superposer le masque avec la couleur

            # Convertir le masque en bounding box YOLOv5
            bbox = mask_to_yolo_bbox(mask)
            if bbox:
                x_center, y_center, width, height = bbox
                # Normaliser les valeurs
                x_center /= image.shape[1]
                y_center /= image.shape[0]
                width /= image.shape[1]
                height /= image.shape[0]
                yolo_annotations.append(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")

            # Ajouter le texte de la classe
            cv2.putText(annotated_image, class_name, (int(x), int(y)),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)  # Texte blanc

        # Sauvegarder les résultats
        annotated_filename = f"annotated_{image_name}"
        annotated_path = os.path.join(output_dir, annotated_filename)
        cv2.imwrite(annotated_path, annotated_image)

        # Sauvegarder les annotations YOLOv5
        yolo_path = os.path.join(output_dir, f"{os.path.splitext(image_name)[0]}.txt")
        with open(yolo_path, "w") as f:
            f.write("\n".join(yolo_annotations))

        # Copier l'image originale dans le dossier
        original_copy_path = os.path.join(output_dir, image_name)
        if not os.path.exists(original_copy_path):
            os.rename(image_path, original_copy_path)

        # Renvoyer le chemin relatif pour affichage
        relative_output_dir = output_dir.replace("static/", "")
        output.append({
            'success': True,
            'image': f"{relative_output_dir}/{annotated_filename}",
            'yolo_annotations': f"{relative_output_dir}/{os.path.splitext(image_name)[0]}.txt"
        })

    return jsonify(output)

if __name__ == '__main__':
    app.run(debug=True, host='0.0.0.0', port=5000)