File size: 853 Bytes
d2542a3 1579b70 d2542a3 bf5699f a4d2baa d2542a3 1579b70 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import torch
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
from layers.summarizer import PGL_SUM
from config import DEVICE
from tqdm import tqdm
def load_model(weights_path):
model = PGL_SUM(
input_size=1024,
output_size=1024,
num_segments=4,
heads=8,
fusion="add",
pos_enc="absolute"
).to(DEVICE)
model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
model.eval()
return model
def batch_inference(model, input, batch_size=128):
model.eval()
output = []
with torch.no_grad():
for i in tqdm(range(0, input.size(0), batch_size)):
batch = input[i:i + batch_size].to(DEVICE)
out = model(batch)
output.append(out.cpu())
return torch.cat(output, dim=0)
|