a-ragab-h-m commited on
Commit
f2809ce
·
verified ·
1 Parent(s): fd3eb7b

Update Actor/graph.py

Browse files
Files changed (1) hide show
  1. Actor/graph.py +36 -23
Actor/graph.py CHANGED
@@ -2,17 +2,16 @@ import torch
2
 
3
 
4
  class Graph(object):
5
-
6
  def __init__(self, graph_data, device='cpu'):
7
-
8
  self.device = device
9
 
10
- self.start_time = graph_data['start_times']
11
- self.end_time = graph_data['end_times']
12
- self.depot = graph_data['depot']
13
- self.node_positions = graph_data['node_vector']
14
- self.distance_matrix = graph_data['distance_matrix']
15
- self.time_matrix = graph_data['time_matrix']
 
16
 
17
  self.num_nodes = self.distance_matrix.shape[1]
18
  self.batch_size = self.distance_matrix.shape[0]
@@ -23,46 +22,60 @@ class Graph(object):
23
  self.max_dist = self.distance_matrix.max()
24
  self.max_drive_time = self.time_matrix.max()
25
 
26
-
27
  def correct_depot_features(self):
28
- self.start_time = self.start_time * (1 - self.depot)
29
- self.end_time = self.end_time * (1 - self.depot)
30
-
31
 
32
  def construct_vector(self):
 
33
  L = [self.node_positions, self.start_time, self.end_time, self.depot]
34
  self.vector = torch.cat(L, dim=2)
35
  return self.vector
36
 
37
-
38
  def get_drive_times(self, from_nodes, to_nodes):
39
-
 
 
40
  num_elements = from_nodes.shape[1]
41
  assert num_elements == to_nodes.shape[1]
42
 
43
- ind_1 = from_nodes.reshape(self.batch_size, num_elements, 1).reshape(1, 1, self.num_nodes)
 
44
  dist = torch.gather(self.time_matrix, dim=1, index=ind_1)
45
-
46
- ind_2 = to_nodes.reshape(self.batch_size, num_elements, 1)
47
  drive_times = torch.gather(dist, dim=2, index=ind_2)
48
  return drive_times
49
 
50
-
51
  def get_distances(self, from_nodes, to_nodes):
 
 
 
52
  num_elements = from_nodes.shape[1]
53
  assert num_elements == to_nodes.shape[1]
54
 
55
- ind_1 = from_nodes.reshape(self.batch_size, num_elements, 1).reshape(1, 1, self.num_nodes)
 
56
  dist = torch.gather(self.distance_matrix, dim=1, index=ind_1)
57
-
58
- ind_2 = to_nodes.reshape(self.batch_size, num_elements, 1)
59
  distances = torch.gather(dist, dim=2, index=ind_2)
60
  return distances
61
 
62
-
63
  def compute_time_window_compatibility(self):
64
- #check for drive times being too long
 
 
 
65
  x = self.start_time.reshape(self.batch_size, self.num_nodes, 1).repeat(1, 1, self.num_nodes)
66
  y = self.end_time.reshape(self.batch_size, 1, self.num_nodes).repeat(1, self.num_nodes, 1)
67
  time_mask = (x + self.time_matrix <= y).float()
68
  return time_mask
 
 
 
 
 
 
 
 
 
 
2
 
3
 
4
  class Graph(object):
 
5
  def __init__(self, graph_data, device='cpu'):
 
6
  self.device = device
7
 
8
+ # Load and move data to the correct device
9
+ self.start_time = graph_data['start_times'].to(device)
10
+ self.end_time = graph_data['end_times'].to(device)
11
+ self.depot = graph_data['depot'].to(device)
12
+ self.node_positions = graph_data['node_vector'].to(device)
13
+ self.distance_matrix = graph_data['distance_matrix'].to(device)
14
+ self.time_matrix = graph_data['time_matrix'].to(device)
15
 
16
  self.num_nodes = self.distance_matrix.shape[1]
17
  self.batch_size = self.distance_matrix.shape[0]
 
22
  self.max_dist = self.distance_matrix.max()
23
  self.max_drive_time = self.time_matrix.max()
24
 
 
25
  def correct_depot_features(self):
26
+ # Remove time windows for depot nodes
27
+ self.start_time = self.start_time.clone() * (1 - self.depot)
28
+ self.end_time = self.end_time.clone() * (1 - self.depot)
29
 
30
  def construct_vector(self):
31
+ # Concatenate node features for model input
32
  L = [self.node_positions, self.start_time, self.end_time, self.depot]
33
  self.vector = torch.cat(L, dim=2)
34
  return self.vector
35
 
 
36
  def get_drive_times(self, from_nodes, to_nodes):
37
+ """
38
+ Get drive times between from_nodes and to_nodes.
39
+ """
40
  num_elements = from_nodes.shape[1]
41
  assert num_elements == to_nodes.shape[1]
42
 
43
+ # Extract relevant entries from the time matrix
44
+ ind_1 = from_nodes.unsqueeze(2).expand(-1, -1, self.num_nodes)
45
  dist = torch.gather(self.time_matrix, dim=1, index=ind_1)
46
+ ind_2 = to_nodes.unsqueeze(2)
 
47
  drive_times = torch.gather(dist, dim=2, index=ind_2)
48
  return drive_times
49
 
 
50
  def get_distances(self, from_nodes, to_nodes):
51
+ """
52
+ Get Euclidean distances between from_nodes and to_nodes.
53
+ """
54
  num_elements = from_nodes.shape[1]
55
  assert num_elements == to_nodes.shape[1]
56
 
57
+ # Extract relevant entries from the distance matrix
58
+ ind_1 = from_nodes.unsqueeze(2).expand(-1, -1, self.num_nodes)
59
  dist = torch.gather(self.distance_matrix, dim=1, index=ind_1)
60
+ ind_2 = to_nodes.unsqueeze(2)
 
61
  distances = torch.gather(dist, dim=2, index=ind_2)
62
  return distances
63
 
 
64
  def compute_time_window_compatibility(self):
65
+ """
66
+ Determine if traveling from node i to node j respects time window constraints:
67
+ i.e., start_time[i] + drive_time[i][j] <= end_time[j]
68
+ """
69
  x = self.start_time.reshape(self.batch_size, self.num_nodes, 1).repeat(1, 1, self.num_nodes)
70
  y = self.end_time.reshape(self.batch_size, 1, self.num_nodes).repeat(1, self.num_nodes, 1)
71
  time_mask = (x + self.time_matrix <= y).float()
72
  return time_mask
73
+
74
+ def to(self, device):
75
+ """
76
+ Move all tensors to the specified device.
77
+ """
78
+ self.device = device
79
+ for attr in ['start_time', 'end_time', 'depot', 'node_positions', 'distance_matrix', 'time_matrix']:
80
+ setattr(self, attr, getattr(self, attr).to(device))
81
+ self.time_window_compatibility = self.time_window_compatibility.to(device)