File size: 853 Bytes
d2542a3
 
 
 
 
 
1579b70
d2542a3
 
 
 
 
 
 
 
 
 
bf5699f
a4d2baa
d2542a3
1579b70
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
from layers.summarizer import PGL_SUM
from config import DEVICE
from tqdm import tqdm

def load_model(weights_path):
    model = PGL_SUM(
        input_size=1024,
        output_size=1024,
        num_segments=4,
        heads=8,
        fusion="add",
        pos_enc="absolute"
    ).to(DEVICE)
    model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
    model.eval()
    return model

def batch_inference(model, input, batch_size=128):
    model.eval()
    output = []
    with torch.no_grad():
        for i in tqdm(range(0, input.size(0), batch_size)):
            batch = input[i:i + batch_size].to(DEVICE)
            out = model(batch)
            output.append(out.cpu())
    return torch.cat(output, dim=0)