a-ragab-h-m commited on
Commit
bf2229d
·
verified ·
1 Parent(s): e57a9a9

Delete actor.py

Browse files
Files changed (1) hide show
  1. actor.py +0 -751
actor.py DELETED
@@ -1,751 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from Actor.graph import Graph
5
- from Actor.fleet import Fleet
6
- from utils.actor_utils import widen_data, select_data
7
- from Actor.normalization import Normalization
8
-
9
-
10
- #This is the updated version of the actor
11
- class Actor(nn.Module):
12
-
13
- def __init__(self, model=None, num_movers=5, num_neighbors_encoder=5,
14
- num_neighbors_action=5, normalize=False, use_fleet_attention=True,
15
- device='cpu'):
16
- super().__init__()
17
-
18
- self.device = device
19
- self.num_movers = num_movers
20
- self.num_neighbors_encoder = num_neighbors_encoder
21
- self.num_neighbors_action = num_neighbors_action
22
-
23
- self.apply_normalization = normalize
24
- self.normalization = None
25
- self.use_fleet_attention = use_fleet_attention
26
-
27
-
28
- if model is None:
29
- self.mode = 'nearest_neighbors'
30
- else:
31
- self.encoder = model.encoder.to(self.device)
32
- self.decoder = model.decoder.to(self.device)
33
- self.projections = model.projections.to(self.device)
34
- self.fleet_attention = model.fleet_attention.to(self.device)
35
- self.mode = 'train'
36
-
37
-
38
- self.sample_size = 1
39
-
40
-
41
- def train_mode(self, sample_size=1):
42
- self.train()
43
- self.sample_size = sample_size
44
- self.mode = 'train'
45
-
46
-
47
- def greedy_search(self):
48
- self.eval()
49
- self.mode = 'greedy'
50
-
51
-
52
- def nearest_neighbors(self):
53
- self.eval()
54
- self.mode = 'nearest_neighbors'
55
-
56
-
57
- def sample_mode(self, sample_size=10):
58
- self.sample_size = sample_size
59
- self.eval()
60
- self.mode = 'sample'
61
-
62
-
63
- def beam_search(self, sample_size=10):
64
- self.sample_size=sample_size
65
- self.eval()
66
- self.mode = 'beam_search'
67
-
68
-
69
- def update_batch_size(self):
70
- self.batch_size = self.fleet.time.shape[0]
71
- self.fleet.batch_size = self.fleet.time.shape[0]
72
- self.graph.batch_size = self.fleet.time.shape[0]
73
-
74
-
75
-
76
- def forward(self, batch, *args, **kwargs):
77
-
78
- graph_data, fleet_data = batch
79
-
80
- self.original_batch_size = graph_data['distance_matrix'].shape[0]
81
- self.batch_size = self.original_batch_size
82
-
83
- self.num_nodes = graph_data['distance_matrix'].shape[1]
84
- self.num_cars = fleet_data['start_time'].shape[1]
85
-
86
-
87
- self.graph = Graph(graph_data, device=self.device)
88
- self.fleet = Fleet(fleet_data, num_nodes=self.num_nodes, device=self.device)
89
- self.num_nodes = self.graph.distance_matrix.shape[1]
90
- self.update_batch_size()
91
-
92
- self.normalization = Normalization(self, normalize_position=True)
93
- if self.apply_normalization:
94
- self.normalization.normalize(self)
95
-
96
- self.num_depots = self.fleet.num_depots.max().item()
97
- self.num_movers_corrected = int(min(max(self.num_movers, self.num_depots), self.num_cars))
98
-
99
-
100
- if self.mode != 'nearest_neighbors':
101
- encoder_input = self.graph.construct_vector()
102
- encoder_mask = self.compute_encoder_mask()
103
- self.node_embeddings = self.encoder(encoder_input, encoder_mask)
104
- self.node_projections = self.projections(self.node_embeddings)
105
-
106
-
107
- if self.mode == 'sample':
108
- widen_data(self, include_embeddings=True, include_projections=True)
109
- self.update_batch_size()
110
-
111
-
112
- self.log_probs = torch.zeros(self.batch_size).to(self.device)
113
- self.counter = 0
114
- while self.loop_condition() and (self.counter < self.num_nodes*4):
115
-
116
- unavailable_moves = self.check_non_depot_options(use_time=True)
117
- mover_indices = self.get_mover_indices(unavailable_moves=unavailable_moves)
118
- action_mask = self.compute_action_mask(mover_indices=mover_indices,
119
- unavailable_moves=unavailable_moves)
120
-
121
- if self.mode != 'nearest_neighbors':
122
- decoder_input = self.construct_decoder_input(mover_indices=mover_indices)
123
- decoder_mask = self.compute_decoder_mask(mover_indices=mover_indices,
124
- unavailable_moves=unavailable_moves)
125
-
126
- decoder_output = self.decoder(decoder_input=decoder_input,
127
- projections=self.node_projections,
128
- mask=decoder_mask)
129
- else:
130
- decoder_output = None
131
-
132
- next_node, car_to_move, log_prob = self.compute_action(decoder_output, action_mask, mover_indices)
133
-
134
- if self.mode == 'beam_search':
135
- widen_data(self, include_embeddings=True, include_projections=True)
136
- self.update_batch_size()
137
-
138
- self.log_probs += log_prob
139
-
140
- self.update_time(next_node, car_to_move)
141
- self.update_distance(next_node, car_to_move)
142
- self.update_node_path(next_node, car_to_move)
143
- self.update_traversed_nodes()
144
-
145
- #self.return_to_depot()
146
- #self.update_traversed_nodes()
147
-
148
- if (self.mode == 'beam_search') and (self.counter > 0):
149
- self.consolidate_beams()
150
-
151
- self.update_batch_size()
152
- self.counter += 1
153
-
154
- self.return_to_depot_1()
155
-
156
-
157
- if self.mode in {'beam_search', 'sample'}:
158
- # we now must select out the best k elments
159
-
160
- incomplete = self.check_complete().float()
161
- cost = self.compute_cost()
162
-
163
- max_cost = cost.max()
164
- masked_cost = (1 - incomplete)*cost + incomplete*max_cost*10
165
-
166
- p = masked_cost.reshape(self.original_batch_size, self.sample_size)
167
- a = torch.argmin(p, dim=1)
168
-
169
- b = torch.arange(self.original_batch_size).to(self.device)
170
- b = b * self.sample_size
171
-
172
- index = a + b
173
- select_data(self, index=index, include_projections=True, include_embeddings=True)
174
-
175
-
176
- self.update_batch_size()
177
- if self.apply_normalization:
178
- self.normalization.inverse_normalize(self)
179
-
180
- total_distance = self.fleet.distance.sum(dim=1).squeeze(1).detach()
181
- total_time = self.fleet.time.sum(dim=1).squeeze(1).detach()
182
- total_late_time = self.fleet.late_time.sum(dim=1).squeeze(1).detach()
183
-
184
- output = {
185
- 'distance': total_distance.detach(),
186
- 'total_time': total_time.detach(),
187
- 'log_probs': self.log_probs,
188
- 'late_time': total_late_time.detach(),
189
- 'incomplete': self.check_complete(),
190
- 'path': self.fleet.path,
191
- 'arrival_times': self.fleet.arrival_times
192
- }
193
-
194
- return output
195
-
196
-
197
- def consolidate_beams(self):
198
- p = self.log_probs.reshape(self.original_batch_size, self.sample_size * self.sample_size)
199
- a = torch.topk(p, dim=1, k=self.sample_size, largest=True)[1]
200
-
201
- b = torch.arange(self.original_batch_size).unsqueeze(1).repeat(1, self.sample_size).to(self.device)
202
- b = b * self.sample_size * self.sample_size
203
-
204
- ind = (a + b).reshape(-1)
205
- select_data(self, index=ind, include_projections=True, include_embeddings=True)
206
-
207
-
208
- def adjust_arrival_times(self):
209
-
210
- if self.apply_normalization and (self.normalization_params is not None):
211
- num_steps = self.fleet.arrival_times.shape[2]
212
- a = self.normalization_params['earliest_start_time'].reshape(self.batch_size, 1, 1).repeat(1, self.num_cars, num_steps)
213
- b = self.normalization_params['greatest_drive_time'].reshape(self.batch_size, 1, 1).repeat(1, self.num_cars, num_steps)
214
- self.fleet.arrival_times = self.fleet.arrival_times*b + a
215
-
216
- a = self.normalization_params['earliest_start_time']
217
- b = self.normalization_params['greatest_drive_time']
218
- self.fleet.late_time = self.fleet.late_time*b + a
219
-
220
-
221
- def compute_encoder_mask(self):
222
- #check for drive times being too long
223
- time_window_non_compatibility = 1 - self.graph.time_window_compatibility
224
-
225
- #compute diag mask
226
- diag = torch.diag(torch.ones(self.num_nodes)).unsqueeze(0).repeat(self.batch_size, 1, 1).to(self.device)
227
-
228
- # compute neighbors mask
229
- m = (time_window_non_compatibility + diag > 0).float()
230
-
231
- dist = self.graph.distance_matrix*(1 - m) + m*self.graph.max_dist*10
232
-
233
- K = min(self.num_nodes, self.num_neighbors_encoder)
234
- neighbors_index = torch.topk(dist, k=K, dim=2, largest=False)[1] # ~ [batch, num_nodes, num_neighbors]
235
-
236
- a = torch.arange(self.num_nodes).reshape(1, 1, -1, 1).repeat(self.batch_size, self.num_nodes, 1, K).to(self.device)
237
- b = neighbors_index.unsqueeze(2).repeat(1, 1, self.num_nodes, 1)
238
- neighbors_mask = (a == b).float().sum(dim=3)
239
-
240
-
241
- m = (time_window_non_compatibility == 0).float()
242
- neighbs_time_mask = neighbors_mask*m
243
-
244
- #compute depot mask
245
- v = self.graph.depot.reshape(self.batch_size, 1, self.num_nodes).repeat(1, self.num_nodes, 1)
246
- w = self.graph.depot.reshape(self.batch_size, self.num_nodes, 1).repeat(1, 1, self.num_nodes)
247
- depot_mask = (v + w > 0).float()
248
-
249
- diag = torch.diag(torch.ones(self.num_nodes)).unsqueeze(0).repeat(self.batch_size, 1, 1).to(self.device)
250
-
251
- mask = (neighbs_time_mask + depot_mask + diag > 0).float()
252
- return mask
253
-
254
-
255
- def compute_depoyment_priority_score(self):
256
- #number of available nodes
257
- available_nodes = (self.check_non_depot_options().float() == 0).float()
258
- num_available_nodes = available_nodes.sum(dim=2)
259
-
260
- #excess
261
- ind = self.fleet.node.reshape(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_nodes)
262
- distances = torch.gather(self.graph.distance_matrix, dim=1, index=ind)
263
-
264
- max_distance = self.graph.distance_matrix.reshape(self.batch_size, -1).max(dim=1)[0].reshape(
265
- self.batch_size, 1, 1).repeat(1, self.num_cars, self.num_nodes)
266
- excess = ((max_distance - distances)*available_nodes).sum(dim=2)
267
-
268
-
269
- max_excess = excess.max()
270
-
271
- score = (num_available_nodes + max_excess)*10 + excess
272
- return score
273
-
274
-
275
-
276
- def check_complete(self):
277
- has_untraversed_nodes = ((self.fleet.traversed_nodes == 0).float().sum(dim=1) > 0)
278
- #has_cars_in_field = ((self.fleet.node != self.fleet.depot).sum(dim=1) > 0)
279
- #incomplete = (has_untraversed_nodes | has_cars_in_field)
280
- incomplete = has_untraversed_nodes
281
- return incomplete
282
-
283
-
284
- def loop_condition(self):
285
- a = self.check_complete().sum().item()
286
- if a == 0:
287
- return False
288
- else:
289
- return True
290
-
291
-
292
- def compute_cost(self):
293
-
294
- dist = self.fleet.distance.sum(dim=1).squeeze(1)
295
- time = self.fleet.time.sum(dim=1).squeeze(1)
296
- lateness = self.fleet.late_time.sum(dim=1).squeeze(1)
297
-
298
- #normally we compute the cost as a linear combination of these three quantities
299
- return time
300
-
301
-
302
-
303
- def compute_total_distance(self):
304
- p_2 = self.fleet.path
305
- p_1 = torch.cat([p_2[:,:,-1:], p_2[:,:,:-1]], dim=2)
306
-
307
- mat = self.graph.distance_matrix.reshape(self.batch_size, 1, self.num_nodes, self.num_nodes).repeat(1, self.num_cars, 1, 1)
308
- ind_1 = p_1.unsqueeze(3).repeat(1, 1, 1, self.num_nodes)
309
-
310
- d = torch.gather(mat, dim=2, index=ind_1)
311
- ind_2 = p_2.unsqueeze(3)
312
- pairwise_distances = torch.gather(d, dim=3, index=ind_2).squeeze(3)
313
-
314
- self.car_distances = pairwise_distances.sum(dim=2)
315
- self.total_distance = self.car_distances.sum(dim=1)
316
-
317
- if self.apply_normalization and (self.normalization_params is not None):
318
- distance_multiplier = self.normalization_params['greatest_distance']
319
- self.total_distance = self.total_distance*distance_multiplier.reshape(self.batch_size)
320
-
321
- return self.total_distance
322
-
323
-
324
- def update_traversed_nodes(self):
325
-
326
- path_length = self.fleet.path.shape[2]
327
- a = self.fleet.path.reshape(self.batch_size, self.num_cars, path_length, 1).repeat(1, 1, 1, self.num_nodes)
328
- b = torch.arange(self.num_nodes).reshape(
329
- 1, 1, 1, self.num_nodes).repeat(
330
- self.batch_size, self.num_cars, path_length, 1).to(self.device)
331
-
332
- s = (a == b).float().reshape(self.batch_size, self.num_cars*path_length, self.num_nodes).sum(dim=1)
333
- self.fleet.traversed_nodes = (s > 0).float()
334
-
335
-
336
- def compute_action(self, decoder_output, action_mask, mover_indices):
337
-
338
- if self.mode == 'nearest_neighbors':
339
-
340
- assert decoder_output is None
341
-
342
- num_movers = mover_indices.shape[1]
343
- mover_nodes = torch.gather(self.fleet.node, dim=1, index=mover_indices)
344
-
345
- ind = mover_nodes.reshape(self.batch_size, num_movers, 1).repeat(1, 1, self.num_nodes)
346
- distances = torch.gather(self.graph.distance_matrix, dim=1, index=ind)
347
-
348
- a = mover_nodes.reshape(self.batch_size, num_movers, 1).repeat(1, 1, self.num_nodes)
349
- b = torch.arange(self.num_nodes).reshape(1, 1, self.num_nodes).repeat(self.batch_size, num_movers, 1).to(self.device)
350
- current_node_indicator = (a == b).float()
351
-
352
- a = action_mask.reshape(self.batch_size, -1)
353
- no_options = (a.sum(dim=1) == 0).reshape(-1, 1, 1).repeat(1, num_movers, self.num_nodes).float()
354
-
355
- mask = action_mask * (1 - no_options) + current_node_indicator * no_options
356
- masked_distances = distances*mask + self.graph.max_dist*(1 - mask)*10
357
-
358
- x = masked_distances.reshape(self.batch_size, -1)
359
- ind = torch.argmin(x, dim=1).unsqueeze(1)
360
-
361
- mover_index = ind // self.num_nodes
362
- next_node = ind % self.num_nodes
363
- car_to_move = torch.gather(mover_indices, dim=1, index=mover_index)
364
-
365
- log_prob = torch.zeros(self.batch_size).to(self.device)
366
- return next_node, car_to_move, log_prob
367
-
368
- else:
369
-
370
- assert decoder_output is not None
371
-
372
- num_movers = mover_indices.shape[1]
373
- assert (num_movers == action_mask.shape[1]) and (num_movers == decoder_output.shape[1])
374
-
375
- a = action_mask.reshape(self.batch_size, -1)
376
- no_options = (a.sum(dim=1) == 0).reshape(-1, 1, 1).repeat(1, num_movers, self.num_nodes).float()
377
-
378
- mover_nodes = torch.gather(self.fleet.node, dim=1, index=mover_indices)
379
- a = torch.arange(self.num_nodes).reshape(1, 1, -1).repeat(self.batch_size, num_movers, 1).to(self.device)
380
- b = mover_nodes.reshape(self.batch_size, num_movers, 1).repeat(1, 1, self.num_nodes)
381
- default_option = (a == b).float()
382
-
383
- mask = action_mask*(1 - no_options) + default_option*no_options
384
- masked_decoder_output = decoder_output + mask.log()
385
-
386
- x = masked_decoder_output.reshape(self.batch_size, -1)
387
- probs = torch.softmax(x, dim=1)
388
-
389
- if self.mode == 'greedy':
390
- prob, ind = torch.max(probs, dim=1)[0].unsqueeze(1), torch.argmax(probs, dim=1).unsqueeze(1)
391
-
392
- elif self.mode in {'train', 'sample'}:
393
- ind = torch.multinomial(probs, num_samples=1)
394
- prob = torch.gather(probs, dim=1, index=ind)
395
-
396
- elif self.mode == 'beam_search':
397
- prob, ind = torch.topk(probs, dim=1, k=self.sample_size, largest=True)
398
-
399
- mover_index = ind // self.num_nodes
400
- next_node = ind % self.num_nodes
401
- car_to_move = torch.gather(mover_indices, dim=1, index=mover_index)
402
-
403
- next_node = next_node.reshape(-1)
404
- car_to_move = car_to_move.reshape(-1)
405
- prob = prob.reshape(-1)
406
-
407
- return next_node, car_to_move, prob.log()
408
-
409
-
410
- def update_distance(self, next_node, car_to_move):
411
- ind = car_to_move.reshape(self.batch_size, 1)
412
- n = self.fleet.node.reshape(self.batch_size, self.num_cars)
413
- current_node = torch.gather(n, dim=1, index=ind).squeeze(1)
414
-
415
- #compute distance to next node
416
- ind_1 = current_node.reshape(-1, 1, 1).repeat(1, 1, self.num_nodes)
417
- distances = torch.gather(self.graph.distance_matrix, dim=1, index=ind_1).squeeze(1)
418
- ind_2 = next_node.reshape(-1, 1)
419
- dist_to_next_node = torch.gather(distances, dim=1, index=ind_2).squeeze(1)
420
-
421
- #compute current distance
422
- t = self.fleet.distance.reshape(self.batch_size, self.num_cars)
423
- current_distance = torch.gather(t, dim=1, index=car_to_move.reshape(self.batch_size, 1)).squeeze(1)
424
-
425
- #compute updated_distance
426
- updated_distance = current_distance + dist_to_next_node
427
-
428
- #compute mover mask
429
- a = torch.arange(self.num_cars).reshape(1, -1).repeat(self.batch_size, 1).to(self.device)
430
- b = car_to_move.reshape(self.batch_size, 1).repeat(1, self.num_cars)
431
- update_mask = (a == b).float().unsqueeze(2)
432
-
433
- #update distance
434
- t = updated_distance.reshape(self.batch_size, 1, 1).repeat(1, self.num_cars, 1)
435
- self.fleet.distance = self.fleet.distance * (1 - update_mask) + t * update_mask
436
-
437
-
438
- def update_time(self, next_node, car_to_move):
439
- #get current node of mover
440
- ind = car_to_move.reshape(self.batch_size, 1)
441
- n = self.fleet.node.reshape(self.batch_size, self.num_cars)
442
- current_node = torch.gather(n, dim=1, index=ind).squeeze(1)
443
-
444
- #compute time to next node
445
- ind_1 = current_node.reshape(-1, 1, 1).repeat(1, 1, self.num_nodes)
446
- drive_times = torch.gather(self.graph.time_matrix, dim=1, index=ind_1).squeeze(1)
447
- ind_2 = next_node.reshape(-1, 1)
448
- time_to_next_node = torch.gather(drive_times, dim=1, index=ind_2).squeeze(1)
449
-
450
- #compute start time at next node
451
- start_time = self.graph.start_time.reshape(self.batch_size, self.num_nodes)
452
- ind = next_node.reshape(self.batch_size, 1)
453
- next_start_time = torch.gather(start_time, dim=1, index=ind).squeeze(1)
454
-
455
- #compute current time
456
- t = self.fleet.time.reshape(self.batch_size, self.num_cars)
457
- current_node_time = torch.gather(t, dim=1, index=car_to_move.reshape(self.batch_size, 1)).squeeze(1)
458
-
459
- #compute updated_time
460
- a = time_to_next_node + current_node_time
461
- b = next_start_time
462
- updated_time = (a > b).float()*a + (a <= b).float()*b
463
-
464
- #compute end_time at next node
465
- ind = next_node.reshape(-1, 1)
466
- end_times = self.graph.end_time.reshape(self.batch_size, self.num_nodes)
467
- end_time_next_node = torch.gather(end_times, dim=1, index=ind).squeeze(1)
468
- late_time = F.relu(updated_time - end_time_next_node)
469
-
470
-
471
- #compute mover mask
472
- a = torch.arange(self.num_cars).reshape(1, -1).repeat(self.batch_size, 1).to(self.device)
473
- b = car_to_move.reshape(self.batch_size, 1).repeat(1, self.num_cars)
474
- update_mask = (a == b).float().unsqueeze(2)
475
-
476
- #update time
477
- t = updated_time.reshape(self.batch_size, 1, 1).repeat(1, self.num_cars, 1)
478
- self.fleet.time = self.fleet.time*(1 - update_mask) + t*update_mask
479
-
480
- # update_late_time
481
- l = late_time.reshape(self.batch_size, 1, 1).repeat(1, self.num_cars, 1)
482
- self.fleet.late_time = self.fleet.late_time*(1 - update_mask) + l*update_mask
483
-
484
-
485
-
486
- def update_node_path(self, next_node, car_to_move):
487
- #compute mover mask
488
- a = torch.arange(self.num_cars).reshape(1, -1).repeat(self.batch_size, 1).to(self.device)
489
- b = car_to_move.reshape(self.batch_size, 1).repeat(1, self.num_cars)
490
- update_mask = (a == b).long()
491
-
492
- new_node = next_node.reshape(self.batch_size, 1).repeat(1, self.num_cars)
493
- self.fleet.node = update_mask*new_node + self.fleet.node*(1 - update_mask)
494
-
495
- L = [self.fleet.path, self.fleet.node.unsqueeze(2)]
496
- self.fleet.path = torch.cat(L, dim=2)
497
-
498
- t = self.fleet.time.reshape(self.batch_size, self.num_cars, 1)
499
- H = [self.fleet.arrival_times, t]
500
- self.fleet.arrival_times = torch.cat(H, dim=2)
501
-
502
-
503
-
504
- def check_non_depot_options(self, use_time=True):
505
- '''
506
- Output is a byte tensor of shape [batch, num_cars, num_nodes] with entry (i,j) = 1 if
507
- the move of car i to node j is invalid
508
- '''
509
-
510
- if use_time:
511
- #check for arrival times
512
- too_far = 1 - self.check_arival_times()
513
-
514
- #is depot
515
- a = self.fleet.depot.reshape(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_nodes)
516
- b = torch.arange(self.num_nodes).reshape(1, 1, -1).repeat(self.batch_size, self.num_cars, 1).to(self.device)
517
- is_depot = (a == b)
518
-
519
- #check traversed nodes
520
- a = self.fleet.traversed_nodes.reshape(self.batch_size, 1, self.num_nodes).repeat(1, self.num_cars, 1)
521
- traversed_nodes = (a == 1)
522
-
523
- # has value of 1 if the move to that node is NOT possible
524
- if use_time:
525
- unavailable_moves = (too_far | is_depot | traversed_nodes)
526
- else:
527
- unavailable_moves = (is_depot | traversed_nodes)
528
-
529
- return unavailable_moves
530
-
531
-
532
- def compute_decoder_mask(self, mover_indices, unavailable_moves):
533
-
534
- num_movers = mover_indices.shape[1]
535
-
536
- mover_nodes = torch.gather(self.fleet.node, dim=1, index=mover_indices)
537
-
538
- ind = mover_indices.reshape(self.batch_size, num_movers, 1).repeat(1, 1, self.num_nodes)
539
- unavailable_moves = torch.gather(unavailable_moves, dim=1, index=ind)
540
-
541
- no_options = ((1 - unavailable_moves.float()).sum(dim=2) == 0).unsqueeze(2).repeat(1, 1, self.num_nodes).float()
542
-
543
- a = mover_nodes.reshape(self.batch_size, num_movers, 1).repeat(1, 1, self.num_nodes)
544
- b = torch.arange(self.num_nodes).reshape(1, 1, -1).repeat(self.batch_size, num_movers, 1).to(self.device)
545
- default_move = (a == b).float()
546
-
547
- decoder_mask = (1 - unavailable_moves.float())*(1 - no_options) + default_move*no_options
548
- return decoder_mask
549
-
550
-
551
- def compute_action_mask(self, mover_indices, unavailable_moves):
552
-
553
- if self.mode != 'nearest_neighbors':
554
- ######################################################
555
- d = torch.diag(torch.ones(self.num_nodes)).unsqueeze(0).repeat(self.batch_size, 1, 1).to(self.device)
556
- ind = self.fleet.node.reshape(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_nodes)
557
-
558
- diag = torch.gather(d, dim=1, index=ind)
559
- distances = torch.gather(self.graph.distance_matrix, dim=1, index=ind)
560
- m = (unavailable_moves.float() + diag > 0).float()
561
-
562
- masked_distances = distances*(1 - m) + m*self.graph.max_dist*100
563
- K = min(self.num_neighbors_action, self.num_nodes)
564
- neighbor_indices = torch.topk(masked_distances, dim=2, k=K, largest=False)[1]
565
-
566
- a = torch.arange(self.num_nodes).reshape(1, 1, -1, 1).repeat(self.batch_size, self.num_cars, 1, K).to(self.device)
567
- b = neighbor_indices.reshape(self.batch_size, self.num_cars, 1, K).repeat(1, 1, self.num_nodes, 1)
568
-
569
- non_neighbor_mask = ((a == b).float().sum(dim=3) == 0)
570
- mask = 1 - (non_neighbor_mask | unavailable_moves).float()
571
- ######################################################
572
- else:
573
- mask = 1 - unavailable_moves.float()
574
-
575
- num_movers = mover_indices.shape[1]
576
- ind = mover_indices.reshape(self.batch_size, num_movers, 1).repeat(1, 1, self.num_nodes)
577
- action_mask = torch.gather(mask, dim=1, index=ind)
578
- return action_mask
579
-
580
-
581
-
582
- def construct_decoder_input(self, mover_indices):
583
-
584
- embedding_size = self.node_embeddings.shape[2]
585
- mask = self.check_non_depot_options(use_time=True)
586
- mask = mask.permute(0, 2, 1).unsqueeze(3).repeat(1, 1, 1, embedding_size).float()
587
-
588
- node_vectors = self.node_embeddings.reshape(self.batch_size, self.num_nodes, 1, embedding_size).repeat(1, 1, self.num_cars, 1)
589
-
590
- #depot vector
591
- start_ind = self.fleet.depot.reshape(self.batch_size, 1, self.num_cars, 1).repeat(1, 1, 1, embedding_size)
592
- a = torch.gather(node_vectors, dim=1, index=start_ind)
593
- depot_vector = a.squeeze(1)
594
-
595
- #current node vector
596
- current_ind = self.fleet.node.reshape(self.batch_size, 1, self.num_cars, 1).repeat(1, 1, 1, embedding_size)
597
- b = torch.gather(node_vectors, dim=1, index=current_ind)
598
- current_node_vector = b.squeeze(1)
599
-
600
- #mean graph vector
601
- mean_graph_vector = node_vectors.mean(dim=1)
602
-
603
-
604
- #current feature values
605
- feature_values = self.fleet.construct_vector()
606
-
607
-
608
- #other cars in field
609
- num_movers = mover_indices.shape[1]
610
- if num_movers == 1:
611
- movers_vector = current_node_vector*0
612
- elif not self.use_fleet_attention:
613
- a = mover_indices.reshape(self.batch_size, num_movers, 1).repeat(1, 1, self.num_cars)
614
- b = torch.arange(self.num_cars).reshape(1, 1, self.num_cars).repeat(self.batch_size, num_movers, 1).to(self.device)
615
- c = ((a == b).float().sum(dim=1) > 0).float()
616
- movers_mask = c.reshape(self.batch_size, self.num_cars, 1).repeat(1, 1, embedding_size)
617
- movers_vector = ((current_node_vector*movers_mask).sum(dim=1).unsqueeze(1) - current_node_vector)/(num_movers - 1)
618
- else:
619
- x = torch.cat([current_node_vector, feature_values], dim=2)
620
- movers_vector = self.fleet_attention(x)
621
-
622
-
623
- L = [current_node_vector, depot_vector, mean_graph_vector, movers_vector, feature_values]
624
- pre_output = torch.cat(L, dim=2)
625
-
626
-
627
- num_movers = mover_indices.shape[1]
628
- ind = mover_indices.reshape(self.batch_size, num_movers, 1).repeat(1, 1, pre_output.shape[2])
629
- output = torch.gather(pre_output, dim=1, index=ind)
630
-
631
- return output
632
-
633
-
634
- def get_mover_indices(self, unavailable_moves):
635
- depot = self.fleet.depot.reshape(self.batch_size, self.num_cars).long()
636
- current_node = self.fleet.node.reshape(self.batch_size, self.num_cars).long()
637
-
638
- # find all cars with "no options"
639
- has_option = ((1 - unavailable_moves.float()).sum(dim=2) > 0)
640
-
641
- at_depot = (current_node == depot)
642
- in_field = (current_node != depot)
643
- active_in_field = in_field & has_option
644
- active_at_depot = at_depot & has_option
645
-
646
- deployment_score = self.compute_depoyment_priority_score()
647
- max_deployment_score = deployment_score.max()
648
-
649
- A = max_deployment_score * 100
650
- B = deployment_score * 10
651
-
652
- score = active_in_field.float() * A + active_at_depot.float() * B
653
- score = score.reshape(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_depots)
654
-
655
- a = self.fleet.depot.reshape(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_depots)
656
- b = torch.arange(self.num_depots).reshape(1, 1, self.num_depots).repeat(self.batch_size, self.num_cars, 1).to(self.device)
657
- m = (a == b).float()
658
-
659
- score = score*m
660
-
661
- K = self.num_movers_corrected
662
- indices = torch.topk(score, k=K, dim=1, largest=True)[1]
663
- indices = indices.reshape(self.batch_size, self.num_depots*K)
664
- return indices
665
-
666
-
667
- def check_arival_times(self):
668
- #check for arrival times
669
- d = self.graph.time_matrix.unsqueeze(1).repeat(1, self.num_cars, 1, 1)
670
- ind = self.fleet.node.reshape(self.batch_size, self.num_cars, 1, 1).repeat(1, 1, 1, self.num_nodes)
671
- drive_times = torch.gather(d, dim=2, index=ind).squeeze(2)
672
- t = self.fleet.time.reshape(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_nodes)
673
- arival_times = drive_times + t
674
- b = self.graph.end_time.reshape(self.batch_size, 1, self.num_nodes).repeat(1, self.num_cars, 1)
675
- attainable = (arival_times <= b)
676
- return attainable
677
-
678
-
679
- def return_to_depot(self):
680
- unavailable_moves = 1 - self.check_non_depot_options(use_time=False).float()
681
- return_to_depot = (unavailable_moves.reshape(self.batch_size, -1).sum(dim=1) == 0).float()
682
- return_to_depot = return_to_depot.reshape(self.batch_size, 1).repeat(1, self.num_cars)
683
-
684
-
685
- if return_to_depot.sum().item() > 0:
686
-
687
- depot = self.fleet.depot.reshape(self.batch_size, self.num_cars).long()
688
- node = self.fleet.node.reshape(self.batch_size, self.num_cars).long()
689
-
690
- #compute next node
691
- return_to_depot = return_to_depot.long()
692
- next_node = return_to_depot*depot + (1 - return_to_depot)*node
693
-
694
- #update time
695
- return_to_depot = return_to_depot.unsqueeze(2).float()
696
-
697
- current_node = self.fleet.node
698
- ind_1 = current_node.reshape(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_nodes)
699
- drive_times = torch.gather(self.graph.time_matrix, dim=1, index=ind_1)
700
- ind_2 = next_node.reshape(self.batch_size, self.num_cars, 1)
701
- time_to_next_node = torch.gather(drive_times, dim=2, index=ind_2)
702
- self.fleet.time = self.fleet.time + time_to_next_node*return_to_depot
703
-
704
- #update node
705
- self.fleet.node = next_node
706
-
707
- #update path
708
- n = self.fleet.node.reshape(self.batch_size, self.num_cars, 1)
709
- self.fleet.path = torch.cat([self.fleet.path, n], dim=2)
710
-
711
- #update arrival time
712
- t = self.fleet.time.reshape(self.batch_size, self.num_cars, 1)
713
- self.fleet.arrival_times = torch.cat([self.fleet.arrival_times, t], dim=2)
714
-
715
-
716
- def return_to_depot_1(self):
717
-
718
- depot = self.fleet.depot.reshape(self.batch_size, self.num_cars).long()
719
- node = self.fleet.node.reshape(self.batch_size, self.num_cars).long()
720
-
721
- #compute next node
722
- next_node = depot
723
-
724
- #update time
725
- current_node = self.fleet.node
726
- ind_1 = current_node.reshape(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_nodes)
727
- drive_times = torch.gather(self.graph.time_matrix, dim=1, index=ind_1)
728
- ind_2 = next_node.reshape(self.batch_size, self.num_cars, 1)
729
- time_to_next_node = torch.gather(drive_times, dim=2, index=ind_2)
730
- self.fleet.time = self.fleet.time + time_to_next_node
731
-
732
-
733
- #update distance
734
- current_node = self.fleet.node
735
- ind_1 = current_node.reshape(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_nodes)
736
- drive_distances = torch.gather(self.graph.distance_matrix, dim=1, index=ind_1)
737
- ind_2 = next_node.reshape(self.batch_size, self.num_cars, 1)
738
- distance_to_next_node = torch.gather(drive_distances, dim=2, index=ind_2)
739
- self.fleet.distance = self.fleet.distance + distance_to_next_node
740
-
741
-
742
- #update node
743
- self.fleet.node = next_node
744
-
745
- #update path
746
- n = self.fleet.node.reshape(self.batch_size, self.num_cars, 1)
747
- self.fleet.path = torch.cat([self.fleet.path, n], dim=2)
748
-
749
- #update arrival time
750
- t = self.fleet.time.reshape(self.batch_size, self.num_cars, 1)
751
- self.fleet.arrival_times = torch.cat([self.fleet.arrival_times, t], dim=2)