a-ragab-h-m commited on
Commit
e57a9a9
·
verified ·
1 Parent(s): 184a603

Delete actor_modified.py

Browse files
Files changed (1) hide show
  1. actor_modified.py +0 -155
actor_modified.py DELETED
@@ -1,155 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from Actor.graph import Graph
5
- from Actor.fleet import Fleet
6
- from utils.actor_utils import widen_data, select_data
7
- from Actor.normalization import Normalization
8
-
9
-
10
- class Actor(nn.Module):
11
- def __init__(self, model=None, num_movers=5, num_neighbors_encoder=5,
12
- num_neighbors_action=5, normalize=False, use_fleet_attention=True,
13
- device='cpu'):
14
- super().__init__()
15
-
16
- self.device = device
17
- self.num_movers = num_movers
18
- self.num_neighbors_encoder = num_neighbors_encoder
19
- self.num_neighbors_action = num_neighbors_action
20
-
21
- self.apply_normalization = normalize
22
- self.normalization = None
23
- self.use_fleet_attention = use_fleet_attention
24
-
25
- if model is None:
26
- self.mode = 'nearest_neighbors'
27
- else:
28
- self.encoder = model.encoder.to(self.device)
29
- self.decoder = model.decoder.to(self.device)
30
- self.projections = model.projections.to(self.device)
31
- self.fleet_attention = model.fleet_attention.to(self.device)
32
- self.mode = 'train'
33
-
34
- self.sample_size = 1
35
-
36
- def train_mode(self, sample_size=1):
37
- self.train()
38
- self.sample_size = sample_size
39
- self.mode = 'train'
40
-
41
- def greedy_search(self):
42
- self.eval()
43
- self.mode = 'greedy'
44
-
45
- def nearest_neighbors(self):
46
- self.eval()
47
- self.mode = 'nearest_neighbors'
48
-
49
- def sample_mode(self, sample_size=10):
50
- self.sample_size = sample_size
51
- self.eval()
52
- self.mode = 'sample'
53
-
54
- def beam_search(self, sample_size=10):
55
- self.sample_size = sample_size
56
- self.eval()
57
- self.mode = 'beam_search'
58
-
59
- def update_batch_size(self):
60
- self.batch_size = self.fleet.time.shape[0]
61
- self.fleet.batch_size = self.batch_size
62
- self.graph.batch_size = self.batch_size
63
-
64
- def forward(self, batch, *args, **kwargs):
65
- graph_data, fleet_data = batch
66
- self.original_batch_size = graph_data['distance_matrix'].shape[0]
67
- self.batch_size = self.original_batch_size
68
-
69
- self.num_nodes = graph_data['distance_matrix'].shape[1]
70
- self.num_cars = fleet_data['start_time'].shape[1]
71
-
72
- self.graph = Graph(graph_data, device=self.device)
73
- self.fleet = Fleet(fleet_data, num_nodes=self.num_nodes, device=self.device)
74
- self.update_batch_size()
75
-
76
- self.normalization = Normalization(self, normalize_position=True)
77
- if self.apply_normalization:
78
- self.normalization.normalize(self)
79
-
80
- self.num_depots = self.fleet.num_depots.max().item()
81
- self.num_movers_corrected = int(min(max(self.num_movers, self.num_depots), self.num_cars))
82
-
83
- if self.mode != 'nearest_neighbors':
84
- encoder_input = self.graph.construct_vector()
85
- encoder_mask = self.compute_encoder_mask()
86
- self.node_embeddings = self.encoder(encoder_input, encoder_mask)
87
- self.node_projections = self.projections(self.node_embeddings)
88
-
89
- if self.mode == 'sample':
90
- widen_data(self, include_embeddings=True, include_projections=True)
91
- self.update_batch_size()
92
-
93
- self.log_probs = torch.zeros(self.batch_size).to(self.device)
94
- self.counter = 0
95
- while self.loop_condition() and (self.counter < self.num_nodes * 4):
96
- unavailable_moves = self.check_non_depot_options(use_time=True)
97
- mover_indices = self.get_mover_indices(unavailable_moves=unavailable_moves)
98
- action_mask = self.compute_action_mask(mover_indices=mover_indices, unavailable_moves=unavailable_moves)
99
-
100
- decoder_output = None
101
- if self.mode != 'nearest_neighbors':
102
- decoder_input = self.construct_decoder_input(mover_indices=mover_indices)
103
- decoder_mask = self.compute_decoder_mask(mover_indices=mover_indices, unavailable_moves=unavailable_moves)
104
- decoder_output = self.decoder(decoder_input=decoder_input,
105
- projections=self.node_projections,
106
- mask=decoder_mask)
107
-
108
- next_node, car_to_move, log_prob = self.compute_action(decoder_output, action_mask, mover_indices)
109
-
110
- if self.mode == 'beam_search':
111
- widen_data(self, include_embeddings=True, include_projections=True)
112
- self.update_batch_size()
113
-
114
- self.log_probs += log_prob
115
- self.update_time(next_node, car_to_move)
116
- self.update_distance(next_node, car_to_move)
117
- self.update_node_path(next_node, car_to_move)
118
- self.update_traversed_nodes()
119
-
120
- if (self.mode == 'beam_search') and (self.counter > 0):
121
- self.consolidate_beams()
122
-
123
- self.update_batch_size()
124
- self.counter += 1
125
-
126
- self.return_to_depot_1()
127
-
128
- if self.mode in {'beam_search', 'sample'}:
129
- incomplete = self.check_complete().float()
130
- cost = self.compute_cost()
131
- max_cost = cost.max()
132
- masked_cost = (1 - incomplete) * cost + incomplete * max_cost * 10
133
-
134
- p = masked_cost.reshape(self.original_batch_size, self.sample_size)
135
- a = torch.argmin(p, dim=1)
136
- b = torch.arange(self.original_batch_size).to(self.device)
137
- b = b * self.sample_size
138
- index = a + b
139
- select_data(self, index=index, include_projections=True, include_embeddings=True)
140
-
141
- self.update_batch_size()
142
- if self.apply_normalization:
143
- self.normalization.inverse_normalize(self)
144
-
145
- output = {
146
- 'distance': self.fleet.distance.sum(dim=1).squeeze(1).detach(),
147
- 'total_time': self.fleet.time.sum(dim=1).squeeze(1).detach(),
148
- 'log_probs': self.log_probs,
149
- 'late_time': self.fleet.late_time.sum(dim=1).squeeze(1).detach(),
150
- 'incomplete': self.check_complete(),
151
- 'path': self.fleet.path,
152
- 'arrival_times': self.fleet.arrival_times
153
- }
154
-
155
- return output