Salimshakeel's picture
use original weights
bf5699f
raw
history blame
501 Bytes
import torch
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
from layers.summarizer import PGL_SUM
from config import DEVICE
def load_model(weights_path):
model = PGL_SUM(
input_size=1024,
output_size=1024,
num_segments=4,
heads=8,
fusion="add",
pos_enc="absolute"
).to(DEVICE)
model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
model.eval()
return model