|
import torch |
|
|
|
|
|
class Normalization(object): |
|
def __init__(self, actor, normalize_position=False, device='cpu'): |
|
self.normalize_position = normalize_position |
|
self.device = device |
|
|
|
graph = actor.graph |
|
fleet = actor.fleet |
|
|
|
batch_size = graph.distance_matrix.size(0) |
|
num_nodes = graph.distance_matrix.size(1) |
|
|
|
|
|
self.greatest_drive_time = graph.time_matrix.view(batch_size, -1).max(dim=1)[0] |
|
self.greatest_distance = graph.distance_matrix.view(batch_size, -1).max(dim=1)[0] |
|
|
|
fleet_start_flat = fleet.start_time.view(batch_size, -1) |
|
graph_start_flat = graph.start_time.view(batch_size, -1) |
|
self.earliest_start_time = torch.cat([fleet_start_flat, graph_start_flat], dim=1).min(dim=1)[0] |
|
|
|
self.mean_positions = graph.node_positions.mean(dim=1) |
|
self.std_positions = graph.node_positions.std(dim=1) |
|
|
|
def normalize(self, actor): |
|
batch_size = actor.graph.distance_matrix.size(0) |
|
num_nodes = actor.graph.distance_matrix.size(1) |
|
num_cars = actor.fleet.start_time.size(1) |
|
|
|
|
|
actor.graph.distance_matrix /= self.greatest_distance.view(batch_size, 1, 1) |
|
actor.graph.time_matrix /= self.greatest_drive_time.view(batch_size, 1, 1) |
|
|
|
|
|
st_offset = self.earliest_start_time.view(batch_size, 1, 1) |
|
st_scale = self.greatest_drive_time.view(batch_size, 1, 1) |
|
|
|
actor.graph.start_time = (actor.graph.start_time - st_offset) / st_scale |
|
actor.graph.end_time = (actor.graph.end_time - st_offset) / st_scale |
|
|
|
|
|
actor.fleet.late_time /= self.greatest_drive_time.view(batch_size, 1, 1) |
|
actor.fleet.arrival_times /= self.greatest_drive_time.view(batch_size, 1, 1) |
|
|
|
|
|
if self.normalize_position: |
|
mean_pos = self.mean_positions.view(batch_size, 1, -1) |
|
std_pos = self.std_positions.view(batch_size, 1, -1) |
|
actor.graph.node_positions = (actor.graph.node_positions - mean_pos) / std_pos |
|
|
|
def inverse_normalize(self, actor): |
|
batch_size = actor.graph.distance_matrix.size(0) |
|
num_nodes = actor.graph.distance_matrix.size(1) |
|
num_cars = actor.fleet.start_time.size(1) |
|
|
|
|
|
actor.graph.distance_matrix *= self.greatest_distance.view(batch_size, 1, 1) |
|
actor.graph.time_matrix *= self.greatest_drive_time.view(batch_size, 1, 1) |
|
|
|
|
|
st_offset = self.earliest_start_time.view(batch_size, 1, 1) |
|
st_scale = self.greatest_drive_time.view(batch_size, 1, 1) |
|
|
|
actor.graph.start_time = actor.graph.start_time * st_scale + st_offset |
|
actor.graph.end_time = actor.graph.end_time * st_scale + st_offset |
|
|
|
|
|
actor.fleet.late_time *= self.greatest_drive_time.view(batch_size, 1, 1) |
|
actor.fleet.arrival_times *= self.greatest_drive_time.view(batch_size, 1, 1) |
|
|
|
|
|
if self.normalize_position: |
|
mean_pos = self.mean_positions.view(batch_size, 1, -1) |
|
std_pos = self.std_positions.view(batch_size, 1, -1) |
|
actor.graph.node_positions = actor.graph.node_positions * std_pos + mean_pos |
|
|