File size: 3,548 Bytes
c0bb46a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d00772b
 
 
 
 
c0bb46a
 
 
 
d00772b
 
 
 
 
 
 
 
 
 
 
 
c0bb46a
d00772b
 
 
c0bb46a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d00772b
c0bb46a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d00772b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
from transformers import VideoMAEForVideoClassification, VideoMAEFeatureExtractor
import os, cv2, uuid, json
import numpy as np
import gdown

model_path = "checkpoint_epoch_1.pt"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# تحميل النموذج من Google Drive لو مش موجود
if not os.path.exists(model_path):
    print("Downloading checkpoint...")
    url = "https://drive.google.com/uc?id=1dIaptYPq-1fgo0yoBoPlDsbIfs3BEqJI"
    gdown.download(url, model_path, quiet=False)

model = VideoMAEForVideoClassification.from_pretrained("MCG-NJU/videomae-base", num_labels=3)
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval().to(device)

feature_extractor = VideoMAEFeatureExtractor.from_pretrained("MCG-NJU/videomae-base")
label_map = {0: "Goal", 1: "Card", 2: "Substitution"}

def predict_gradio(video):
    import tempfile
    import shutil

    # أنشئ مجلد مؤقت للعمل
    video_id = str(uuid.uuid4())
    work_dir = f"./temp/{video_id}"
    os.makedirs(work_dir, exist_ok=True)

    # نحفظ الفيديو المرفوع على هيئة ملف مؤقت mp4
    temp_video_path = os.path.join(work_dir, "input.mp4")
    if isinstance(video, str):
        # Gradio بيرسل أحيانًا مسار الملف
        shutil.copy(video, temp_video_path)
    else:
        # Gradio بيرسل BytesIO stream (مش شائع بس نغطيه)
        with open(temp_video_path, "wb") as f:
            f.write(video.read())

    # نحاول نفتح الفيديو
    cap = cv2.VideoCapture(temp_video_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    if fps == 0 or fps != fps:  # NaN or 0
        return [{"error": "Invalid or unreadable video."}], ""

    frames = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        resized = cv2.resize(frame, (224, 224))
        frames.append(resized)
    cap.release()

    segment_size = int(fps * 5)
    predictions = []
    output_segments = []

    for i in range(0, len(frames), segment_size):
        segment = frames[i:i+segment_size]
        if len(segment) < 16:
            continue
        indices = np.linspace(0, len(segment)-1, 16).astype(int)
        sampled_frames = [segment[idx] for idx in indices]

        inputs = feature_extractor(sampled_frames, return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = model(**inputs)
            probs = torch.nn.functional.softmax(outputs.logits, dim=1)
            confidence, pred = torch.max(probs, dim=1)

            if confidence.item() > 0.70:
                label = label_map[pred.item()]
                start_time = i / fps
                end_time = min((i + segment_size), len(frames)) / fps
                predictions.append({
                    "start": round(start_time, 2),
                    "end": round(end_time, 2),
                    "label": label,
                    "confidence": round(confidence.item(), 3)
                })
                output_segments.append(segment)

    out_path = f"{work_dir}/summary.mp4"
    if output_segments:
        fourcc = cv2.VideoWriter_fourcc(*"mp4v")
        out = cv2.VideoWriter(out_path, fourcc, fps, (224, 224))
        for seg in output_segments:
            for frame in seg:
                out.write(frame)
        out.release()
        return predictions, out_path
    else:
        return predictions, ""