Spaces:
Runtime error
Runtime error
import torch | |
class Graph(object): | |
def __init__(self, graph_data, device='cpu'): | |
self.device = device | |
# Load and move data to the correct device | |
self.start_time = graph_data['start_times'].to(device) | |
self.end_time = graph_data['end_times'].to(device) | |
self.depot = graph_data['depot'].to(device) | |
self.node_positions = graph_data['node_vector'].to(device) | |
self.distance_matrix = graph_data['distance_matrix'].to(device) | |
self.time_matrix = graph_data['time_matrix'].to(device) | |
self.num_nodes = self.distance_matrix.shape[1] | |
self.batch_size = self.distance_matrix.shape[0] | |
self.correct_depot_features() | |
self.time_window_compatibility = self.compute_time_window_compatibility() | |
self.max_dist = self.distance_matrix.max() | |
self.max_drive_time = self.time_matrix.max() | |
def correct_depot_features(self): | |
# Remove time windows for depot nodes | |
self.start_time = self.start_time.clone() * (1 - self.depot) | |
self.end_time = self.end_time.clone() * (1 - self.depot) | |
def construct_vector(self): | |
# Concatenate node features for model input | |
L = [self.node_positions, self.start_time, self.end_time, self.depot] | |
self.vector = torch.cat(L, dim=2) | |
return self.vector | |
def get_drive_times(self, from_nodes, to_nodes): | |
""" | |
Get drive times between from_nodes and to_nodes. | |
""" | |
num_elements = from_nodes.shape[1] | |
assert num_elements == to_nodes.shape[1] | |
# Extract relevant entries from the time matrix | |
ind_1 = from_nodes.unsqueeze(2).expand(-1, -1, self.num_nodes) | |
dist = torch.gather(self.time_matrix, dim=1, index=ind_1) | |
ind_2 = to_nodes.unsqueeze(2) | |
drive_times = torch.gather(dist, dim=2, index=ind_2) | |
return drive_times | |
def get_distances(self, from_nodes, to_nodes): | |
""" | |
Get Euclidean distances between from_nodes and to_nodes. | |
""" | |
num_elements = from_nodes.shape[1] | |
assert num_elements == to_nodes.shape[1] | |
# Extract relevant entries from the distance matrix | |
ind_1 = from_nodes.unsqueeze(2).expand(-1, -1, self.num_nodes) | |
dist = torch.gather(self.distance_matrix, dim=1, index=ind_1) | |
ind_2 = to_nodes.unsqueeze(2) | |
distances = torch.gather(dist, dim=2, index=ind_2) | |
return distances | |
def compute_time_window_compatibility(self): | |
""" | |
Determine if traveling from node i to node j respects time window constraints: | |
i.e., start_time[i] + drive_time[i][j] <= end_time[j] | |
""" | |
x = self.start_time.reshape(self.batch_size, self.num_nodes, 1).repeat(1, 1, self.num_nodes) | |
y = self.end_time.reshape(self.batch_size, 1, self.num_nodes).repeat(1, self.num_nodes, 1) | |
time_mask = (x + self.time_matrix <= y).float() | |
return time_mask | |
def to(self, device): | |
""" | |
Move all tensors to the specified device. | |
""" | |
self.device = device | |
for attr in ['start_time', 'end_time', 'depot', 'node_positions', 'distance_matrix', 'time_matrix']: | |
setattr(self, attr, getattr(self, attr).to(device)) | |
self.time_window_compatibility = self.time_window_compatibility.to(device) | |