File size: 3,405 Bytes
a3d45a8 8a08dc9 a3d45a8 8a08dc9 a3d45a8 8a08dc9 a3d45a8 8a08dc9 a3d45a8 8a08dc9 a3d45a8 8a08dc9 a3d45a8 8a08dc9 a3d45a8 8a08dc9 a3d45a8 8a08dc9 a3d45a8 8a08dc9 a3d45a8 8a08dc9 a3d45a8 8a08dc9 a3d45a8 8a08dc9 a3d45a8 8a08dc9 a3d45a8 8a08dc9 a3d45a8 8a08dc9 a3d45a8 8a08dc9 a3d45a8 8a08dc9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import torch
class Normalization(object):
def __init__(self, actor, normalize_position=False, device='cpu'):
self.normalize_position = normalize_position
self.device = device
graph = actor.graph
fleet = actor.fleet
batch_size = graph.distance_matrix.size(0)
num_nodes = graph.distance_matrix.size(1)
# Normalize scale factors
self.greatest_drive_time = graph.time_matrix.view(batch_size, -1).max(dim=1)[0] # (B,)
self.greatest_distance = graph.distance_matrix.view(batch_size, -1).max(dim=1)[0]
fleet_start_flat = fleet.start_time.view(batch_size, -1)
graph_start_flat = graph.start_time.view(batch_size, -1)
self.earliest_start_time = torch.cat([fleet_start_flat, graph_start_flat], dim=1).min(dim=1)[0]
self.mean_positions = graph.node_positions.mean(dim=1)
self.std_positions = graph.node_positions.std(dim=1)
def normalize(self, actor):
batch_size = actor.graph.distance_matrix.size(0)
num_nodes = actor.graph.distance_matrix.size(1)
num_cars = actor.fleet.start_time.size(1)
# Normalize graph matrices
actor.graph.distance_matrix /= self.greatest_distance.view(batch_size, 1, 1)
actor.graph.time_matrix /= self.greatest_drive_time.view(batch_size, 1, 1)
# Normalize graph time windows
st_offset = self.earliest_start_time.view(batch_size, 1, 1)
st_scale = self.greatest_drive_time.view(batch_size, 1, 1)
actor.graph.start_time = (actor.graph.start_time - st_offset) / st_scale
actor.graph.end_time = (actor.graph.end_time - st_offset) / st_scale
# Normalize fleet times
actor.fleet.late_time /= self.greatest_drive_time.view(batch_size, 1, 1)
actor.fleet.arrival_times /= self.greatest_drive_time.view(batch_size, 1, 1)
# Normalize positions (optional)
if self.normalize_position:
mean_pos = self.mean_positions.view(batch_size, 1, -1)
std_pos = self.std_positions.view(batch_size, 1, -1)
actor.graph.node_positions = (actor.graph.node_positions - mean_pos) / std_pos
def inverse_normalize(self, actor):
batch_size = actor.graph.distance_matrix.size(0)
num_nodes = actor.graph.distance_matrix.size(1)
num_cars = actor.fleet.start_time.size(1)
# Inverse graph matrices
actor.graph.distance_matrix *= self.greatest_distance.view(batch_size, 1, 1)
actor.graph.time_matrix *= self.greatest_drive_time.view(batch_size, 1, 1)
# Inverse graph time windows
st_offset = self.earliest_start_time.view(batch_size, 1, 1)
st_scale = self.greatest_drive_time.view(batch_size, 1, 1)
actor.graph.start_time = actor.graph.start_time * st_scale + st_offset
actor.graph.end_time = actor.graph.end_time * st_scale + st_offset
# Inverse fleet times
actor.fleet.late_time *= self.greatest_drive_time.view(batch_size, 1, 1)
actor.fleet.arrival_times *= self.greatest_drive_time.view(batch_size, 1, 1)
# Inverse normalization of positions
if self.normalize_position:
mean_pos = self.mean_positions.view(batch_size, 1, -1)
std_pos = self.std_positions.view(batch_size, 1, -1)
actor.graph.node_positions = actor.graph.node_positions * std_pos + mean_pos
|