salimshakeel commited on
Commit
d2542a3
·
1 Parent(s): 30508a4

upload files

Browse files
.gitattributes CHANGED
@@ -1,35 +1,173 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[codz]
4
+ *$py.class
5
+ # C extensions
6
+ *.so
7
+ # Distribution / packaging
8
+ .Python
9
+ build/
10
+ develop-eggs/
11
+ dist/
12
+ downloads/
13
+ eggs/
14
+ .eggs/
15
+ lib/
16
+ lib64/
17
+ parts/
18
+ sdist/
19
+ var/
20
+ wheels/
21
+ share/python-wheels/
22
+ *.egg-info/
23
+ .installed.cfg
24
+ *.egg
25
+ MANIFEST
26
+ # PyInstaller
27
+ # Usually these files are written by a python script from a template
28
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
29
+ *.manifest
30
+ *.spec
31
+ # Installer logs
32
+ pip-log.txt
33
+ pip-delete-this-directory.txt
34
+ # Unit test / coverage reports
35
+ htmlcov/
36
+ .tox/
37
+ .nox/
38
+ .coverage
39
+ .coverage.*
40
+ .cache
41
+ nosetests.xml
42
+ coverage.xml
43
+ *.cover
44
+ *.py.cover
45
+ .hypothesis/
46
+ .pytest_cache/
47
+ cover/
48
+ # Translations
49
+ *.mo
50
+ *.pot
51
+ # Django stuff:
52
+ *.log
53
+ local_settings.py
54
+ db.sqlite3
55
+ db.sqlite3-journal
56
+ # Flask stuff:
57
+ instance/
58
+ .webassets-cache
59
+ # Scrapy stuff:
60
+ .scrapy
61
+ # Sphinx documentation
62
+ docs/_build/
63
+ # PyBuilder
64
+ .pybuilder/
65
+ target/
66
+ # Jupyter Notebook
67
+ .ipynb_checkpoints
68
+ # IPython
69
+ profile_default/
70
+ ipython_config.py
71
+ # pyenv
72
+ # For a library or package, you might want to ignore these files since the code is
73
+ # intended to run in multiple environments; otherwise, check them in:
74
+ # .python-version
75
+ # pipenv
76
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
77
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
78
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
79
+ # install all needed dependencies.
80
+ #Pipfile.lock
81
+ # UV
82
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
83
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
84
+ # commonly ignored for libraries.
85
+ #uv.lock
86
+ # poetry
87
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
88
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
89
+ # commonly ignored for libraries.
90
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
91
+ #poetry.lock
92
+ #poetry.toml
93
+ # pdm
94
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
95
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
96
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
97
+ #pdm.lock
98
+ #pdm.toml
99
+ .pdm-python
100
+ .pdm-build/
101
+ # pixi
102
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
103
+ #pixi.lock
104
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
105
+ # in the .venv directory. It is recommended not to include this directory in version control.
106
+ .pixi
107
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
108
+ __pypackages__/
109
+ # Celery stuff
110
+ celerybeat-schedule
111
+ celerybeat.pid
112
+ # SageMath parsed files
113
+ *.sage.py
114
+ # Environments
115
+ .env
116
+ .envrc
117
+ .venv
118
+ env/
119
+ venv/
120
+ ENV/
121
+ env.bak/
122
+ venv.bak/
123
+ # Spyder project settings
124
+ .spyderproject
125
+ .spyproject
126
+ # Rope project settings
127
+ .ropeproject
128
+ # mkdocs documentation
129
+ /site
130
+ # mypy
131
+ .mypy_cache/
132
+ .dmypy.json
133
+ dmypy.json
134
+ # Pyre type checker
135
+ .pyre/
136
+ # pytype static type analyzer
137
+ .pytype/
138
+ # Cython debug symbols
139
+ cython_debug/
140
+ # PyCharm
141
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
142
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
143
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
144
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
145
+ #.idea/
146
+ # Abstra
147
+ # Abstra is an AI-powered process automation framework.
148
+ # Ignore directories containing user credentials, local state, and settings.
149
+ # Learn more at https://abstra.io/docs
150
+ .abstra/
151
+ # Visual Studio Code
152
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
153
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
154
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
155
+ # you could uncomment the following to ignore the entire vscode folder
156
+ # .vscode/
157
+ # Ruff stuff:
158
+ .ruff_cache/
159
+ # PyPI configuration file
160
+ .pypirc
161
+ # Cursor
162
+ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
163
+ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
164
+ # refer to https://docs.cursor.com/context/ignore-files
165
+ .cursorignore
166
+ .cursorindexingignore
167
+ # Marimo
168
+ marimo/_static/
169
+ marimo/_lsp/
170
+ __marimo__/
171
+ # Streamlit
172
+ .streamlit/secrets.toml
173
+ static/uploads/* filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # You will also find guides on how best to write your Dockerfile
3
+
4
+ FROM python:3.12-slim
5
+
6
+ WORKDIR /code
7
+
8
+ COPY ./requirements.txt /code/requirements.txt
9
+
10
+ RUN pip install --no-cache-dir --upgrade pip \
11
+ && pip install --no-cache-dir -r /code/requirements.txt
12
+
13
+ COPY . .
14
+
15
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
Procfile ADDED
@@ -0,0 +1 @@
 
 
1
+ web: gunicorn -w 3 -k uvicorn.workers.UvicornWorker main:app
__init__.py ADDED
File without changes
config.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # config.py
2
+ import torch
3
+ UPLOAD_DIR = "backend/static/uploads"
4
+ OUTPUT_DIR = "backend/static/outputs"
5
+ FRAME_RATE = 15
6
+ SCORE_THRESHOLD = 0.4
7
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
layers/attention.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+
6
+
7
+ class SelfAttention(nn.Module):
8
+ def __init__(self, input_size=1024, output_size=1024, freq=10000, heads=1, pos_enc=None):
9
+ """ The basic (multi-head) Attention 'cell' containing the learnable parameters of Q, K and V
10
+
11
+ :param int input_size: Feature input size of Q, K, V.
12
+ :param int output_size: Feature -hidden- size of Q, K, V.
13
+ :param int freq: The frequency of the sinusoidal positional encoding.
14
+ :param int heads: Number of heads for the attention module.
15
+ :param str | None pos_enc: The type of the positional encoding [supported: Absolute, Relative].
16
+ """
17
+ super(SelfAttention, self).__init__()
18
+
19
+ self.permitted_encodings = ["absolute", "relative"]
20
+ if pos_enc is not None:
21
+ pos_enc = pos_enc.lower()
22
+ assert pos_enc in self.permitted_encodings, f"Supported encodings: {*self.permitted_encodings,}"
23
+
24
+ self.input_size = input_size
25
+ self.output_size = output_size
26
+ self.heads = heads
27
+ self.pos_enc = pos_enc
28
+ self.freq = freq
29
+ self.Wk, self.Wq, self.Wv = nn.ModuleList(), nn.ModuleList(), nn.ModuleList()
30
+ for _ in range(self.heads):
31
+ self.Wk.append(nn.Linear(in_features=input_size, out_features=output_size//heads, bias=False))
32
+ self.Wq.append(nn.Linear(in_features=input_size, out_features=output_size//heads, bias=False))
33
+ self.Wv.append(nn.Linear(in_features=input_size, out_features=output_size//heads, bias=False))
34
+ self.out = nn.Linear(in_features=output_size, out_features=input_size, bias=False)
35
+
36
+ self.softmax = nn.Softmax(dim=-1)
37
+ self.drop = nn.Dropout(p=0.5)
38
+
39
+ def getAbsolutePosition(self, T):
40
+ """Calculate the sinusoidal positional encoding based on the absolute position of each considered frame.
41
+ Based on 'Attention is all you need' paper (https://arxiv.org/abs/1706.03762)
42
+
43
+ :param int T: Number of frames contained in Q, K and V
44
+ :return: Tensor with shape [T, T]
45
+ """
46
+ freq = self.freq
47
+ d = self.input_size
48
+
49
+ pos = torch.tensor([k for k in range(T)], device=self.out.weight.device)
50
+ i = torch.tensor([k for k in range(T//2)], device=self.out.weight.device)
51
+
52
+ # Reshape tensors each pos_k for each i indices
53
+ pos = pos.reshape(pos.shape[0], 1)
54
+ pos = pos.repeat_interleave(i.shape[0], dim=1)
55
+ i = i.repeat(pos.shape[0], 1)
56
+
57
+ AP = torch.zeros(T, T, device=self.out.weight.device)
58
+ AP[pos, 2*i] = torch.sin(pos / freq ** ((2 * i) / d))
59
+ AP[pos, 2*i+1] = torch.cos(pos / freq ** ((2 * i) / d))
60
+ return AP
61
+
62
+ def getRelativePosition(self, T):
63
+ """Calculate the sinusoidal positional encoding based on the relative position of each considered frame.
64
+ r_pos calculations as here: https://theaisummer.com/positional-embeddings/
65
+
66
+ :param int T: Number of frames contained in Q, K and V
67
+ :return: Tensor with shape [T, T]
68
+ """
69
+ freq = self.freq
70
+ d = 2 * T
71
+ min_rpos = -(T - 1)
72
+
73
+ i = torch.tensor([k for k in range(T)], device=self.out.weight.device)
74
+ j = torch.tensor([k for k in range(T)], device=self.out.weight.device)
75
+
76
+ # Reshape tensors each i for each j indices
77
+ i = i.reshape(i.shape[0], 1)
78
+ i = i.repeat_interleave(i.shape[0], dim=1)
79
+ j = j.repeat(i.shape[0], 1)
80
+
81
+ # Calculate the relative positions
82
+ r_pos = j - i - min_rpos
83
+
84
+ RP = torch.zeros(T, T, device=self.out.weight.device)
85
+ idx = torch.tensor([k for k in range(T//2)], device=self.out.weight.device)
86
+ RP[:, 2*idx] = torch.sin(r_pos[:, 2*idx] / freq ** ((i[:, 2*idx] + j[:, 2*idx]) / d))
87
+ RP[:, 2*idx+1] = torch.cos(r_pos[:, 2*idx+1] / freq ** ((i[:, 2*idx+1] + j[:, 2*idx+1]) / d))
88
+ return RP
89
+
90
+ def forward(self, x):
91
+ """ Compute the weighted frame features, based on either the global or local (multi-head) attention mechanism.
92
+
93
+ :param torch.tensor x: Frame features with shape [T, input_size]
94
+ :return: A tuple of:
95
+ y: Weighted features based on the attention weights, with shape [T, input_size]
96
+ att_weights : The attention weights (before dropout), with shape [T, T]
97
+ """
98
+ outputs = []
99
+ for head in range(self.heads):
100
+ K = self.Wk[head](x)
101
+ Q = self.Wq[head](x)
102
+ V = self.Wv[head](x)
103
+
104
+ # Q *= 0.06 # scale factor VASNet
105
+ # Q /= np.sqrt(self.output_size) # scale factor (i.e 1 / sqrt(d_k) )
106
+ energies = torch.matmul(Q, K.transpose(1, 0))
107
+ if self.pos_enc is not None:
108
+ if self.pos_enc == "absolute":
109
+ AP = self.getAbsolutePosition(T=energies.shape[0])
110
+ energies = energies + AP
111
+ elif self.pos_enc == "relative":
112
+ RP = self.getRelativePosition(T=energies.shape[0])
113
+ energies = energies + RP
114
+
115
+ att_weights = self.softmax(energies)
116
+ _att_weights = self.drop(att_weights)
117
+ y = torch.matmul(_att_weights, V)
118
+
119
+ # Save the current head output
120
+ outputs.append(y)
121
+ y = self.out(torch.cat(outputs, dim=1))
122
+ return y, att_weights.clone() # for now we don't deal with the weights (probably max or avg pooling)
123
+
124
+
125
+ if __name__ == '__main__':
126
+ pass
127
+ """Uncomment for a quick proof of concept
128
+ model = SelfAttention(input_size=256, output_size=256, pos_enc="absolute").cuda()
129
+ _input = torch.randn(500, 256).cuda() # [seq_len, hidden_size]
130
+ output, weights = model(_input)
131
+ print(f"Output shape: {output.shape}\tattention shape: {weights.shape}")
132
+ """
layers/summarizer.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import math
6
+ from .attention import SelfAttention
7
+
8
+
9
+
10
+ class MultiAttention(nn.Module):
11
+ def __init__(self, input_size=1024, output_size=1024, freq=10000, pos_enc=None,
12
+ num_segments=None, heads=1, fusion=None):
13
+ """ Class wrapping the MultiAttention part of PGL-SUM; its key modules and parameters.
14
+
15
+ :param int input_size: The expected input feature size.
16
+ :param int output_size: The hidden feature size of the attention mechanisms.
17
+ :param int freq: The frequency of the sinusoidal positional encoding.
18
+ :param None | str pos_enc: The selected positional encoding [absolute, relative].
19
+ :param None | int num_segments: The selected number of segments to split the videos.
20
+ :param int heads: The selected number of global heads.
21
+ :param None | str fusion: The selected type of feature fusion.
22
+ """
23
+ super(MultiAttention, self).__init__()
24
+
25
+ # Global Attention, considering differences among all frames
26
+ self.attention = SelfAttention(input_size=input_size, output_size=output_size,
27
+ freq=freq, pos_enc=pos_enc, heads=heads)
28
+
29
+ self.num_segments = num_segments
30
+ if self.num_segments is not None:
31
+ assert self.num_segments >= 2, "num_segments must be None or 2+"
32
+ self.local_attention = nn.ModuleList()
33
+ for _ in range(self.num_segments):
34
+ # Local Attention, considering differences among the same segment with reduce hidden size
35
+ self.local_attention.append(SelfAttention(input_size=input_size, output_size=output_size//num_segments,
36
+ freq=freq, pos_enc=pos_enc, heads=4))
37
+ self.permitted_fusions = ["add", "mult", "avg", "max"]
38
+ self.fusion = fusion
39
+ if self.fusion is not None:
40
+ self.fusion = self.fusion.lower()
41
+ assert self.fusion in self.permitted_fusions, f"Fusion method must be: {*self.permitted_fusions,}"
42
+
43
+ def forward(self, x):
44
+ """ Compute the weighted frame features, based on the global and locals (multi-head) attention mechanisms.
45
+
46
+ :param torch.Tensor x: Tensor with shape [T, input_size] containing the frame features.
47
+ :return: A tuple of:
48
+ weighted_value: Tensor with shape [T, input_size] containing the weighted frame features.
49
+ attn_weights: Tensor with shape [T, T] containing the attention weights.
50
+ """
51
+ weighted_value, attn_weights = self.attention(x) # global attention
52
+
53
+ if self.num_segments is not None and self.fusion is not None:
54
+ segment_size = math.ceil(x.shape[0] / self.num_segments)
55
+ for segment in range(self.num_segments):
56
+ left_pos = segment * segment_size
57
+ right_pos = (segment + 1) * segment_size
58
+ local_x = x[left_pos:right_pos]
59
+ weighted_local_value, attn_local_weights = self.local_attention[segment](local_x) # local attentions
60
+
61
+ # Normalize the features vectors
62
+ weighted_value[left_pos:right_pos] = F.normalize(weighted_value[left_pos:right_pos].clone(), p=2, dim=1)
63
+ weighted_local_value = F.normalize(weighted_local_value, p=2, dim=1)
64
+ if self.fusion == "add":
65
+ weighted_value[left_pos:right_pos] += weighted_local_value
66
+ elif self.fusion == "mult":
67
+ weighted_value[left_pos:right_pos] *= weighted_local_value
68
+ elif self.fusion == "avg":
69
+ weighted_value[left_pos:right_pos] += weighted_local_value
70
+ weighted_value[left_pos:right_pos] /= 2
71
+ elif self.fusion == "max":
72
+ weighted_value[left_pos:right_pos] = torch.max(weighted_value[left_pos:right_pos].clone(),
73
+ weighted_local_value)
74
+
75
+ return weighted_value, attn_weights
76
+
77
+
78
+ class PGL_SUM(nn.Module):
79
+ def __init__(self, input_size=1024, output_size=1024, freq=10000, pos_enc=None,
80
+ num_segments=None, heads=1, fusion=None):
81
+ """ Class wrapping the PGL-SUM model; its key modules and parameters.
82
+
83
+ :param int input_size: The expected input feature size.
84
+ :param int output_size: The hidden feature size of the attention mechanisms.
85
+ :param int freq: The frequency of the sinusoidal positional encoding.
86
+ :param None | str pos_enc: The selected positional encoding [absolute, relative].
87
+ :param None | int num_segments: The selected number of segments to split the videos.
88
+ :param int heads: The selected number of global heads.
89
+ :param None | str fusion: The selected type of feature fusion.
90
+ """
91
+ super(PGL_SUM, self).__init__()
92
+
93
+ self.attention = MultiAttention(input_size=input_size, output_size=output_size, freq=freq,
94
+ pos_enc=pos_enc, num_segments=num_segments, heads=heads, fusion=fusion)
95
+ self.linear_1 = nn.Linear(in_features=input_size, out_features=input_size)
96
+ self.linear_2 = nn.Linear(in_features=self.linear_1.out_features, out_features=1)
97
+
98
+ self.drop = nn.Dropout(p=0.5)
99
+ self.norm_y = nn.LayerNorm(normalized_shape=input_size, eps=1e-6)
100
+ self.norm_linear = nn.LayerNorm(normalized_shape=self.linear_1.out_features, eps=1e-6)
101
+ self.relu = nn.ReLU()
102
+ self.sigmoid = nn.Sigmoid()
103
+
104
+ def forward(self, frame_features):
105
+ """ Produce frames importance scores from the frame features, using the PGL-SUM model.
106
+
107
+ :param torch.Tensor frame_features: Tensor of shape [T, input_size] containing the frame features produced by
108
+ using the pool5 layer of GoogleNet.
109
+ :return: A tuple of:
110
+ y: Tensor with shape [1, T] containing the frames importance scores in [0, 1].
111
+ attn_weights: Tensor with shape [T, T] containing the attention weights.
112
+ """
113
+ residual = frame_features
114
+ weighted_value, attn_weights = self.attention(frame_features)
115
+ y = weighted_value + residual
116
+ y = self.drop(y)
117
+ y = self.norm_y(y)
118
+
119
+ # 2-layer NN (Regressor Network)
120
+ y = self.linear_1(y)
121
+ y = self.relu(y)
122
+ y = self.drop(y)
123
+ y = self.norm_linear(y)
124
+
125
+ y = self.linear_2(y)
126
+ y = self.sigmoid(y)
127
+ y = y.view(1, -1)
128
+
129
+ return y, attn_weights
130
+
131
+
132
+ if __name__ == '__main__':
133
+ pass
134
+ """Uncomment for a quick proof of concept
135
+ model = PGL_SUM(input_size=256, output_size=256, num_segments=3, fusion="Add").cuda()
136
+ _input = torch.randn(500, 256).cuda() # [seq_len, hidden_size]
137
+ output, weights = model(_input)
138
+ print(f"Output shape: {output.shape}\tattention shape: {weights.shape}")
139
+ """
main.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from routes import summarize
4
+ from fastapi.staticfiles import StaticFiles
5
+ from fastapi.responses import JSONResponse
6
+ import os
7
+
8
+ app = FastAPI()
9
+ app.include_router(summarize.router)
10
+
11
+ # ✅ Root route to avoid 404 on /
12
+ @app.get("/")
13
+ def read_root():
14
+ return JSONResponse(content={"message": "Video summarization API is running"})
15
+
16
+ # CORS
17
+ app.add_middleware(
18
+ CORSMiddleware,
19
+ allow_origins=["*"],
20
+ allow_methods=["*"],
21
+ allow_headers=["*"],
22
+ )
23
+
24
+ # Mount static folder
25
+ static_dir = os.path.join("backend", "static")
26
+ app.mount("/static", StaticFiles(directory=static_dir), name="static")
routes/__init__.py ADDED
File without changes
routes/summarize.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, UploadFile, File
2
+ from fastapi.responses import JSONResponse
3
+ from utils.file_utils import save_uploaded_file
4
+ from services.extractor import extract_features
5
+ from services.model_loader import load_model
6
+ from services.summarizer import get_scores, get_selected_indices, save_summary_video
7
+ from config import UPLOAD_DIR, OUTPUT_DIR
8
+
9
+ router = APIRouter()
10
+
11
+ @router.post("/summarize")
12
+ def summarize_video(video: UploadFile = File(...)):
13
+ if not video.filename.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
14
+ return JSONResponse(content={"error": "Unsupported file format"}, status_code=400)
15
+
16
+ video_path = save_uploaded_file(video, UPLOAD_DIR)
17
+ features, picks = extract_features(video_path)
18
+ model = load_model("backend/Model/epoch-199.pkl")
19
+ scores = get_scores(model, features)
20
+ selected = get_selected_indices(scores, picks)
21
+ output_path = f"{OUTPUT_DIR}/summary_{video.filename}"
22
+ save_summary_video(video_path, selected, output_path)
23
+ summary_url = f"/static/outputs/summary_{video.filename}"
24
+
25
+ return JSONResponse(content={
26
+ "message": "Summarization complete",
27
+ "summary_video_url": summary_url
28
+ })
services/__init__.py ADDED
File without changes
services/extractor.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ from torchvision import models, transforms
6
+ from config import DEVICE, FRAME_RATE
7
+
8
+ # Load GoogLeNet once
9
+ from torchvision.models import GoogLeNet_Weights
10
+ weights = GoogLeNet_Weights.DEFAULT
11
+ googlenet = models.googlenet(weights=weights).to(DEVICE).eval()
12
+
13
+ feature_extractor = torch.nn.Sequential(
14
+ googlenet.conv1,
15
+ googlenet.maxpool1,
16
+ googlenet.conv2,
17
+ googlenet.conv3,
18
+ googlenet.maxpool2,
19
+ googlenet.inception3a,
20
+ googlenet.inception3b,
21
+ googlenet.maxpool3,
22
+ googlenet.inception4a,
23
+ googlenet.inception4b,
24
+ googlenet.inception4c,
25
+ googlenet.inception4d,
26
+ googlenet.inception4e,
27
+ googlenet.maxpool4,
28
+ googlenet.inception5a,
29
+ googlenet.inception5b,
30
+ googlenet.avgpool,
31
+ torch.nn.Flatten()
32
+ )
33
+
34
+ transform = transforms.Compose([
35
+ transforms.Resize((224, 224)),
36
+ transforms.ToTensor(),
37
+ transforms.Normalize(
38
+ mean=[0.485, 0.456, 0.406],
39
+ std=[0.229, 0.224, 0.225]
40
+ )
41
+ ])
42
+
43
+ def extract_features(video_path):
44
+ cap = cv2.VideoCapture(video_path)
45
+ fps = cap.get(cv2.CAP_PROP_FPS)
46
+ picks, frames = [], []
47
+ count = 0
48
+
49
+ while cap.isOpened():
50
+ ret, frame = cap.read()
51
+ if not ret:
52
+ break
53
+ if int(count % round(fps // FRAME_RATE)) == 0:
54
+ image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
55
+ input_tensor = transform(image).unsqueeze(0).to(DEVICE)
56
+ with torch.no_grad():
57
+ feature = feature_extractor(input_tensor).squeeze(0).cpu().numpy()
58
+ frames.append(feature)
59
+ picks.append(count)
60
+ count += 1
61
+ cap.release()
62
+ return np.stack(frames), picks
services/model_loader.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ import os
4
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
5
+ from layers.summarizer import PGL_SUM
6
+ from config import DEVICE
7
+
8
+ def load_model(weights_path):
9
+ model = PGL_SUM(
10
+ input_size=1024,
11
+ output_size=1024,
12
+ num_segments=4,
13
+ heads=8,
14
+ fusion="add",
15
+ pos_enc="absolute"
16
+ ).to(DEVICE)
17
+ model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
18
+ model.eval
19
+ return model
services/summarizer.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ from config import SCORE_THRESHOLD
4
+
5
+ def get_scores(model, features):
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+ model = model.to(device)
8
+ with torch.no_grad():
9
+ features_tensor = torch.tensor(features, dtype=torch.float32).to(device)
10
+ scores, _ = model(features_tensor)
11
+ return scores.squeeze().cpu().numpy()
12
+
13
+
14
+ def get_selected_indices(scores, picks, threshold=SCORE_THRESHOLD):
15
+ return [picks[i] for i, score in enumerate(scores) if score >= threshold]
16
+
17
+ import subprocess
18
+ import os
19
+
20
+ def save_summary_video(video_path, selected_indices, output_path, fps=15):
21
+ import cv2
22
+
23
+ cap = cv2.VideoCapture(video_path)
24
+ selected = set(selected_indices)
25
+ frame_id = 0
26
+ frames = {}
27
+
28
+ while cap.isOpened():
29
+ ret, frame = cap.read()
30
+ if not ret:
31
+ break
32
+ if frame_id in selected:
33
+ frames[frame_id] = frame
34
+ frame_id += 1
35
+ cap.release()
36
+
37
+ if not frames:
38
+ print("No frames selected.")
39
+ return
40
+
41
+ h, w, _ = list(frames.values())[0].shape
42
+
43
+ # 1️⃣ Save raw video first
44
+ raw_output_path = output_path.replace(".mp4", "_raw.mp4")
45
+ writer = cv2.VideoWriter(raw_output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
46
+ for fid in sorted(frames.keys()):
47
+ writer.write(frames[fid])
48
+ writer.release()
49
+
50
+ # 2️⃣ Use FFmpeg to fix video (browser-compatible)
51
+ try:
52
+ subprocess.run([
53
+ "ffmpeg",
54
+ "-y", # overwrite if file exists
55
+ "-i", raw_output_path,
56
+ "-vcodec", "libx264",
57
+ "-acodec", "aac",
58
+ output_path
59
+ ], check=True)
60
+ os.remove(raw_output_path) # optional: remove raw file
61
+ print(f"✅ FFmpeg re-encoded video saved to: {output_path}")
62
+ except subprocess.CalledProcessError as e:
63
+ print("❌ FFmpeg failed:", e)
64
+ print("⚠️ Using raw video instead.")
65
+ os.rename(raw_output_path, output_path)
static/uploads/77ea55af6d744160a5c7e8440b294bb6_Paris Saint-Germain vs Atlético de Madrid Highlights | FIFA Club World Cup 2025.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99183f4ca670013008f6a45943bf878532a1db1ad0753f671d289f55f45dac93
3
+ size 22890415
static/uploads/84daab3df51f418ebff312b2ed129bc1_Paris Saint-Germain vs Atlético de Madrid Highlights | FIFA Club World Cup 2025.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99183f4ca670013008f6a45943bf878532a1db1ad0753f671d289f55f45dac93
3
+ size 22890415
static/uploads/8ba4aec007f5404db2e9ac9570e59ca6_Paris Saint-Germain vs Atlético de Madrid Highlights | FIFA Club World Cup 2025.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99183f4ca670013008f6a45943bf878532a1db1ad0753f671d289f55f45dac93
3
+ size 22890415
static/uploads/b0b93f4bcdcb4662865bb4dc26c1b243_Paris Saint-Germain vs Atlético de Madrid Highlights | FIFA Club World Cup 2025.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99183f4ca670013008f6a45943bf878532a1db1ad0753f671d289f55f45dac93
3
+ size 22890415
static/uploads/e051610a8a634fd9a9de3c016d38ce73_Paris Saint-Germain vs Atlético de Madrid Highlights | FIFA Club World Cup 2025.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99183f4ca670013008f6a45943bf878532a1db1ad0753f671d289f55f45dac93
3
+ size 22890415
utils/__init__.py ADDED
File without changes
utils/file_utils.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from uuid import uuid4
3
+
4
+ def save_uploaded_file(uploaded_file, upload_dir):
5
+ os.makedirs(upload_dir, exist_ok=True)
6
+ filename = f"{uuid4().hex}_{uploaded_file.filename}"
7
+ filepath = os.path.join(upload_dir, filename)
8
+ with open(filepath, "wb") as f:
9
+ f.write(uploaded_file.file.read())
10
+ return filepath