Update Actor/graph.py
Browse files- Actor/graph.py +36 -23
Actor/graph.py
CHANGED
@@ -2,17 +2,16 @@ import torch
|
|
2 |
|
3 |
|
4 |
class Graph(object):
|
5 |
-
|
6 |
def __init__(self, graph_data, device='cpu'):
|
7 |
-
|
8 |
self.device = device
|
9 |
|
10 |
-
|
11 |
-
self.
|
12 |
-
self.
|
13 |
-
self.
|
14 |
-
self.
|
15 |
-
self.
|
|
|
16 |
|
17 |
self.num_nodes = self.distance_matrix.shape[1]
|
18 |
self.batch_size = self.distance_matrix.shape[0]
|
@@ -23,46 +22,60 @@ class Graph(object):
|
|
23 |
self.max_dist = self.distance_matrix.max()
|
24 |
self.max_drive_time = self.time_matrix.max()
|
25 |
|
26 |
-
|
27 |
def correct_depot_features(self):
|
28 |
-
|
29 |
-
self.
|
30 |
-
|
31 |
|
32 |
def construct_vector(self):
|
|
|
33 |
L = [self.node_positions, self.start_time, self.end_time, self.depot]
|
34 |
self.vector = torch.cat(L, dim=2)
|
35 |
return self.vector
|
36 |
|
37 |
-
|
38 |
def get_drive_times(self, from_nodes, to_nodes):
|
39 |
-
|
|
|
|
|
40 |
num_elements = from_nodes.shape[1]
|
41 |
assert num_elements == to_nodes.shape[1]
|
42 |
|
43 |
-
|
|
|
44 |
dist = torch.gather(self.time_matrix, dim=1, index=ind_1)
|
45 |
-
|
46 |
-
ind_2 = to_nodes.reshape(self.batch_size, num_elements, 1)
|
47 |
drive_times = torch.gather(dist, dim=2, index=ind_2)
|
48 |
return drive_times
|
49 |
|
50 |
-
|
51 |
def get_distances(self, from_nodes, to_nodes):
|
|
|
|
|
|
|
52 |
num_elements = from_nodes.shape[1]
|
53 |
assert num_elements == to_nodes.shape[1]
|
54 |
|
55 |
-
|
|
|
56 |
dist = torch.gather(self.distance_matrix, dim=1, index=ind_1)
|
57 |
-
|
58 |
-
ind_2 = to_nodes.reshape(self.batch_size, num_elements, 1)
|
59 |
distances = torch.gather(dist, dim=2, index=ind_2)
|
60 |
return distances
|
61 |
|
62 |
-
|
63 |
def compute_time_window_compatibility(self):
|
64 |
-
|
|
|
|
|
|
|
65 |
x = self.start_time.reshape(self.batch_size, self.num_nodes, 1).repeat(1, 1, self.num_nodes)
|
66 |
y = self.end_time.reshape(self.batch_size, 1, self.num_nodes).repeat(1, self.num_nodes, 1)
|
67 |
time_mask = (x + self.time_matrix <= y).float()
|
68 |
return time_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
|
4 |
class Graph(object):
|
|
|
5 |
def __init__(self, graph_data, device='cpu'):
|
|
|
6 |
self.device = device
|
7 |
|
8 |
+
# Load and move data to the correct device
|
9 |
+
self.start_time = graph_data['start_times'].to(device)
|
10 |
+
self.end_time = graph_data['end_times'].to(device)
|
11 |
+
self.depot = graph_data['depot'].to(device)
|
12 |
+
self.node_positions = graph_data['node_vector'].to(device)
|
13 |
+
self.distance_matrix = graph_data['distance_matrix'].to(device)
|
14 |
+
self.time_matrix = graph_data['time_matrix'].to(device)
|
15 |
|
16 |
self.num_nodes = self.distance_matrix.shape[1]
|
17 |
self.batch_size = self.distance_matrix.shape[0]
|
|
|
22 |
self.max_dist = self.distance_matrix.max()
|
23 |
self.max_drive_time = self.time_matrix.max()
|
24 |
|
|
|
25 |
def correct_depot_features(self):
|
26 |
+
# Remove time windows for depot nodes
|
27 |
+
self.start_time = self.start_time.clone() * (1 - self.depot)
|
28 |
+
self.end_time = self.end_time.clone() * (1 - self.depot)
|
29 |
|
30 |
def construct_vector(self):
|
31 |
+
# Concatenate node features for model input
|
32 |
L = [self.node_positions, self.start_time, self.end_time, self.depot]
|
33 |
self.vector = torch.cat(L, dim=2)
|
34 |
return self.vector
|
35 |
|
|
|
36 |
def get_drive_times(self, from_nodes, to_nodes):
|
37 |
+
"""
|
38 |
+
Get drive times between from_nodes and to_nodes.
|
39 |
+
"""
|
40 |
num_elements = from_nodes.shape[1]
|
41 |
assert num_elements == to_nodes.shape[1]
|
42 |
|
43 |
+
# Extract relevant entries from the time matrix
|
44 |
+
ind_1 = from_nodes.unsqueeze(2).expand(-1, -1, self.num_nodes)
|
45 |
dist = torch.gather(self.time_matrix, dim=1, index=ind_1)
|
46 |
+
ind_2 = to_nodes.unsqueeze(2)
|
|
|
47 |
drive_times = torch.gather(dist, dim=2, index=ind_2)
|
48 |
return drive_times
|
49 |
|
|
|
50 |
def get_distances(self, from_nodes, to_nodes):
|
51 |
+
"""
|
52 |
+
Get Euclidean distances between from_nodes and to_nodes.
|
53 |
+
"""
|
54 |
num_elements = from_nodes.shape[1]
|
55 |
assert num_elements == to_nodes.shape[1]
|
56 |
|
57 |
+
# Extract relevant entries from the distance matrix
|
58 |
+
ind_1 = from_nodes.unsqueeze(2).expand(-1, -1, self.num_nodes)
|
59 |
dist = torch.gather(self.distance_matrix, dim=1, index=ind_1)
|
60 |
+
ind_2 = to_nodes.unsqueeze(2)
|
|
|
61 |
distances = torch.gather(dist, dim=2, index=ind_2)
|
62 |
return distances
|
63 |
|
|
|
64 |
def compute_time_window_compatibility(self):
|
65 |
+
"""
|
66 |
+
Determine if traveling from node i to node j respects time window constraints:
|
67 |
+
i.e., start_time[i] + drive_time[i][j] <= end_time[j]
|
68 |
+
"""
|
69 |
x = self.start_time.reshape(self.batch_size, self.num_nodes, 1).repeat(1, 1, self.num_nodes)
|
70 |
y = self.end_time.reshape(self.batch_size, 1, self.num_nodes).repeat(1, self.num_nodes, 1)
|
71 |
time_mask = (x + self.time_matrix <= y).float()
|
72 |
return time_mask
|
73 |
+
|
74 |
+
def to(self, device):
|
75 |
+
"""
|
76 |
+
Move all tensors to the specified device.
|
77 |
+
"""
|
78 |
+
self.device = device
|
79 |
+
for attr in ['start_time', 'end_time', 'depot', 'node_positions', 'distance_matrix', 'time_matrix']:
|
80 |
+
setattr(self, attr, getattr(self, attr).to(device))
|
81 |
+
self.time_window_compatibility = self.time_window_compatibility.to(device)
|