| import torch | |
| import sys | |
| import os | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) | |
| from layers.summarizer import PGL_SUM | |
| from config import DEVICE | |
| def load_model(weights_path): | |
| model = PGL_SUM( | |
| input_size=1024, | |
| output_size=1024, | |
| num_segments=4, | |
| heads=8, | |
| fusion="add", | |
| pos_enc="absolute" | |
| ).to(DEVICE) | |
| # model.load_state_dict(torch.load(weights_path, map_location=DEVICE)) | |
| model.eval() | |
| return model | |