File size: 3,405 Bytes
a3d45a8
 
 
 
 
 
 
 
 
 
 
8a08dc9
 
a3d45a8
8a08dc9
 
 
a3d45a8
8a08dc9
 
 
a3d45a8
 
8a08dc9
a3d45a8
 
8a08dc9
 
 
a3d45a8
8a08dc9
 
 
a3d45a8
8a08dc9
 
 
a3d45a8
8a08dc9
 
a3d45a8
8a08dc9
 
 
a3d45a8
8a08dc9
a3d45a8
8a08dc9
 
 
a3d45a8
 
8a08dc9
 
 
a3d45a8
8a08dc9
 
 
a3d45a8
8a08dc9
 
 
a3d45a8
8a08dc9
 
a3d45a8
8a08dc9
 
 
a3d45a8
8a08dc9
a3d45a8
8a08dc9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch


class Normalization(object):
    def __init__(self, actor, normalize_position=False, device='cpu'):
        self.normalize_position = normalize_position
        self.device = device

        graph = actor.graph
        fleet = actor.fleet

        batch_size = graph.distance_matrix.size(0)
        num_nodes = graph.distance_matrix.size(1)

        # Normalize scale factors
        self.greatest_drive_time = graph.time_matrix.view(batch_size, -1).max(dim=1)[0]  # (B,)
        self.greatest_distance = graph.distance_matrix.view(batch_size, -1).max(dim=1)[0]

        fleet_start_flat = fleet.start_time.view(batch_size, -1)
        graph_start_flat = graph.start_time.view(batch_size, -1)
        self.earliest_start_time = torch.cat([fleet_start_flat, graph_start_flat], dim=1).min(dim=1)[0]

        self.mean_positions = graph.node_positions.mean(dim=1)
        self.std_positions = graph.node_positions.std(dim=1)

    def normalize(self, actor):
        batch_size = actor.graph.distance_matrix.size(0)
        num_nodes = actor.graph.distance_matrix.size(1)
        num_cars = actor.fleet.start_time.size(1)

        # Normalize graph matrices
        actor.graph.distance_matrix /= self.greatest_distance.view(batch_size, 1, 1)
        actor.graph.time_matrix /= self.greatest_drive_time.view(batch_size, 1, 1)

        # Normalize graph time windows
        st_offset = self.earliest_start_time.view(batch_size, 1, 1)
        st_scale = self.greatest_drive_time.view(batch_size, 1, 1)

        actor.graph.start_time = (actor.graph.start_time - st_offset) / st_scale
        actor.graph.end_time = (actor.graph.end_time - st_offset) / st_scale

        # Normalize fleet times
        actor.fleet.late_time /= self.greatest_drive_time.view(batch_size, 1, 1)
        actor.fleet.arrival_times /= self.greatest_drive_time.view(batch_size, 1, 1)

        # Normalize positions (optional)
        if self.normalize_position:
            mean_pos = self.mean_positions.view(batch_size, 1, -1)
            std_pos = self.std_positions.view(batch_size, 1, -1)
            actor.graph.node_positions = (actor.graph.node_positions - mean_pos) / std_pos

    def inverse_normalize(self, actor):
        batch_size = actor.graph.distance_matrix.size(0)
        num_nodes = actor.graph.distance_matrix.size(1)
        num_cars = actor.fleet.start_time.size(1)

        # Inverse graph matrices
        actor.graph.distance_matrix *= self.greatest_distance.view(batch_size, 1, 1)
        actor.graph.time_matrix *= self.greatest_drive_time.view(batch_size, 1, 1)

        # Inverse graph time windows
        st_offset = self.earliest_start_time.view(batch_size, 1, 1)
        st_scale = self.greatest_drive_time.view(batch_size, 1, 1)

        actor.graph.start_time = actor.graph.start_time * st_scale + st_offset
        actor.graph.end_time = actor.graph.end_time * st_scale + st_offset

        # Inverse fleet times
        actor.fleet.late_time *= self.greatest_drive_time.view(batch_size, 1, 1)
        actor.fleet.arrival_times *= self.greatest_drive_time.view(batch_size, 1, 1)

        # Inverse normalization of positions
        if self.normalize_position:
            mean_pos = self.mean_positions.view(batch_size, 1, -1)
            std_pos = self.std_positions.view(batch_size, 1, -1)
            actor.graph.node_positions = actor.graph.node_positions * std_pos + mean_pos