a-ragab-h-m's picture
Update Actor/graph.py
f2809ce verified
import torch
class Graph(object):
def __init__(self, graph_data, device='cpu'):
self.device = device
# Load and move data to the correct device
self.start_time = graph_data['start_times'].to(device)
self.end_time = graph_data['end_times'].to(device)
self.depot = graph_data['depot'].to(device)
self.node_positions = graph_data['node_vector'].to(device)
self.distance_matrix = graph_data['distance_matrix'].to(device)
self.time_matrix = graph_data['time_matrix'].to(device)
self.num_nodes = self.distance_matrix.shape[1]
self.batch_size = self.distance_matrix.shape[0]
self.correct_depot_features()
self.time_window_compatibility = self.compute_time_window_compatibility()
self.max_dist = self.distance_matrix.max()
self.max_drive_time = self.time_matrix.max()
def correct_depot_features(self):
# Remove time windows for depot nodes
self.start_time = self.start_time.clone() * (1 - self.depot)
self.end_time = self.end_time.clone() * (1 - self.depot)
def construct_vector(self):
# Concatenate node features for model input
L = [self.node_positions, self.start_time, self.end_time, self.depot]
self.vector = torch.cat(L, dim=2)
return self.vector
def get_drive_times(self, from_nodes, to_nodes):
"""
Get drive times between from_nodes and to_nodes.
"""
num_elements = from_nodes.shape[1]
assert num_elements == to_nodes.shape[1]
# Extract relevant entries from the time matrix
ind_1 = from_nodes.unsqueeze(2).expand(-1, -1, self.num_nodes)
dist = torch.gather(self.time_matrix, dim=1, index=ind_1)
ind_2 = to_nodes.unsqueeze(2)
drive_times = torch.gather(dist, dim=2, index=ind_2)
return drive_times
def get_distances(self, from_nodes, to_nodes):
"""
Get Euclidean distances between from_nodes and to_nodes.
"""
num_elements = from_nodes.shape[1]
assert num_elements == to_nodes.shape[1]
# Extract relevant entries from the distance matrix
ind_1 = from_nodes.unsqueeze(2).expand(-1, -1, self.num_nodes)
dist = torch.gather(self.distance_matrix, dim=1, index=ind_1)
ind_2 = to_nodes.unsqueeze(2)
distances = torch.gather(dist, dim=2, index=ind_2)
return distances
def compute_time_window_compatibility(self):
"""
Determine if traveling from node i to node j respects time window constraints:
i.e., start_time[i] + drive_time[i][j] <= end_time[j]
"""
x = self.start_time.reshape(self.batch_size, self.num_nodes, 1).repeat(1, 1, self.num_nodes)
y = self.end_time.reshape(self.batch_size, 1, self.num_nodes).repeat(1, self.num_nodes, 1)
time_mask = (x + self.time_matrix <= y).float()
return time_mask
def to(self, device):
"""
Move all tensors to the specified device.
"""
self.device = device
for attr in ['start_time', 'end_time', 'depot', 'node_positions', 'distance_matrix', 'time_matrix']:
setattr(self, attr, getattr(self, attr).to(device))
self.time_window_compatibility = self.time_window_compatibility.to(device)