import torch class Graph(object): def __init__(self, graph_data, device='cpu'): self.device = device # Load and move data to the correct device self.start_time = graph_data['start_times'].to(device) self.end_time = graph_data['end_times'].to(device) self.depot = graph_data['depot'].to(device) self.node_positions = graph_data['node_vector'].to(device) self.distance_matrix = graph_data['distance_matrix'].to(device) self.time_matrix = graph_data['time_matrix'].to(device) self.num_nodes = self.distance_matrix.shape[1] self.batch_size = self.distance_matrix.shape[0] self.correct_depot_features() self.time_window_compatibility = self.compute_time_window_compatibility() self.max_dist = self.distance_matrix.max() self.max_drive_time = self.time_matrix.max() def correct_depot_features(self): # Remove time windows for depot nodes self.start_time = self.start_time.clone() * (1 - self.depot) self.end_time = self.end_time.clone() * (1 - self.depot) def construct_vector(self): # Concatenate node features for model input L = [self.node_positions, self.start_time, self.end_time, self.depot] self.vector = torch.cat(L, dim=2) return self.vector def get_drive_times(self, from_nodes, to_nodes): """ Get drive times between from_nodes and to_nodes. """ num_elements = from_nodes.shape[1] assert num_elements == to_nodes.shape[1] # Extract relevant entries from the time matrix ind_1 = from_nodes.unsqueeze(2).expand(-1, -1, self.num_nodes) dist = torch.gather(self.time_matrix, dim=1, index=ind_1) ind_2 = to_nodes.unsqueeze(2) drive_times = torch.gather(dist, dim=2, index=ind_2) return drive_times def get_distances(self, from_nodes, to_nodes): """ Get Euclidean distances between from_nodes and to_nodes. """ num_elements = from_nodes.shape[1] assert num_elements == to_nodes.shape[1] # Extract relevant entries from the distance matrix ind_1 = from_nodes.unsqueeze(2).expand(-1, -1, self.num_nodes) dist = torch.gather(self.distance_matrix, dim=1, index=ind_1) ind_2 = to_nodes.unsqueeze(2) distances = torch.gather(dist, dim=2, index=ind_2) return distances def compute_time_window_compatibility(self): """ Determine if traveling from node i to node j respects time window constraints: i.e., start_time[i] + drive_time[i][j] <= end_time[j] """ x = self.start_time.reshape(self.batch_size, self.num_nodes, 1).repeat(1, 1, self.num_nodes) y = self.end_time.reshape(self.batch_size, 1, self.num_nodes).repeat(1, self.num_nodes, 1) time_mask = (x + self.time_matrix <= y).float() return time_mask def to(self, device): """ Move all tensors to the specified device. """ self.device = device for attr in ['start_time', 'end_time', 'depot', 'node_positions', 'distance_matrix', 'time_matrix']: setattr(self, attr, getattr(self, attr).to(device)) self.time_window_compatibility = self.time_window_compatibility.to(device)