|
import torch |
|
|
|
|
|
class Graph(object): |
|
def __init__(self, graph_data, device='cpu'): |
|
self.device = device |
|
|
|
|
|
self.start_time = graph_data['start_times'].to(device) |
|
self.end_time = graph_data['end_times'].to(device) |
|
self.depot = graph_data['depot'].to(device) |
|
self.node_positions = graph_data['node_vector'].to(device) |
|
self.distance_matrix = graph_data['distance_matrix'].to(device) |
|
self.time_matrix = graph_data['time_matrix'].to(device) |
|
|
|
self.num_nodes = self.distance_matrix.shape[1] |
|
self.batch_size = self.distance_matrix.shape[0] |
|
|
|
self.correct_depot_features() |
|
|
|
self.time_window_compatibility = self.compute_time_window_compatibility() |
|
self.max_dist = self.distance_matrix.max() |
|
self.max_drive_time = self.time_matrix.max() |
|
|
|
def correct_depot_features(self): |
|
|
|
self.start_time = self.start_time.clone() * (1 - self.depot) |
|
self.end_time = self.end_time.clone() * (1 - self.depot) |
|
|
|
def construct_vector(self): |
|
|
|
L = [self.node_positions, self.start_time, self.end_time, self.depot] |
|
self.vector = torch.cat(L, dim=2) |
|
return self.vector |
|
|
|
def get_drive_times(self, from_nodes, to_nodes): |
|
""" |
|
Get drive times between from_nodes and to_nodes. |
|
""" |
|
num_elements = from_nodes.shape[1] |
|
assert num_elements == to_nodes.shape[1] |
|
|
|
|
|
ind_1 = from_nodes.unsqueeze(2).expand(-1, -1, self.num_nodes) |
|
dist = torch.gather(self.time_matrix, dim=1, index=ind_1) |
|
ind_2 = to_nodes.unsqueeze(2) |
|
drive_times = torch.gather(dist, dim=2, index=ind_2) |
|
return drive_times |
|
|
|
def get_distances(self, from_nodes, to_nodes): |
|
""" |
|
Get Euclidean distances between from_nodes and to_nodes. |
|
""" |
|
num_elements = from_nodes.shape[1] |
|
assert num_elements == to_nodes.shape[1] |
|
|
|
|
|
ind_1 = from_nodes.unsqueeze(2).expand(-1, -1, self.num_nodes) |
|
dist = torch.gather(self.distance_matrix, dim=1, index=ind_1) |
|
ind_2 = to_nodes.unsqueeze(2) |
|
distances = torch.gather(dist, dim=2, index=ind_2) |
|
return distances |
|
|
|
def compute_time_window_compatibility(self): |
|
""" |
|
Determine if traveling from node i to node j respects time window constraints: |
|
i.e., start_time[i] + drive_time[i][j] <= end_time[j] |
|
""" |
|
x = self.start_time.reshape(self.batch_size, self.num_nodes, 1).repeat(1, 1, self.num_nodes) |
|
y = self.end_time.reshape(self.batch_size, 1, self.num_nodes).repeat(1, self.num_nodes, 1) |
|
time_mask = (x + self.time_matrix <= y).float() |
|
return time_mask |
|
|
|
def to(self, device): |
|
""" |
|
Move all tensors to the specified device. |
|
""" |
|
self.device = device |
|
for attr in ['start_time', 'end_time', 'depot', 'node_positions', 'distance_matrix', 'time_matrix']: |
|
setattr(self, attr, getattr(self, attr).to(device)) |
|
self.time_window_compatibility = self.time_window_compatibility.to(device) |
|
|