a-ragab-h-m commited on
Commit
184a603
·
verified ·
1 Parent(s): 7de92e3

Upload 2 files

Browse files
Files changed (2) hide show
  1. actor.py +751 -0
  2. actor_modified.py +155 -0
actor.py ADDED
@@ -0,0 +1,751 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from Actor.graph import Graph
5
+ from Actor.fleet import Fleet
6
+ from utils.actor_utils import widen_data, select_data
7
+ from Actor.normalization import Normalization
8
+
9
+
10
+ #This is the updated version of the actor
11
+ class Actor(nn.Module):
12
+
13
+ def __init__(self, model=None, num_movers=5, num_neighbors_encoder=5,
14
+ num_neighbors_action=5, normalize=False, use_fleet_attention=True,
15
+ device='cpu'):
16
+ super().__init__()
17
+
18
+ self.device = device
19
+ self.num_movers = num_movers
20
+ self.num_neighbors_encoder = num_neighbors_encoder
21
+ self.num_neighbors_action = num_neighbors_action
22
+
23
+ self.apply_normalization = normalize
24
+ self.normalization = None
25
+ self.use_fleet_attention = use_fleet_attention
26
+
27
+
28
+ if model is None:
29
+ self.mode = 'nearest_neighbors'
30
+ else:
31
+ self.encoder = model.encoder.to(self.device)
32
+ self.decoder = model.decoder.to(self.device)
33
+ self.projections = model.projections.to(self.device)
34
+ self.fleet_attention = model.fleet_attention.to(self.device)
35
+ self.mode = 'train'
36
+
37
+
38
+ self.sample_size = 1
39
+
40
+
41
+ def train_mode(self, sample_size=1):
42
+ self.train()
43
+ self.sample_size = sample_size
44
+ self.mode = 'train'
45
+
46
+
47
+ def greedy_search(self):
48
+ self.eval()
49
+ self.mode = 'greedy'
50
+
51
+
52
+ def nearest_neighbors(self):
53
+ self.eval()
54
+ self.mode = 'nearest_neighbors'
55
+
56
+
57
+ def sample_mode(self, sample_size=10):
58
+ self.sample_size = sample_size
59
+ self.eval()
60
+ self.mode = 'sample'
61
+
62
+
63
+ def beam_search(self, sample_size=10):
64
+ self.sample_size=sample_size
65
+ self.eval()
66
+ self.mode = 'beam_search'
67
+
68
+
69
+ def update_batch_size(self):
70
+ self.batch_size = self.fleet.time.shape[0]
71
+ self.fleet.batch_size = self.fleet.time.shape[0]
72
+ self.graph.batch_size = self.fleet.time.shape[0]
73
+
74
+
75
+
76
+ def forward(self, batch, *args, **kwargs):
77
+
78
+ graph_data, fleet_data = batch
79
+
80
+ self.original_batch_size = graph_data['distance_matrix'].shape[0]
81
+ self.batch_size = self.original_batch_size
82
+
83
+ self.num_nodes = graph_data['distance_matrix'].shape[1]
84
+ self.num_cars = fleet_data['start_time'].shape[1]
85
+
86
+
87
+ self.graph = Graph(graph_data, device=self.device)
88
+ self.fleet = Fleet(fleet_data, num_nodes=self.num_nodes, device=self.device)
89
+ self.num_nodes = self.graph.distance_matrix.shape[1]
90
+ self.update_batch_size()
91
+
92
+ self.normalization = Normalization(self, normalize_position=True)
93
+ if self.apply_normalization:
94
+ self.normalization.normalize(self)
95
+
96
+ self.num_depots = self.fleet.num_depots.max().item()
97
+ self.num_movers_corrected = int(min(max(self.num_movers, self.num_depots), self.num_cars))
98
+
99
+
100
+ if self.mode != 'nearest_neighbors':
101
+ encoder_input = self.graph.construct_vector()
102
+ encoder_mask = self.compute_encoder_mask()
103
+ self.node_embeddings = self.encoder(encoder_input, encoder_mask)
104
+ self.node_projections = self.projections(self.node_embeddings)
105
+
106
+
107
+ if self.mode == 'sample':
108
+ widen_data(self, include_embeddings=True, include_projections=True)
109
+ self.update_batch_size()
110
+
111
+
112
+ self.log_probs = torch.zeros(self.batch_size).to(self.device)
113
+ self.counter = 0
114
+ while self.loop_condition() and (self.counter < self.num_nodes*4):
115
+
116
+ unavailable_moves = self.check_non_depot_options(use_time=True)
117
+ mover_indices = self.get_mover_indices(unavailable_moves=unavailable_moves)
118
+ action_mask = self.compute_action_mask(mover_indices=mover_indices,
119
+ unavailable_moves=unavailable_moves)
120
+
121
+ if self.mode != 'nearest_neighbors':
122
+ decoder_input = self.construct_decoder_input(mover_indices=mover_indices)
123
+ decoder_mask = self.compute_decoder_mask(mover_indices=mover_indices,
124
+ unavailable_moves=unavailable_moves)
125
+
126
+ decoder_output = self.decoder(decoder_input=decoder_input,
127
+ projections=self.node_projections,
128
+ mask=decoder_mask)
129
+ else:
130
+ decoder_output = None
131
+
132
+ next_node, car_to_move, log_prob = self.compute_action(decoder_output, action_mask, mover_indices)
133
+
134
+ if self.mode == 'beam_search':
135
+ widen_data(self, include_embeddings=True, include_projections=True)
136
+ self.update_batch_size()
137
+
138
+ self.log_probs += log_prob
139
+
140
+ self.update_time(next_node, car_to_move)
141
+ self.update_distance(next_node, car_to_move)
142
+ self.update_node_path(next_node, car_to_move)
143
+ self.update_traversed_nodes()
144
+
145
+ #self.return_to_depot()
146
+ #self.update_traversed_nodes()
147
+
148
+ if (self.mode == 'beam_search') and (self.counter > 0):
149
+ self.consolidate_beams()
150
+
151
+ self.update_batch_size()
152
+ self.counter += 1
153
+
154
+ self.return_to_depot_1()
155
+
156
+
157
+ if self.mode in {'beam_search', 'sample'}:
158
+ # we now must select out the best k elments
159
+
160
+ incomplete = self.check_complete().float()
161
+ cost = self.compute_cost()
162
+
163
+ max_cost = cost.max()
164
+ masked_cost = (1 - incomplete)*cost + incomplete*max_cost*10
165
+
166
+ p = masked_cost.reshape(self.original_batch_size, self.sample_size)
167
+ a = torch.argmin(p, dim=1)
168
+
169
+ b = torch.arange(self.original_batch_size).to(self.device)
170
+ b = b * self.sample_size
171
+
172
+ index = a + b
173
+ select_data(self, index=index, include_projections=True, include_embeddings=True)
174
+
175
+
176
+ self.update_batch_size()
177
+ if self.apply_normalization:
178
+ self.normalization.inverse_normalize(self)
179
+
180
+ total_distance = self.fleet.distance.sum(dim=1).squeeze(1).detach()
181
+ total_time = self.fleet.time.sum(dim=1).squeeze(1).detach()
182
+ total_late_time = self.fleet.late_time.sum(dim=1).squeeze(1).detach()
183
+
184
+ output = {
185
+ 'distance': total_distance.detach(),
186
+ 'total_time': total_time.detach(),
187
+ 'log_probs': self.log_probs,
188
+ 'late_time': total_late_time.detach(),
189
+ 'incomplete': self.check_complete(),
190
+ 'path': self.fleet.path,
191
+ 'arrival_times': self.fleet.arrival_times
192
+ }
193
+
194
+ return output
195
+
196
+
197
+ def consolidate_beams(self):
198
+ p = self.log_probs.reshape(self.original_batch_size, self.sample_size * self.sample_size)
199
+ a = torch.topk(p, dim=1, k=self.sample_size, largest=True)[1]
200
+
201
+ b = torch.arange(self.original_batch_size).unsqueeze(1).repeat(1, self.sample_size).to(self.device)
202
+ b = b * self.sample_size * self.sample_size
203
+
204
+ ind = (a + b).reshape(-1)
205
+ select_data(self, index=ind, include_projections=True, include_embeddings=True)
206
+
207
+
208
+ def adjust_arrival_times(self):
209
+
210
+ if self.apply_normalization and (self.normalization_params is not None):
211
+ num_steps = self.fleet.arrival_times.shape[2]
212
+ a = self.normalization_params['earliest_start_time'].reshape(self.batch_size, 1, 1).repeat(1, self.num_cars, num_steps)
213
+ b = self.normalization_params['greatest_drive_time'].reshape(self.batch_size, 1, 1).repeat(1, self.num_cars, num_steps)
214
+ self.fleet.arrival_times = self.fleet.arrival_times*b + a
215
+
216
+ a = self.normalization_params['earliest_start_time']
217
+ b = self.normalization_params['greatest_drive_time']
218
+ self.fleet.late_time = self.fleet.late_time*b + a
219
+
220
+
221
+ def compute_encoder_mask(self):
222
+ #check for drive times being too long
223
+ time_window_non_compatibility = 1 - self.graph.time_window_compatibility
224
+
225
+ #compute diag mask
226
+ diag = torch.diag(torch.ones(self.num_nodes)).unsqueeze(0).repeat(self.batch_size, 1, 1).to(self.device)
227
+
228
+ # compute neighbors mask
229
+ m = (time_window_non_compatibility + diag > 0).float()
230
+
231
+ dist = self.graph.distance_matrix*(1 - m) + m*self.graph.max_dist*10
232
+
233
+ K = min(self.num_nodes, self.num_neighbors_encoder)
234
+ neighbors_index = torch.topk(dist, k=K, dim=2, largest=False)[1] # ~ [batch, num_nodes, num_neighbors]
235
+
236
+ a = torch.arange(self.num_nodes).reshape(1, 1, -1, 1).repeat(self.batch_size, self.num_nodes, 1, K).to(self.device)
237
+ b = neighbors_index.unsqueeze(2).repeat(1, 1, self.num_nodes, 1)
238
+ neighbors_mask = (a == b).float().sum(dim=3)
239
+
240
+
241
+ m = (time_window_non_compatibility == 0).float()
242
+ neighbs_time_mask = neighbors_mask*m
243
+
244
+ #compute depot mask
245
+ v = self.graph.depot.reshape(self.batch_size, 1, self.num_nodes).repeat(1, self.num_nodes, 1)
246
+ w = self.graph.depot.reshape(self.batch_size, self.num_nodes, 1).repeat(1, 1, self.num_nodes)
247
+ depot_mask = (v + w > 0).float()
248
+
249
+ diag = torch.diag(torch.ones(self.num_nodes)).unsqueeze(0).repeat(self.batch_size, 1, 1).to(self.device)
250
+
251
+ mask = (neighbs_time_mask + depot_mask + diag > 0).float()
252
+ return mask
253
+
254
+
255
+ def compute_depoyment_priority_score(self):
256
+ #number of available nodes
257
+ available_nodes = (self.check_non_depot_options().float() == 0).float()
258
+ num_available_nodes = available_nodes.sum(dim=2)
259
+
260
+ #excess
261
+ ind = self.fleet.node.reshape(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_nodes)
262
+ distances = torch.gather(self.graph.distance_matrix, dim=1, index=ind)
263
+
264
+ max_distance = self.graph.distance_matrix.reshape(self.batch_size, -1).max(dim=1)[0].reshape(
265
+ self.batch_size, 1, 1).repeat(1, self.num_cars, self.num_nodes)
266
+ excess = ((max_distance - distances)*available_nodes).sum(dim=2)
267
+
268
+
269
+ max_excess = excess.max()
270
+
271
+ score = (num_available_nodes + max_excess)*10 + excess
272
+ return score
273
+
274
+
275
+
276
+ def check_complete(self):
277
+ has_untraversed_nodes = ((self.fleet.traversed_nodes == 0).float().sum(dim=1) > 0)
278
+ #has_cars_in_field = ((self.fleet.node != self.fleet.depot).sum(dim=1) > 0)
279
+ #incomplete = (has_untraversed_nodes | has_cars_in_field)
280
+ incomplete = has_untraversed_nodes
281
+ return incomplete
282
+
283
+
284
+ def loop_condition(self):
285
+ a = self.check_complete().sum().item()
286
+ if a == 0:
287
+ return False
288
+ else:
289
+ return True
290
+
291
+
292
+ def compute_cost(self):
293
+
294
+ dist = self.fleet.distance.sum(dim=1).squeeze(1)
295
+ time = self.fleet.time.sum(dim=1).squeeze(1)
296
+ lateness = self.fleet.late_time.sum(dim=1).squeeze(1)
297
+
298
+ #normally we compute the cost as a linear combination of these three quantities
299
+ return time
300
+
301
+
302
+
303
+ def compute_total_distance(self):
304
+ p_2 = self.fleet.path
305
+ p_1 = torch.cat([p_2[:,:,-1:], p_2[:,:,:-1]], dim=2)
306
+
307
+ mat = self.graph.distance_matrix.reshape(self.batch_size, 1, self.num_nodes, self.num_nodes).repeat(1, self.num_cars, 1, 1)
308
+ ind_1 = p_1.unsqueeze(3).repeat(1, 1, 1, self.num_nodes)
309
+
310
+ d = torch.gather(mat, dim=2, index=ind_1)
311
+ ind_2 = p_2.unsqueeze(3)
312
+ pairwise_distances = torch.gather(d, dim=3, index=ind_2).squeeze(3)
313
+
314
+ self.car_distances = pairwise_distances.sum(dim=2)
315
+ self.total_distance = self.car_distances.sum(dim=1)
316
+
317
+ if self.apply_normalization and (self.normalization_params is not None):
318
+ distance_multiplier = self.normalization_params['greatest_distance']
319
+ self.total_distance = self.total_distance*distance_multiplier.reshape(self.batch_size)
320
+
321
+ return self.total_distance
322
+
323
+
324
+ def update_traversed_nodes(self):
325
+
326
+ path_length = self.fleet.path.shape[2]
327
+ a = self.fleet.path.reshape(self.batch_size, self.num_cars, path_length, 1).repeat(1, 1, 1, self.num_nodes)
328
+ b = torch.arange(self.num_nodes).reshape(
329
+ 1, 1, 1, self.num_nodes).repeat(
330
+ self.batch_size, self.num_cars, path_length, 1).to(self.device)
331
+
332
+ s = (a == b).float().reshape(self.batch_size, self.num_cars*path_length, self.num_nodes).sum(dim=1)
333
+ self.fleet.traversed_nodes = (s > 0).float()
334
+
335
+
336
+ def compute_action(self, decoder_output, action_mask, mover_indices):
337
+
338
+ if self.mode == 'nearest_neighbors':
339
+
340
+ assert decoder_output is None
341
+
342
+ num_movers = mover_indices.shape[1]
343
+ mover_nodes = torch.gather(self.fleet.node, dim=1, index=mover_indices)
344
+
345
+ ind = mover_nodes.reshape(self.batch_size, num_movers, 1).repeat(1, 1, self.num_nodes)
346
+ distances = torch.gather(self.graph.distance_matrix, dim=1, index=ind)
347
+
348
+ a = mover_nodes.reshape(self.batch_size, num_movers, 1).repeat(1, 1, self.num_nodes)
349
+ b = torch.arange(self.num_nodes).reshape(1, 1, self.num_nodes).repeat(self.batch_size, num_movers, 1).to(self.device)
350
+ current_node_indicator = (a == b).float()
351
+
352
+ a = action_mask.reshape(self.batch_size, -1)
353
+ no_options = (a.sum(dim=1) == 0).reshape(-1, 1, 1).repeat(1, num_movers, self.num_nodes).float()
354
+
355
+ mask = action_mask * (1 - no_options) + current_node_indicator * no_options
356
+ masked_distances = distances*mask + self.graph.max_dist*(1 - mask)*10
357
+
358
+ x = masked_distances.reshape(self.batch_size, -1)
359
+ ind = torch.argmin(x, dim=1).unsqueeze(1)
360
+
361
+ mover_index = ind // self.num_nodes
362
+ next_node = ind % self.num_nodes
363
+ car_to_move = torch.gather(mover_indices, dim=1, index=mover_index)
364
+
365
+ log_prob = torch.zeros(self.batch_size).to(self.device)
366
+ return next_node, car_to_move, log_prob
367
+
368
+ else:
369
+
370
+ assert decoder_output is not None
371
+
372
+ num_movers = mover_indices.shape[1]
373
+ assert (num_movers == action_mask.shape[1]) and (num_movers == decoder_output.shape[1])
374
+
375
+ a = action_mask.reshape(self.batch_size, -1)
376
+ no_options = (a.sum(dim=1) == 0).reshape(-1, 1, 1).repeat(1, num_movers, self.num_nodes).float()
377
+
378
+ mover_nodes = torch.gather(self.fleet.node, dim=1, index=mover_indices)
379
+ a = torch.arange(self.num_nodes).reshape(1, 1, -1).repeat(self.batch_size, num_movers, 1).to(self.device)
380
+ b = mover_nodes.reshape(self.batch_size, num_movers, 1).repeat(1, 1, self.num_nodes)
381
+ default_option = (a == b).float()
382
+
383
+ mask = action_mask*(1 - no_options) + default_option*no_options
384
+ masked_decoder_output = decoder_output + mask.log()
385
+
386
+ x = masked_decoder_output.reshape(self.batch_size, -1)
387
+ probs = torch.softmax(x, dim=1)
388
+
389
+ if self.mode == 'greedy':
390
+ prob, ind = torch.max(probs, dim=1)[0].unsqueeze(1), torch.argmax(probs, dim=1).unsqueeze(1)
391
+
392
+ elif self.mode in {'train', 'sample'}:
393
+ ind = torch.multinomial(probs, num_samples=1)
394
+ prob = torch.gather(probs, dim=1, index=ind)
395
+
396
+ elif self.mode == 'beam_search':
397
+ prob, ind = torch.topk(probs, dim=1, k=self.sample_size, largest=True)
398
+
399
+ mover_index = ind // self.num_nodes
400
+ next_node = ind % self.num_nodes
401
+ car_to_move = torch.gather(mover_indices, dim=1, index=mover_index)
402
+
403
+ next_node = next_node.reshape(-1)
404
+ car_to_move = car_to_move.reshape(-1)
405
+ prob = prob.reshape(-1)
406
+
407
+ return next_node, car_to_move, prob.log()
408
+
409
+
410
+ def update_distance(self, next_node, car_to_move):
411
+ ind = car_to_move.reshape(self.batch_size, 1)
412
+ n = self.fleet.node.reshape(self.batch_size, self.num_cars)
413
+ current_node = torch.gather(n, dim=1, index=ind).squeeze(1)
414
+
415
+ #compute distance to next node
416
+ ind_1 = current_node.reshape(-1, 1, 1).repeat(1, 1, self.num_nodes)
417
+ distances = torch.gather(self.graph.distance_matrix, dim=1, index=ind_1).squeeze(1)
418
+ ind_2 = next_node.reshape(-1, 1)
419
+ dist_to_next_node = torch.gather(distances, dim=1, index=ind_2).squeeze(1)
420
+
421
+ #compute current distance
422
+ t = self.fleet.distance.reshape(self.batch_size, self.num_cars)
423
+ current_distance = torch.gather(t, dim=1, index=car_to_move.reshape(self.batch_size, 1)).squeeze(1)
424
+
425
+ #compute updated_distance
426
+ updated_distance = current_distance + dist_to_next_node
427
+
428
+ #compute mover mask
429
+ a = torch.arange(self.num_cars).reshape(1, -1).repeat(self.batch_size, 1).to(self.device)
430
+ b = car_to_move.reshape(self.batch_size, 1).repeat(1, self.num_cars)
431
+ update_mask = (a == b).float().unsqueeze(2)
432
+
433
+ #update distance
434
+ t = updated_distance.reshape(self.batch_size, 1, 1).repeat(1, self.num_cars, 1)
435
+ self.fleet.distance = self.fleet.distance * (1 - update_mask) + t * update_mask
436
+
437
+
438
+ def update_time(self, next_node, car_to_move):
439
+ #get current node of mover
440
+ ind = car_to_move.reshape(self.batch_size, 1)
441
+ n = self.fleet.node.reshape(self.batch_size, self.num_cars)
442
+ current_node = torch.gather(n, dim=1, index=ind).squeeze(1)
443
+
444
+ #compute time to next node
445
+ ind_1 = current_node.reshape(-1, 1, 1).repeat(1, 1, self.num_nodes)
446
+ drive_times = torch.gather(self.graph.time_matrix, dim=1, index=ind_1).squeeze(1)
447
+ ind_2 = next_node.reshape(-1, 1)
448
+ time_to_next_node = torch.gather(drive_times, dim=1, index=ind_2).squeeze(1)
449
+
450
+ #compute start time at next node
451
+ start_time = self.graph.start_time.reshape(self.batch_size, self.num_nodes)
452
+ ind = next_node.reshape(self.batch_size, 1)
453
+ next_start_time = torch.gather(start_time, dim=1, index=ind).squeeze(1)
454
+
455
+ #compute current time
456
+ t = self.fleet.time.reshape(self.batch_size, self.num_cars)
457
+ current_node_time = torch.gather(t, dim=1, index=car_to_move.reshape(self.batch_size, 1)).squeeze(1)
458
+
459
+ #compute updated_time
460
+ a = time_to_next_node + current_node_time
461
+ b = next_start_time
462
+ updated_time = (a > b).float()*a + (a <= b).float()*b
463
+
464
+ #compute end_time at next node
465
+ ind = next_node.reshape(-1, 1)
466
+ end_times = self.graph.end_time.reshape(self.batch_size, self.num_nodes)
467
+ end_time_next_node = torch.gather(end_times, dim=1, index=ind).squeeze(1)
468
+ late_time = F.relu(updated_time - end_time_next_node)
469
+
470
+
471
+ #compute mover mask
472
+ a = torch.arange(self.num_cars).reshape(1, -1).repeat(self.batch_size, 1).to(self.device)
473
+ b = car_to_move.reshape(self.batch_size, 1).repeat(1, self.num_cars)
474
+ update_mask = (a == b).float().unsqueeze(2)
475
+
476
+ #update time
477
+ t = updated_time.reshape(self.batch_size, 1, 1).repeat(1, self.num_cars, 1)
478
+ self.fleet.time = self.fleet.time*(1 - update_mask) + t*update_mask
479
+
480
+ # update_late_time
481
+ l = late_time.reshape(self.batch_size, 1, 1).repeat(1, self.num_cars, 1)
482
+ self.fleet.late_time = self.fleet.late_time*(1 - update_mask) + l*update_mask
483
+
484
+
485
+
486
+ def update_node_path(self, next_node, car_to_move):
487
+ #compute mover mask
488
+ a = torch.arange(self.num_cars).reshape(1, -1).repeat(self.batch_size, 1).to(self.device)
489
+ b = car_to_move.reshape(self.batch_size, 1).repeat(1, self.num_cars)
490
+ update_mask = (a == b).long()
491
+
492
+ new_node = next_node.reshape(self.batch_size, 1).repeat(1, self.num_cars)
493
+ self.fleet.node = update_mask*new_node + self.fleet.node*(1 - update_mask)
494
+
495
+ L = [self.fleet.path, self.fleet.node.unsqueeze(2)]
496
+ self.fleet.path = torch.cat(L, dim=2)
497
+
498
+ t = self.fleet.time.reshape(self.batch_size, self.num_cars, 1)
499
+ H = [self.fleet.arrival_times, t]
500
+ self.fleet.arrival_times = torch.cat(H, dim=2)
501
+
502
+
503
+
504
+ def check_non_depot_options(self, use_time=True):
505
+ '''
506
+ Output is a byte tensor of shape [batch, num_cars, num_nodes] with entry (i,j) = 1 if
507
+ the move of car i to node j is invalid
508
+ '''
509
+
510
+ if use_time:
511
+ #check for arrival times
512
+ too_far = 1 - self.check_arival_times()
513
+
514
+ #is depot
515
+ a = self.fleet.depot.reshape(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_nodes)
516
+ b = torch.arange(self.num_nodes).reshape(1, 1, -1).repeat(self.batch_size, self.num_cars, 1).to(self.device)
517
+ is_depot = (a == b)
518
+
519
+ #check traversed nodes
520
+ a = self.fleet.traversed_nodes.reshape(self.batch_size, 1, self.num_nodes).repeat(1, self.num_cars, 1)
521
+ traversed_nodes = (a == 1)
522
+
523
+ # has value of 1 if the move to that node is NOT possible
524
+ if use_time:
525
+ unavailable_moves = (too_far | is_depot | traversed_nodes)
526
+ else:
527
+ unavailable_moves = (is_depot | traversed_nodes)
528
+
529
+ return unavailable_moves
530
+
531
+
532
+ def compute_decoder_mask(self, mover_indices, unavailable_moves):
533
+
534
+ num_movers = mover_indices.shape[1]
535
+
536
+ mover_nodes = torch.gather(self.fleet.node, dim=1, index=mover_indices)
537
+
538
+ ind = mover_indices.reshape(self.batch_size, num_movers, 1).repeat(1, 1, self.num_nodes)
539
+ unavailable_moves = torch.gather(unavailable_moves, dim=1, index=ind)
540
+
541
+ no_options = ((1 - unavailable_moves.float()).sum(dim=2) == 0).unsqueeze(2).repeat(1, 1, self.num_nodes).float()
542
+
543
+ a = mover_nodes.reshape(self.batch_size, num_movers, 1).repeat(1, 1, self.num_nodes)
544
+ b = torch.arange(self.num_nodes).reshape(1, 1, -1).repeat(self.batch_size, num_movers, 1).to(self.device)
545
+ default_move = (a == b).float()
546
+
547
+ decoder_mask = (1 - unavailable_moves.float())*(1 - no_options) + default_move*no_options
548
+ return decoder_mask
549
+
550
+
551
+ def compute_action_mask(self, mover_indices, unavailable_moves):
552
+
553
+ if self.mode != 'nearest_neighbors':
554
+ ######################################################
555
+ d = torch.diag(torch.ones(self.num_nodes)).unsqueeze(0).repeat(self.batch_size, 1, 1).to(self.device)
556
+ ind = self.fleet.node.reshape(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_nodes)
557
+
558
+ diag = torch.gather(d, dim=1, index=ind)
559
+ distances = torch.gather(self.graph.distance_matrix, dim=1, index=ind)
560
+ m = (unavailable_moves.float() + diag > 0).float()
561
+
562
+ masked_distances = distances*(1 - m) + m*self.graph.max_dist*100
563
+ K = min(self.num_neighbors_action, self.num_nodes)
564
+ neighbor_indices = torch.topk(masked_distances, dim=2, k=K, largest=False)[1]
565
+
566
+ a = torch.arange(self.num_nodes).reshape(1, 1, -1, 1).repeat(self.batch_size, self.num_cars, 1, K).to(self.device)
567
+ b = neighbor_indices.reshape(self.batch_size, self.num_cars, 1, K).repeat(1, 1, self.num_nodes, 1)
568
+
569
+ non_neighbor_mask = ((a == b).float().sum(dim=3) == 0)
570
+ mask = 1 - (non_neighbor_mask | unavailable_moves).float()
571
+ ######################################################
572
+ else:
573
+ mask = 1 - unavailable_moves.float()
574
+
575
+ num_movers = mover_indices.shape[1]
576
+ ind = mover_indices.reshape(self.batch_size, num_movers, 1).repeat(1, 1, self.num_nodes)
577
+ action_mask = torch.gather(mask, dim=1, index=ind)
578
+ return action_mask
579
+
580
+
581
+
582
+ def construct_decoder_input(self, mover_indices):
583
+
584
+ embedding_size = self.node_embeddings.shape[2]
585
+ mask = self.check_non_depot_options(use_time=True)
586
+ mask = mask.permute(0, 2, 1).unsqueeze(3).repeat(1, 1, 1, embedding_size).float()
587
+
588
+ node_vectors = self.node_embeddings.reshape(self.batch_size, self.num_nodes, 1, embedding_size).repeat(1, 1, self.num_cars, 1)
589
+
590
+ #depot vector
591
+ start_ind = self.fleet.depot.reshape(self.batch_size, 1, self.num_cars, 1).repeat(1, 1, 1, embedding_size)
592
+ a = torch.gather(node_vectors, dim=1, index=start_ind)
593
+ depot_vector = a.squeeze(1)
594
+
595
+ #current node vector
596
+ current_ind = self.fleet.node.reshape(self.batch_size, 1, self.num_cars, 1).repeat(1, 1, 1, embedding_size)
597
+ b = torch.gather(node_vectors, dim=1, index=current_ind)
598
+ current_node_vector = b.squeeze(1)
599
+
600
+ #mean graph vector
601
+ mean_graph_vector = node_vectors.mean(dim=1)
602
+
603
+
604
+ #current feature values
605
+ feature_values = self.fleet.construct_vector()
606
+
607
+
608
+ #other cars in field
609
+ num_movers = mover_indices.shape[1]
610
+ if num_movers == 1:
611
+ movers_vector = current_node_vector*0
612
+ elif not self.use_fleet_attention:
613
+ a = mover_indices.reshape(self.batch_size, num_movers, 1).repeat(1, 1, self.num_cars)
614
+ b = torch.arange(self.num_cars).reshape(1, 1, self.num_cars).repeat(self.batch_size, num_movers, 1).to(self.device)
615
+ c = ((a == b).float().sum(dim=1) > 0).float()
616
+ movers_mask = c.reshape(self.batch_size, self.num_cars, 1).repeat(1, 1, embedding_size)
617
+ movers_vector = ((current_node_vector*movers_mask).sum(dim=1).unsqueeze(1) - current_node_vector)/(num_movers - 1)
618
+ else:
619
+ x = torch.cat([current_node_vector, feature_values], dim=2)
620
+ movers_vector = self.fleet_attention(x)
621
+
622
+
623
+ L = [current_node_vector, depot_vector, mean_graph_vector, movers_vector, feature_values]
624
+ pre_output = torch.cat(L, dim=2)
625
+
626
+
627
+ num_movers = mover_indices.shape[1]
628
+ ind = mover_indices.reshape(self.batch_size, num_movers, 1).repeat(1, 1, pre_output.shape[2])
629
+ output = torch.gather(pre_output, dim=1, index=ind)
630
+
631
+ return output
632
+
633
+
634
+ def get_mover_indices(self, unavailable_moves):
635
+ depot = self.fleet.depot.reshape(self.batch_size, self.num_cars).long()
636
+ current_node = self.fleet.node.reshape(self.batch_size, self.num_cars).long()
637
+
638
+ # find all cars with "no options"
639
+ has_option = ((1 - unavailable_moves.float()).sum(dim=2) > 0)
640
+
641
+ at_depot = (current_node == depot)
642
+ in_field = (current_node != depot)
643
+ active_in_field = in_field & has_option
644
+ active_at_depot = at_depot & has_option
645
+
646
+ deployment_score = self.compute_depoyment_priority_score()
647
+ max_deployment_score = deployment_score.max()
648
+
649
+ A = max_deployment_score * 100
650
+ B = deployment_score * 10
651
+
652
+ score = active_in_field.float() * A + active_at_depot.float() * B
653
+ score = score.reshape(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_depots)
654
+
655
+ a = self.fleet.depot.reshape(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_depots)
656
+ b = torch.arange(self.num_depots).reshape(1, 1, self.num_depots).repeat(self.batch_size, self.num_cars, 1).to(self.device)
657
+ m = (a == b).float()
658
+
659
+ score = score*m
660
+
661
+ K = self.num_movers_corrected
662
+ indices = torch.topk(score, k=K, dim=1, largest=True)[1]
663
+ indices = indices.reshape(self.batch_size, self.num_depots*K)
664
+ return indices
665
+
666
+
667
+ def check_arival_times(self):
668
+ #check for arrival times
669
+ d = self.graph.time_matrix.unsqueeze(1).repeat(1, self.num_cars, 1, 1)
670
+ ind = self.fleet.node.reshape(self.batch_size, self.num_cars, 1, 1).repeat(1, 1, 1, self.num_nodes)
671
+ drive_times = torch.gather(d, dim=2, index=ind).squeeze(2)
672
+ t = self.fleet.time.reshape(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_nodes)
673
+ arival_times = drive_times + t
674
+ b = self.graph.end_time.reshape(self.batch_size, 1, self.num_nodes).repeat(1, self.num_cars, 1)
675
+ attainable = (arival_times <= b)
676
+ return attainable
677
+
678
+
679
+ def return_to_depot(self):
680
+ unavailable_moves = 1 - self.check_non_depot_options(use_time=False).float()
681
+ return_to_depot = (unavailable_moves.reshape(self.batch_size, -1).sum(dim=1) == 0).float()
682
+ return_to_depot = return_to_depot.reshape(self.batch_size, 1).repeat(1, self.num_cars)
683
+
684
+
685
+ if return_to_depot.sum().item() > 0:
686
+
687
+ depot = self.fleet.depot.reshape(self.batch_size, self.num_cars).long()
688
+ node = self.fleet.node.reshape(self.batch_size, self.num_cars).long()
689
+
690
+ #compute next node
691
+ return_to_depot = return_to_depot.long()
692
+ next_node = return_to_depot*depot + (1 - return_to_depot)*node
693
+
694
+ #update time
695
+ return_to_depot = return_to_depot.unsqueeze(2).float()
696
+
697
+ current_node = self.fleet.node
698
+ ind_1 = current_node.reshape(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_nodes)
699
+ drive_times = torch.gather(self.graph.time_matrix, dim=1, index=ind_1)
700
+ ind_2 = next_node.reshape(self.batch_size, self.num_cars, 1)
701
+ time_to_next_node = torch.gather(drive_times, dim=2, index=ind_2)
702
+ self.fleet.time = self.fleet.time + time_to_next_node*return_to_depot
703
+
704
+ #update node
705
+ self.fleet.node = next_node
706
+
707
+ #update path
708
+ n = self.fleet.node.reshape(self.batch_size, self.num_cars, 1)
709
+ self.fleet.path = torch.cat([self.fleet.path, n], dim=2)
710
+
711
+ #update arrival time
712
+ t = self.fleet.time.reshape(self.batch_size, self.num_cars, 1)
713
+ self.fleet.arrival_times = torch.cat([self.fleet.arrival_times, t], dim=2)
714
+
715
+
716
+ def return_to_depot_1(self):
717
+
718
+ depot = self.fleet.depot.reshape(self.batch_size, self.num_cars).long()
719
+ node = self.fleet.node.reshape(self.batch_size, self.num_cars).long()
720
+
721
+ #compute next node
722
+ next_node = depot
723
+
724
+ #update time
725
+ current_node = self.fleet.node
726
+ ind_1 = current_node.reshape(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_nodes)
727
+ drive_times = torch.gather(self.graph.time_matrix, dim=1, index=ind_1)
728
+ ind_2 = next_node.reshape(self.batch_size, self.num_cars, 1)
729
+ time_to_next_node = torch.gather(drive_times, dim=2, index=ind_2)
730
+ self.fleet.time = self.fleet.time + time_to_next_node
731
+
732
+
733
+ #update distance
734
+ current_node = self.fleet.node
735
+ ind_1 = current_node.reshape(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_nodes)
736
+ drive_distances = torch.gather(self.graph.distance_matrix, dim=1, index=ind_1)
737
+ ind_2 = next_node.reshape(self.batch_size, self.num_cars, 1)
738
+ distance_to_next_node = torch.gather(drive_distances, dim=2, index=ind_2)
739
+ self.fleet.distance = self.fleet.distance + distance_to_next_node
740
+
741
+
742
+ #update node
743
+ self.fleet.node = next_node
744
+
745
+ #update path
746
+ n = self.fleet.node.reshape(self.batch_size, self.num_cars, 1)
747
+ self.fleet.path = torch.cat([self.fleet.path, n], dim=2)
748
+
749
+ #update arrival time
750
+ t = self.fleet.time.reshape(self.batch_size, self.num_cars, 1)
751
+ self.fleet.arrival_times = torch.cat([self.fleet.arrival_times, t], dim=2)
actor_modified.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from Actor.graph import Graph
5
+ from Actor.fleet import Fleet
6
+ from utils.actor_utils import widen_data, select_data
7
+ from Actor.normalization import Normalization
8
+
9
+
10
+ class Actor(nn.Module):
11
+ def __init__(self, model=None, num_movers=5, num_neighbors_encoder=5,
12
+ num_neighbors_action=5, normalize=False, use_fleet_attention=True,
13
+ device='cpu'):
14
+ super().__init__()
15
+
16
+ self.device = device
17
+ self.num_movers = num_movers
18
+ self.num_neighbors_encoder = num_neighbors_encoder
19
+ self.num_neighbors_action = num_neighbors_action
20
+
21
+ self.apply_normalization = normalize
22
+ self.normalization = None
23
+ self.use_fleet_attention = use_fleet_attention
24
+
25
+ if model is None:
26
+ self.mode = 'nearest_neighbors'
27
+ else:
28
+ self.encoder = model.encoder.to(self.device)
29
+ self.decoder = model.decoder.to(self.device)
30
+ self.projections = model.projections.to(self.device)
31
+ self.fleet_attention = model.fleet_attention.to(self.device)
32
+ self.mode = 'train'
33
+
34
+ self.sample_size = 1
35
+
36
+ def train_mode(self, sample_size=1):
37
+ self.train()
38
+ self.sample_size = sample_size
39
+ self.mode = 'train'
40
+
41
+ def greedy_search(self):
42
+ self.eval()
43
+ self.mode = 'greedy'
44
+
45
+ def nearest_neighbors(self):
46
+ self.eval()
47
+ self.mode = 'nearest_neighbors'
48
+
49
+ def sample_mode(self, sample_size=10):
50
+ self.sample_size = sample_size
51
+ self.eval()
52
+ self.mode = 'sample'
53
+
54
+ def beam_search(self, sample_size=10):
55
+ self.sample_size = sample_size
56
+ self.eval()
57
+ self.mode = 'beam_search'
58
+
59
+ def update_batch_size(self):
60
+ self.batch_size = self.fleet.time.shape[0]
61
+ self.fleet.batch_size = self.batch_size
62
+ self.graph.batch_size = self.batch_size
63
+
64
+ def forward(self, batch, *args, **kwargs):
65
+ graph_data, fleet_data = batch
66
+ self.original_batch_size = graph_data['distance_matrix'].shape[0]
67
+ self.batch_size = self.original_batch_size
68
+
69
+ self.num_nodes = graph_data['distance_matrix'].shape[1]
70
+ self.num_cars = fleet_data['start_time'].shape[1]
71
+
72
+ self.graph = Graph(graph_data, device=self.device)
73
+ self.fleet = Fleet(fleet_data, num_nodes=self.num_nodes, device=self.device)
74
+ self.update_batch_size()
75
+
76
+ self.normalization = Normalization(self, normalize_position=True)
77
+ if self.apply_normalization:
78
+ self.normalization.normalize(self)
79
+
80
+ self.num_depots = self.fleet.num_depots.max().item()
81
+ self.num_movers_corrected = int(min(max(self.num_movers, self.num_depots), self.num_cars))
82
+
83
+ if self.mode != 'nearest_neighbors':
84
+ encoder_input = self.graph.construct_vector()
85
+ encoder_mask = self.compute_encoder_mask()
86
+ self.node_embeddings = self.encoder(encoder_input, encoder_mask)
87
+ self.node_projections = self.projections(self.node_embeddings)
88
+
89
+ if self.mode == 'sample':
90
+ widen_data(self, include_embeddings=True, include_projections=True)
91
+ self.update_batch_size()
92
+
93
+ self.log_probs = torch.zeros(self.batch_size).to(self.device)
94
+ self.counter = 0
95
+ while self.loop_condition() and (self.counter < self.num_nodes * 4):
96
+ unavailable_moves = self.check_non_depot_options(use_time=True)
97
+ mover_indices = self.get_mover_indices(unavailable_moves=unavailable_moves)
98
+ action_mask = self.compute_action_mask(mover_indices=mover_indices, unavailable_moves=unavailable_moves)
99
+
100
+ decoder_output = None
101
+ if self.mode != 'nearest_neighbors':
102
+ decoder_input = self.construct_decoder_input(mover_indices=mover_indices)
103
+ decoder_mask = self.compute_decoder_mask(mover_indices=mover_indices, unavailable_moves=unavailable_moves)
104
+ decoder_output = self.decoder(decoder_input=decoder_input,
105
+ projections=self.node_projections,
106
+ mask=decoder_mask)
107
+
108
+ next_node, car_to_move, log_prob = self.compute_action(decoder_output, action_mask, mover_indices)
109
+
110
+ if self.mode == 'beam_search':
111
+ widen_data(self, include_embeddings=True, include_projections=True)
112
+ self.update_batch_size()
113
+
114
+ self.log_probs += log_prob
115
+ self.update_time(next_node, car_to_move)
116
+ self.update_distance(next_node, car_to_move)
117
+ self.update_node_path(next_node, car_to_move)
118
+ self.update_traversed_nodes()
119
+
120
+ if (self.mode == 'beam_search') and (self.counter > 0):
121
+ self.consolidate_beams()
122
+
123
+ self.update_batch_size()
124
+ self.counter += 1
125
+
126
+ self.return_to_depot_1()
127
+
128
+ if self.mode in {'beam_search', 'sample'}:
129
+ incomplete = self.check_complete().float()
130
+ cost = self.compute_cost()
131
+ max_cost = cost.max()
132
+ masked_cost = (1 - incomplete) * cost + incomplete * max_cost * 10
133
+
134
+ p = masked_cost.reshape(self.original_batch_size, self.sample_size)
135
+ a = torch.argmin(p, dim=1)
136
+ b = torch.arange(self.original_batch_size).to(self.device)
137
+ b = b * self.sample_size
138
+ index = a + b
139
+ select_data(self, index=index, include_projections=True, include_embeddings=True)
140
+
141
+ self.update_batch_size()
142
+ if self.apply_normalization:
143
+ self.normalization.inverse_normalize(self)
144
+
145
+ output = {
146
+ 'distance': self.fleet.distance.sum(dim=1).squeeze(1).detach(),
147
+ 'total_time': self.fleet.time.sum(dim=1).squeeze(1).detach(),
148
+ 'log_probs': self.log_probs,
149
+ 'late_time': self.fleet.late_time.sum(dim=1).squeeze(1).detach(),
150
+ 'incomplete': self.check_complete(),
151
+ 'path': self.fleet.path,
152
+ 'arrival_times': self.fleet.arrival_times
153
+ }
154
+
155
+ return output