import torch | |
import sys | |
import os | |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) | |
from layers.summarizer import PGL_SUM | |
from config import DEVICE | |
def load_model(weights_path): | |
model = PGL_SUM( | |
input_size=1024, | |
output_size=1024, | |
num_segments=4, | |
heads=8, | |
fusion="add", | |
pos_enc="absolute" | |
).to(DEVICE) | |
model.load_state_dict(torch.load(weights_path, map_location=DEVICE)) | |
model.eval() | |
return model | |