File size: 501 Bytes
d2542a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf5699f
a4d2baa
d2542a3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
from layers.summarizer import PGL_SUM
from config import DEVICE

def load_model(weights_path):
    model = PGL_SUM(
        input_size=1024,
        output_size=1024,
        num_segments=4,
        heads=8,
        fusion="add",
        pos_enc="absolute"
    ).to(DEVICE)
    model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
    model.eval()
    return model