File size: 1,427 Bytes
a2a33f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
873cef2
 
 
a2a33f3
 
 
 
873cef2
 
 
 
a2a33f3
 
 
 
 
 
 
 
873cef2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from just_time_windows.Actor.actor import Actor


def train_batch(actor, baseline, batch, optimizer, gradient_clipping=True, comparison_model=None, compute_cost_ratio=True):
    device = actor.device

    actor.train_mode()
    actor.train()
    actor_output = actor(batch)
    actor_cost, log_probs = actor_output['total_time'], actor_output['log_probs']

    with torch.no_grad():
        baseline.greedy_search()
        baseline_output = baseline(batch)
        baseline_cost = baseline_output['total_time']

    loss = ((actor_cost - baseline_cost).detach() * log_probs).mean()

    optimizer.zero_grad()
    loss.backward()

    if gradient_clipping:
        for group in optimizer.param_groups:
            params = [p for p in group['params'] if p.grad is not None]
            if params:
                clip_grad_norm_(params, max_norm=1, norm_type=2)

    optimizer.step()

    if compute_cost_ratio:
        if comparison_model is None:
            normalize = actor.apply_normalization
            comparison_model = Actor(model=None, num_neighbors_action=1, normalize=normalize, device=device)

        with torch.no_grad():
            comp_output = comparison_model(batch)
            comp_cost = comp_output['total_time']

        a = comp_cost.sum().item()
        b = actor_cost.sum().item()
        return b / a

    return None