MO-12 commited on
Commit
c0bb46a
·
verified ·
1 Parent(s): b872074

Update API_Model.py

Browse files
Files changed (1) hide show
  1. API_Model.py +81 -81
API_Model.py CHANGED
@@ -1,81 +1,81 @@
1
- import torch
2
- from transformers import VideoMAEForVideoClassification, VideoMAEFeatureExtractor
3
- import os, cv2, uuid, json
4
- import numpy as np
5
- import gdown
6
-
7
- model_path = "checkpoint_epoch_1.pt"
8
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
-
10
- # تحميل النموذج من Google Drive لو مش موجود
11
- if not os.path.exists(model_path):
12
- print("Downloading checkpoint...")
13
- url = "https://drive.google.com/uc?id=1dIaptYPq-1fgo0yoBoPlDsbIfs3BEqJI"
14
- gdown.download(url, model_path, quiet=False)
15
-
16
- model = VideoMAEForVideoClassification.from_pretrained("MCG-NJU/videomae-base", num_labels=3)
17
- checkpoint = torch.load(model_path, map_location=device)
18
- model.load_state_dict(checkpoint["model_state_dict"])
19
- model.eval().to(device)
20
-
21
- feature_extractor = VideoMAEFeatureExtractor.from_pretrained("MCG-NJU/videomae-base")
22
- label_map = {0: "Goal", 1: "Card", 2: "Substitution"}
23
-
24
- def predict_gradio(video_path):
25
- video_id = str(uuid.uuid4())
26
- work_dir = f"./temp/{video_id}"
27
- os.makedirs(work_dir, exist_ok=True)
28
-
29
- cap = cv2.VideoCapture(video_path)
30
- fps = cap.get(cv2.CAP_PROP_FPS)
31
- frames = []
32
- while True:
33
- ret, frame = cap.read()
34
- if not ret:
35
- break
36
- resized = cv2.resize(frame, (224, 224))
37
- frames.append(resized)
38
- cap.release()
39
-
40
- segment_size = int(fps * 5)
41
- predictions = []
42
- output_segments = []
43
-
44
- for i in range(0, len(frames), segment_size):
45
- segment = frames[i:i+segment_size]
46
- if len(segment) < 16:
47
- continue
48
- indices = np.linspace(0, len(segment)-1, 16).astype(int)
49
- sampled_frames = [segment[idx] for idx in indices]
50
-
51
- inputs = feature_extractor(sampled_frames, return_tensors="pt")
52
- inputs = {k: v.to(device) for k, v in inputs.items()}
53
-
54
- with torch.no_grad():
55
- outputs = model(**inputs)
56
- probs = torch.nn.functional.softmax(outputs.logits, dim=1)
57
- confidence, pred = torch.max(probs, dim=1)
58
-
59
- if confidence.item() > 0.7:
60
- label = label_map[pred.item()]
61
- start_time = i / fps
62
- end_time = min((i + segment_size), len(frames)) / fps
63
- predictions.append({
64
- "start": round(start_time, 2),
65
- "end": round(end_time, 2),
66
- "label": label,
67
- "confidence": round(confidence.item(), 3)
68
- })
69
- output_segments.append(segment)
70
-
71
- out_path = f"{work_dir}/summary.mp4"
72
- if output_segments:
73
- fourcc = cv2.VideoWriter_fourcc(*"mp4v")
74
- out = cv2.VideoWriter(out_path, fourcc, fps, (224, 224))
75
- for seg in output_segments:
76
- for frame in seg:
77
- out.write(frame)
78
- out.release()
79
- return predictions, out_path
80
- else:
81
- return predictions, None
 
1
+ import torch
2
+ from transformers import VideoMAEForVideoClassification, VideoMAEFeatureExtractor
3
+ import os, cv2, uuid, json
4
+ import numpy as np
5
+ import gdown
6
+
7
+ model_path = "checkpoint_epoch_1.pt"
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ # تحميل النموذج من Google Drive لو مش موجود
11
+ if not os.path.exists(model_path):
12
+ print("Downloading checkpoint...")
13
+ url = "https://drive.google.com/uc?id=1dIaptYPq-1fgo0yoBoPlDsbIfs3BEqJI"
14
+ gdown.download(url, model_path, quiet=False)
15
+
16
+ model = VideoMAEForVideoClassification.from_pretrained("MCG-NJU/videomae-base", num_labels=3)
17
+ checkpoint = torch.load(model_path, map_location=device)
18
+ model.load_state_dict(checkpoint["model_state_dict"])
19
+ model.eval().to(device)
20
+
21
+ feature_extractor = VideoMAEFeatureExtractor.from_pretrained("MCG-NJU/videomae-base")
22
+ label_map = {0: "Goal", 1: "Card", 2: "Substitution"}
23
+
24
+ def predict_gradio(video_path):
25
+ video_id = str(uuid.uuid4())
26
+ work_dir = f"./temp/{video_id}"
27
+ os.makedirs(work_dir, exist_ok=True)
28
+
29
+ cap = cv2.VideoCapture(video_path)
30
+ fps = cap.get(cv2.CAP_PROP_FPS)
31
+ frames = []
32
+ while True:
33
+ ret, frame = cap.read()
34
+ if not ret:
35
+ break
36
+ resized = cv2.resize(frame, (224, 224))
37
+ frames.append(resized)
38
+ cap.release()
39
+
40
+ segment_size = int(fps * 5)
41
+ predictions = []
42
+ output_segments = []
43
+
44
+ for i in range(0, len(frames), segment_size):
45
+ segment = frames[i:i+segment_size]
46
+ if len(segment) < 16:
47
+ continue
48
+ indices = np.linspace(0, len(segment)-1, 16).astype(int)
49
+ sampled_frames = [segment[idx] for idx in indices]
50
+
51
+ inputs = feature_extractor(sampled_frames, return_tensors="pt")
52
+ inputs = {k: v.to(device) for k, v in inputs.items()}
53
+
54
+ with torch.no_grad():
55
+ outputs = model(**inputs)
56
+ probs = torch.nn.functional.softmax(outputs.logits, dim=1)
57
+ confidence, pred = torch.max(probs, dim=1)
58
+
59
+ if confidence.item() > 0.7:
60
+ label = label_map[pred.item()]
61
+ start_time = i / fps
62
+ end_time = min((i + segment_size), len(frames)) / fps
63
+ predictions.append({
64
+ "start": round(start_time, 2),
65
+ "end": round(end_time, 2),
66
+ "label": label,
67
+ "confidence": round(confidence.item(), 3)
68
+ })
69
+ output_segments.append(segment)
70
+
71
+ out_path = f"{work_dir}/summary.mp4"
72
+ if output_segments:
73
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
74
+ out = cv2.VideoWriter(out_path, fourcc, fps, (224, 224))
75
+ for seg in output_segments:
76
+ for frame in seg:
77
+ out.write(frame)
78
+ out.release()
79
+ return predictions, out_path
80
+ else:
81
+ return predictions, ""