File size: 6,479 Bytes
f137487
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
867f506
f137487
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
867f506
f137487
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
    from fastapi import FastAPI, UploadFile, File
    import cv2
    import torch
    import pandas as pd
    from PIL import Image
    from transformers import AutoImageProcessor, AutoModelForImageClassification
    from tqdm import tqdm
    import json
    import shutil
    from fastapi.middleware.cors import CORSMiddleware
    from fastapi.responses import HTMLResponse

    app = FastAPI()

    # Add CORS middleware to allow requests from localhost:8080 (or any origin you specify)
    app.add_middleware(
        CORSMiddleware,
        # allow_origins=["http://localhost:8080"],  # Replace with the URL of your Vue.js app
        allow_origins=["http://localhost:8080"],  # Replace with the URL of your Vue.js app
        allow_credentials=True,
        allow_methods=["*"],  # Allows all HTTP methods (GET, POST, etc.)
        allow_headers=["*"],  # Allows all headers (such as Content-Type, Authorization, etc.)
    )

    # Charger le processor et le modèle fine-tuné depuis le chemin local
    local_model_path = r'.\vit-finetuned-ucf101'
    processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
    model = AutoModelForImageClassification.from_pretrained(local_model_path)
    # model = AutoModelForImageClassification.from_pretrained("2nzi/vit-finetuned-ucf101") 
    model.eval()

    # Fonction pour classifier une image
    def classifier_image(image):
        image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        inputs = processor(images=image_pil, return_tensors="pt")
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits
        predicted_class_idx = logits.argmax(-1).item()
        predicted_class = model.config.id2label[predicted_class_idx]
        return predicted_class

    # Fonction pour traiter la vidéo et identifier les séquences de "Surfing"
    def identifier_sequences_surfing(video_path, intervalle=0.5):
        cap = cv2.VideoCapture(video_path)
        frame_rate = cap.get(cv2.CAP_PROP_FPS)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        frame_interval = int(frame_rate * intervalle)

        resultats = []
        sequences_surfing = []
        frame_index = 0
        in_surf_sequence = False
        start_timestamp = None

        with tqdm(total=total_frames, desc="Traitement des frames de la vidéo", unit="frame") as pbar:
            success, frame = cap.read()
            while success:
                if frame_index % frame_interval == 0:
                    timestamp = round(frame_index / frame_rate, 2)  # Maintain precision to the centisecond level
                    classe = classifier_image(frame)
                    resultats.append({"Timestamp": timestamp, "Classe": classe})

                    if classe == "Surfing" and not in_surf_sequence:
                        in_surf_sequence = True
                        start_timestamp = timestamp

                    elif classe != "Surfing" and in_surf_sequence:
                        # Vérifier l'image suivante pour confirmer si c'était une erreur ponctuelle
                        success_next, frame_next = cap.read()
                        next_timestamp = round((frame_index + frame_interval) / frame_rate, 2)
                        classe_next = None

                        if success_next:
                            classe_next = classifier_image(frame_next)
                            resultats.append({"Timestamp": next_timestamp, "Classe": classe_next})

                        # Si l'image suivante est "Surfing", on ignore l'erreur ponctuelle
                        if classe_next == "Surfing":
                            success = success_next
                            frame = frame_next
                            frame_index += frame_interval
                            pbar.update(frame_interval)
                            continue
                        else:
                            # Sinon, terminer la séquence "Surfing"
                            in_surf_sequence = False
                            end_timestamp = timestamp
                            sequences_surfing.append((start_timestamp, end_timestamp))

                success, frame = cap.read()
                frame_index += 1
                pbar.update(1)

        # Si on est toujours dans une séquence "Surfing" à la fin de la vidéo
        if in_surf_sequence:
            sequences_surfing.append((start_timestamp, round(frame_index / frame_rate, 2)))

        cap.release()
        dataframe_sequences = pd.DataFrame(sequences_surfing, columns=["Début", "Fin"])
        return dataframe_sequences

    # Fonction pour convertir les séquences en format JSON
    def convertir_sequences_en_json(dataframe):
        events = []
        blocks = []
        for idx, row in dataframe.iterrows():
            block = {
                "id": f"Surfing{idx + 1}",
                "start": round(row["Début"], 2),
                "end": round(row["Fin"], 2)
            }
            blocks.append(block)
        event = {
            "event": "Surfing",
            "blocks": blocks
        }
        events.append(event)
        return events

    @app.post("/analyze_video/")
    async def analyze_video(file: UploadFile = File(...)):
        with open("uploaded_video.mp4", "wb") as buffer:
            shutil.copyfileobj(file.file, buffer)

        dataframe_sequences = identifier_sequences_surfing("uploaded_video.mp4", intervalle=1)
        json_result = convertir_sequences_en_json(dataframe_sequences)
        return json_result

    @app.get("/", response_class=HTMLResponse)
    async def index():
        return (
            """
            <html>
                <body>
                    <h1>Hello world!</h1>
                    <p>This `/` is the most simple and default endpoint.</p>
                    <p>If you want to learn more, check out the documentation of the API at 
                    <a href='/docs'>/docs</a> or 
                    <a href='https://2nzi-video-sequence-labeling.hf.space/docs' target='_blank'>external docs</a>.
                    </p>
                </body>
            </html>
            """
        )


    # Lancer l'application avec uvicorn (command line)
    # uvicorn main:app --reload
    # http://localhost:8000/docs#/
    # (.venv) PS C:\Users\antoi\Documents\Work_Learn\Labeling-Deploy\FastAPI> uvicorn main:app --host 0.0.0.0 --port 8000 --workers 1