|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from Actor.graph import Graph |
|
from Actor.fleet import Fleet |
|
from utils.actor_utils import widen_data, select_data |
|
from Actor.normalization import Normalization |
|
|
|
|
|
|
|
class Actor(nn.Module): |
|
|
|
def __init__(self, model=None, num_movers=5, num_neighbors_encoder=5, |
|
num_neighbors_action=5, normalize=False, use_fleet_attention=True, |
|
device='cpu'): |
|
super().__init__() |
|
|
|
self.device = device |
|
self.num_movers = num_movers |
|
self.num_neighbors_encoder = num_neighbors_encoder |
|
self.num_neighbors_action = num_neighbors_action |
|
|
|
self.apply_normalization = normalize |
|
self.normalization = None |
|
self.use_fleet_attention = use_fleet_attention |
|
|
|
|
|
if model is None: |
|
self.mode = 'nearest_neighbors' |
|
else: |
|
self.encoder = model.encoder.to(self.device) |
|
self.decoder = model.decoder.to(self.device) |
|
self.projections = model.projections.to(self.device) |
|
self.fleet_attention = model.fleet_attention.to(self.device) |
|
self.mode = 'train' |
|
|
|
|
|
self.sample_size = 1 |
|
|
|
|
|
def train_mode(self, sample_size=1): |
|
self.train() |
|
self.sample_size = sample_size |
|
self.mode = 'train' |
|
|
|
|
|
def greedy_search(self): |
|
self.eval() |
|
self.mode = 'greedy' |
|
|
|
|
|
def nearest_neighbors(self): |
|
self.eval() |
|
self.mode = 'nearest_neighbors' |
|
|
|
|
|
def sample_mode(self, sample_size=10): |
|
self.sample_size = sample_size |
|
self.eval() |
|
self.mode = 'sample' |
|
|
|
|
|
def beam_search(self, sample_size=10): |
|
self.sample_size=sample_size |
|
self.eval() |
|
self.mode = 'beam_search' |
|
|
|
|
|
def update_batch_size(self): |
|
self.batch_size = self.fleet.time.shape[0] |
|
self.fleet.batch_size = self.fleet.time.shape[0] |
|
self.graph.batch_size = self.fleet.time.shape[0] |
|
|
|
|
|
|
|
def forward(self, batch, *args, **kwargs): |
|
|
|
graph_data, fleet_data = batch |
|
|
|
self.original_batch_size = graph_data['distance_matrix'].shape[0] |
|
self.batch_size = self.original_batch_size |
|
|
|
self.num_nodes = graph_data['distance_matrix'].shape[1] |
|
self.num_cars = fleet_data['start_time'].shape[1] |
|
|
|
|
|
self.graph = Graph(graph_data, device=self.device) |
|
self.fleet = Fleet(fleet_data, num_nodes=self.num_nodes, device=self.device) |
|
self.num_nodes = self.graph.distance_matrix.shape[1] |
|
self.update_batch_size() |
|
|
|
self.normalization = Normalization(self, normalize_position=True) |
|
if self.apply_normalization: |
|
self.normalization.normalize(self) |
|
|
|
self.num_depots = self.fleet.num_depots.max().item() |
|
self.num_movers_corrected = int(min(max(self.num_movers, self.num_depots), self.num_cars)) |
|
|
|
|
|
if self.mode != 'nearest_neighbors': |
|
encoder_input = self.graph.construct_vector() |
|
encoder_mask = self.compute_encoder_mask() |
|
self.node_embeddings = self.encoder(encoder_input, encoder_mask) |
|
self.node_projections = self.projections(self.node_embeddings) |
|
|
|
|
|
if self.mode == 'sample': |
|
widen_data(self, include_embeddings=True, include_projections=True) |
|
self.update_batch_size() |
|
|
|
|
|
self.log_probs = torch.zeros(self.batch_size).to(self.device) |
|
self.counter = 0 |
|
while self.loop_condition() and (self.counter < self.num_nodes*4): |
|
|
|
unavailable_moves = self.check_non_depot_options(use_time=True) |
|
mover_indices = self.get_mover_indices(unavailable_moves=unavailable_moves) |
|
action_mask = self.compute_action_mask(mover_indices=mover_indices, |
|
unavailable_moves=unavailable_moves) |
|
|
|
if self.mode != 'nearest_neighbors': |
|
decoder_input = self.construct_decoder_input(mover_indices=mover_indices) |
|
decoder_mask = self.compute_decoder_mask(mover_indices=mover_indices, |
|
unavailable_moves=unavailable_moves) |
|
|
|
decoder_output = self.decoder(decoder_input=decoder_input, |
|
projections=self.node_projections, |
|
mask=decoder_mask) |
|
else: |
|
decoder_output = None |
|
|
|
next_node, car_to_move, log_prob = self.compute_action(decoder_output, action_mask, mover_indices) |
|
|
|
if self.mode == 'beam_search': |
|
widen_data(self, include_embeddings=True, include_projections=True) |
|
self.update_batch_size() |
|
|
|
self.log_probs += log_prob |
|
|
|
self.update_time(next_node, car_to_move) |
|
self.update_distance(next_node, car_to_move) |
|
self.update_node_path(next_node, car_to_move) |
|
self.update_traversed_nodes() |
|
|
|
|
|
|
|
|
|
if (self.mode == 'beam_search') and (self.counter > 0): |
|
self.consolidate_beams() |
|
|
|
self.update_batch_size() |
|
self.counter += 1 |
|
|
|
self.return_to_depot_1() |
|
|
|
|
|
if self.mode in {'beam_search', 'sample'}: |
|
|
|
|
|
incomplete = self.check_complete().float() |
|
cost = self.compute_cost() |
|
|
|
max_cost = cost.max() |
|
masked_cost = (1 - incomplete)*cost + incomplete*max_cost*10 |
|
|
|
p = masked_cost.reshape(self.original_batch_size, self.sample_size) |
|
a = torch.argmin(p, dim=1) |
|
|
|
b = torch.arange(self.original_batch_size).to(self.device) |
|
b = b * self.sample_size |
|
|
|
index = a + b |
|
select_data(self, index=index, include_projections=True, include_embeddings=True) |
|
|
|
|
|
self.update_batch_size() |
|
if self.apply_normalization: |
|
self.normalization.inverse_normalize(self) |
|
|
|
total_distance = self.fleet.distance.sum(dim=1).squeeze(1).detach() |
|
total_time = self.fleet.time.sum(dim=1).squeeze(1).detach() |
|
total_late_time = self.fleet.late_time.sum(dim=1).squeeze(1).detach() |
|
|
|
output = { |
|
'distance': total_distance.detach(), |
|
'total_time': total_time.detach(), |
|
'log_probs': self.log_probs, |
|
'late_time': total_late_time.detach(), |
|
'incomplete': self.check_complete(), |
|
'path': self.fleet.path, |
|
'arrival_times': self.fleet.arrival_times |
|
} |
|
|
|
return output |
|
|
|
|
|
def consolidate_beams(self): |
|
p = self.log_probs.reshape(self.original_batch_size, self.sample_size * self.sample_size) |
|
a = torch.topk(p, dim=1, k=self.sample_size, largest=True)[1] |
|
|
|
b = torch.arange(self.original_batch_size).unsqueeze(1).repeat(1, self.sample_size).to(self.device) |
|
b = b * self.sample_size * self.sample_size |
|
|
|
ind = (a + b).reshape(-1) |
|
select_data(self, index=ind, include_projections=True, include_embeddings=True) |
|
|
|
|
|
def adjust_arrival_times(self): |
|
|
|
if self.apply_normalization and (self.normalization_params is not None): |
|
num_steps = self.fleet.arrival_times.shape[2] |
|
a = self.normalization_params['earliest_start_time'].reshape(self.batch_size, 1, 1).repeat(1, self.num_cars, num_steps) |
|
b = self.normalization_params['greatest_drive_time'].reshape(self.batch_size, 1, 1).repeat(1, self.num_cars, num_steps) |
|
self.fleet.arrival_times = self.fleet.arrival_times*b + a |
|
|
|
a = self.normalization_params['earliest_start_time'] |
|
b = self.normalization_params['greatest_drive_time'] |
|
self.fleet.late_time = self.fleet.late_time*b + a |
|
|
|
|
|
def compute_encoder_mask(self): |
|
|
|
time_window_non_compatibility = 1 - self.graph.time_window_compatibility |
|
|
|
|
|
diag = torch.diag(torch.ones(self.num_nodes)).unsqueeze(0).repeat(self.batch_size, 1, 1).to(self.device) |
|
|
|
|
|
m = (time_window_non_compatibility + diag > 0).float() |
|
|
|
dist = self.graph.distance_matrix*(1 - m) + m*self.graph.max_dist*10 |
|
|
|
K = min(self.num_nodes, self.num_neighbors_encoder) |
|
neighbors_index = torch.topk(dist, k=K, dim=2, largest=False)[1] |
|
|
|
a = torch.arange(self.num_nodes).reshape(1, 1, -1, 1).repeat(self.batch_size, self.num_nodes, 1, K).to(self.device) |
|
b = neighbors_index.unsqueeze(2).repeat(1, 1, self.num_nodes, 1) |
|
neighbors_mask = (a == b).float().sum(dim=3) |
|
|
|
|
|
m = (time_window_non_compatibility == 0).float() |
|
neighbs_time_mask = neighbors_mask*m |
|
|
|
|
|
v = self.graph.depot.reshape(self.batch_size, 1, self.num_nodes).repeat(1, self.num_nodes, 1) |
|
w = self.graph.depot.reshape(self.batch_size, self.num_nodes, 1).repeat(1, 1, self.num_nodes) |
|
depot_mask = (v + w > 0).float() |
|
|
|
diag = torch.diag(torch.ones(self.num_nodes)).unsqueeze(0).repeat(self.batch_size, 1, 1).to(self.device) |
|
|
|
mask = (neighbs_time_mask + depot_mask + diag > 0).float() |
|
return mask |
|
|
|
|
|
def compute_depoyment_priority_score(self): |
|
|
|
available_nodes = (self.check_non_depot_options().float() == 0).float() |
|
num_available_nodes = available_nodes.sum(dim=2) |
|
|
|
|
|
ind = self.fleet.node.reshape(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_nodes) |
|
distances = torch.gather(self.graph.distance_matrix, dim=1, index=ind) |
|
|
|
max_distance = self.graph.distance_matrix.reshape(self.batch_size, -1).max(dim=1)[0].reshape( |
|
self.batch_size, 1, 1).repeat(1, self.num_cars, self.num_nodes) |
|
excess = ((max_distance - distances)*available_nodes).sum(dim=2) |
|
|
|
|
|
max_excess = excess.max() |
|
|
|
score = (num_available_nodes + max_excess)*10 + excess |
|
return score |
|
|
|
|
|
|
|
def check_complete(self): |
|
has_untraversed_nodes = ((self.fleet.traversed_nodes == 0).float().sum(dim=1) > 0) |
|
|
|
|
|
incomplete = has_untraversed_nodes |
|
return incomplete |
|
|
|
|
|
def loop_condition(self): |
|
a = self.check_complete().sum().item() |
|
if a == 0: |
|
return False |
|
else: |
|
return True |
|
|
|
|
|
def compute_cost(self): |
|
|
|
dist = self.fleet.distance.sum(dim=1).squeeze(1) |
|
time = self.fleet.time.sum(dim=1).squeeze(1) |
|
lateness = self.fleet.late_time.sum(dim=1).squeeze(1) |
|
|
|
|
|
return time |
|
|
|
|
|
|
|
def compute_total_distance(self): |
|
p_2 = self.fleet.path |
|
p_1 = torch.cat([p_2[:,:,-1:], p_2[:,:,:-1]], dim=2) |
|
|
|
mat = self.graph.distance_matrix.reshape(self.batch_size, 1, self.num_nodes, self.num_nodes).repeat(1, self.num_cars, 1, 1) |
|
ind_1 = p_1.unsqueeze(3).repeat(1, 1, 1, self.num_nodes) |
|
|
|
d = torch.gather(mat, dim=2, index=ind_1) |
|
ind_2 = p_2.unsqueeze(3) |
|
pairwise_distances = torch.gather(d, dim=3, index=ind_2).squeeze(3) |
|
|
|
self.car_distances = pairwise_distances.sum(dim=2) |
|
self.total_distance = self.car_distances.sum(dim=1) |
|
|
|
if self.apply_normalization and (self.normalization_params is not None): |
|
distance_multiplier = self.normalization_params['greatest_distance'] |
|
self.total_distance = self.total_distance*distance_multiplier.reshape(self.batch_size) |
|
|
|
return self.total_distance |
|
|
|
|
|
def update_traversed_nodes(self): |
|
|
|
path_length = self.fleet.path.shape[2] |
|
a = self.fleet.path.reshape(self.batch_size, self.num_cars, path_length, 1).repeat(1, 1, 1, self.num_nodes) |
|
b = torch.arange(self.num_nodes).reshape( |
|
1, 1, 1, self.num_nodes).repeat( |
|
self.batch_size, self.num_cars, path_length, 1).to(self.device) |
|
|
|
s = (a == b).float().reshape(self.batch_size, self.num_cars*path_length, self.num_nodes).sum(dim=1) |
|
self.fleet.traversed_nodes = (s > 0).float() |
|
|
|
|
|
def compute_action(self, decoder_output, action_mask, mover_indices): |
|
|
|
if self.mode == 'nearest_neighbors': |
|
|
|
assert decoder_output is None |
|
|
|
num_movers = mover_indices.shape[1] |
|
mover_nodes = torch.gather(self.fleet.node, dim=1, index=mover_indices) |
|
|
|
ind = mover_nodes.reshape(self.batch_size, num_movers, 1).repeat(1, 1, self.num_nodes) |
|
distances = torch.gather(self.graph.distance_matrix, dim=1, index=ind) |
|
|
|
a = mover_nodes.reshape(self.batch_size, num_movers, 1).repeat(1, 1, self.num_nodes) |
|
b = torch.arange(self.num_nodes).reshape(1, 1, self.num_nodes).repeat(self.batch_size, num_movers, 1).to(self.device) |
|
current_node_indicator = (a == b).float() |
|
|
|
a = action_mask.reshape(self.batch_size, -1) |
|
no_options = (a.sum(dim=1) == 0).reshape(-1, 1, 1).repeat(1, num_movers, self.num_nodes).float() |
|
|
|
mask = action_mask * (1 - no_options) + current_node_indicator * no_options |
|
masked_distances = distances*mask + self.graph.max_dist*(1 - mask)*10 |
|
|
|
x = masked_distances.reshape(self.batch_size, -1) |
|
ind = torch.argmin(x, dim=1).unsqueeze(1) |
|
|
|
mover_index = ind // self.num_nodes |
|
next_node = ind % self.num_nodes |
|
car_to_move = torch.gather(mover_indices, dim=1, index=mover_index) |
|
|
|
log_prob = torch.zeros(self.batch_size).to(self.device) |
|
return next_node, car_to_move, log_prob |
|
|
|
else: |
|
|
|
assert decoder_output is not None |
|
|
|
num_movers = mover_indices.shape[1] |
|
assert (num_movers == action_mask.shape[1]) and (num_movers == decoder_output.shape[1]) |
|
|
|
a = action_mask.reshape(self.batch_size, -1) |
|
no_options = (a.sum(dim=1) == 0).reshape(-1, 1, 1).repeat(1, num_movers, self.num_nodes).float() |
|
|
|
mover_nodes = torch.gather(self.fleet.node, dim=1, index=mover_indices) |
|
a = torch.arange(self.num_nodes).reshape(1, 1, -1).repeat(self.batch_size, num_movers, 1).to(self.device) |
|
b = mover_nodes.reshape(self.batch_size, num_movers, 1).repeat(1, 1, self.num_nodes) |
|
default_option = (a == b).float() |
|
|
|
mask = action_mask*(1 - no_options) + default_option*no_options |
|
masked_decoder_output = decoder_output + mask.log() |
|
|
|
x = masked_decoder_output.reshape(self.batch_size, -1) |
|
probs = torch.softmax(x, dim=1) |
|
|
|
if self.mode == 'greedy': |
|
prob, ind = torch.max(probs, dim=1)[0].unsqueeze(1), torch.argmax(probs, dim=1).unsqueeze(1) |
|
|
|
elif self.mode in {'train', 'sample'}: |
|
ind = torch.multinomial(probs, num_samples=1) |
|
prob = torch.gather(probs, dim=1, index=ind) |
|
|
|
elif self.mode == 'beam_search': |
|
prob, ind = torch.topk(probs, dim=1, k=self.sample_size, largest=True) |
|
|
|
mover_index = ind // self.num_nodes |
|
next_node = ind % self.num_nodes |
|
car_to_move = torch.gather(mover_indices, dim=1, index=mover_index) |
|
|
|
next_node = next_node.reshape(-1) |
|
car_to_move = car_to_move.reshape(-1) |
|
prob = prob.reshape(-1) |
|
|
|
return next_node, car_to_move, prob.log() |
|
|
|
|
|
def update_distance(self, next_node, car_to_move): |
|
ind = car_to_move.reshape(self.batch_size, 1) |
|
n = self.fleet.node.reshape(self.batch_size, self.num_cars) |
|
current_node = torch.gather(n, dim=1, index=ind).squeeze(1) |
|
|
|
|
|
ind_1 = current_node.reshape(-1, 1, 1).repeat(1, 1, self.num_nodes) |
|
distances = torch.gather(self.graph.distance_matrix, dim=1, index=ind_1).squeeze(1) |
|
ind_2 = next_node.reshape(-1, 1) |
|
dist_to_next_node = torch.gather(distances, dim=1, index=ind_2).squeeze(1) |
|
|
|
|
|
t = self.fleet.distance.reshape(self.batch_size, self.num_cars) |
|
current_distance = torch.gather(t, dim=1, index=car_to_move.reshape(self.batch_size, 1)).squeeze(1) |
|
|
|
|
|
updated_distance = current_distance + dist_to_next_node |
|
|
|
|
|
a = torch.arange(self.num_cars).reshape(1, -1).repeat(self.batch_size, 1).to(self.device) |
|
b = car_to_move.reshape(self.batch_size, 1).repeat(1, self.num_cars) |
|
update_mask = (a == b).float().unsqueeze(2) |
|
|
|
|
|
t = updated_distance.reshape(self.batch_size, 1, 1).repeat(1, self.num_cars, 1) |
|
self.fleet.distance = self.fleet.distance * (1 - update_mask) + t * update_mask |
|
|
|
|
|
def update_time(self, next_node, car_to_move): |
|
|
|
ind = car_to_move.reshape(self.batch_size, 1) |
|
n = self.fleet.node.reshape(self.batch_size, self.num_cars) |
|
current_node = torch.gather(n, dim=1, index=ind).squeeze(1) |
|
|
|
|
|
ind_1 = current_node.reshape(-1, 1, 1).repeat(1, 1, self.num_nodes) |
|
drive_times = torch.gather(self.graph.time_matrix, dim=1, index=ind_1).squeeze(1) |
|
ind_2 = next_node.reshape(-1, 1) |
|
time_to_next_node = torch.gather(drive_times, dim=1, index=ind_2).squeeze(1) |
|
|
|
|
|
start_time = self.graph.start_time.reshape(self.batch_size, self.num_nodes) |
|
ind = next_node.reshape(self.batch_size, 1) |
|
next_start_time = torch.gather(start_time, dim=1, index=ind).squeeze(1) |
|
|
|
|
|
t = self.fleet.time.reshape(self.batch_size, self.num_cars) |
|
current_node_time = torch.gather(t, dim=1, index=car_to_move.reshape(self.batch_size, 1)).squeeze(1) |
|
|
|
|
|
a = time_to_next_node + current_node_time |
|
b = next_start_time |
|
updated_time = (a > b).float()*a + (a <= b).float()*b |
|
|
|
|
|
ind = next_node.reshape(-1, 1) |
|
end_times = self.graph.end_time.reshape(self.batch_size, self.num_nodes) |
|
end_time_next_node = torch.gather(end_times, dim=1, index=ind).squeeze(1) |
|
late_time = F.relu(updated_time - end_time_next_node) |
|
|
|
|
|
|
|
a = torch.arange(self.num_cars).reshape(1, -1).repeat(self.batch_size, 1).to(self.device) |
|
b = car_to_move.reshape(self.batch_size, 1).repeat(1, self.num_cars) |
|
update_mask = (a == b).float().unsqueeze(2) |
|
|
|
|
|
t = updated_time.reshape(self.batch_size, 1, 1).repeat(1, self.num_cars, 1) |
|
self.fleet.time = self.fleet.time*(1 - update_mask) + t*update_mask |
|
|
|
|
|
l = late_time.reshape(self.batch_size, 1, 1).repeat(1, self.num_cars, 1) |
|
self.fleet.late_time = self.fleet.late_time*(1 - update_mask) + l*update_mask |
|
|
|
|
|
|
|
def update_node_path(self, next_node, car_to_move): |
|
|
|
a = torch.arange(self.num_cars).reshape(1, -1).repeat(self.batch_size, 1).to(self.device) |
|
b = car_to_move.reshape(self.batch_size, 1).repeat(1, self.num_cars) |
|
update_mask = (a == b).long() |
|
|
|
new_node = next_node.reshape(self.batch_size, 1).repeat(1, self.num_cars) |
|
self.fleet.node = update_mask*new_node + self.fleet.node*(1 - update_mask) |
|
|
|
L = [self.fleet.path, self.fleet.node.unsqueeze(2)] |
|
self.fleet.path = torch.cat(L, dim=2) |
|
|
|
t = self.fleet.time.reshape(self.batch_size, self.num_cars, 1) |
|
H = [self.fleet.arrival_times, t] |
|
self.fleet.arrival_times = torch.cat(H, dim=2) |
|
|
|
|
|
|
|
def check_non_depot_options(self, use_time=True): |
|
''' |
|
Output is a byte tensor of shape [batch, num_cars, num_nodes] with entry (i,j) = 1 if |
|
the move of car i to node j is invalid |
|
''' |
|
|
|
if use_time: |
|
|
|
too_far = torch.logical_not(self.check_arival_times()) |
|
|
|
|
|
a = self.fleet.depot.reshape(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_nodes) |
|
b = torch.arange(self.num_nodes).reshape(1, 1, -1).repeat(self.batch_size, self.num_cars, 1).to(self.device) |
|
is_depot = (a == b) |
|
|
|
|
|
a = self.fleet.traversed_nodes.reshape(self.batch_size, 1, self.num_nodes).repeat(1, self.num_cars, 1) |
|
traversed_nodes = (a == 1) |
|
|
|
|
|
if use_time: |
|
unavailable_moves = (too_far | is_depot | traversed_nodes) |
|
else: |
|
unavailable_moves = (is_depot | traversed_nodes) |
|
|
|
return unavailable_moves |
|
|
|
|
|
def compute_decoder_mask(self, mover_indices, unavailable_moves): |
|
|
|
num_movers = mover_indices.shape[1] |
|
|
|
mover_nodes = torch.gather(self.fleet.node, dim=1, index=mover_indices) |
|
|
|
ind = mover_indices.reshape(self.batch_size, num_movers, 1).repeat(1, 1, self.num_nodes) |
|
unavailable_moves = torch.gather(unavailable_moves, dim=1, index=ind) |
|
|
|
no_options = ((1 - unavailable_moves.float()).sum(dim=2) == 0).unsqueeze(2).repeat(1, 1, self.num_nodes).float() |
|
|
|
a = mover_nodes.reshape(self.batch_size, num_movers, 1).repeat(1, 1, self.num_nodes) |
|
b = torch.arange(self.num_nodes).reshape(1, 1, -1).repeat(self.batch_size, num_movers, 1).to(self.device) |
|
default_move = (a == b).float() |
|
|
|
decoder_mask = (1 - unavailable_moves.float())*(1 - no_options) + default_move*no_options |
|
return decoder_mask |
|
|
|
|
|
def compute_action_mask(self, mover_indices, unavailable_moves): |
|
|
|
if self.mode != 'nearest_neighbors': |
|
|
|
d = torch.diag(torch.ones(self.num_nodes)).unsqueeze(0).repeat(self.batch_size, 1, 1).to(self.device) |
|
ind = self.fleet.node.reshape(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_nodes) |
|
|
|
diag = torch.gather(d, dim=1, index=ind) |
|
distances = torch.gather(self.graph.distance_matrix, dim=1, index=ind) |
|
m = (unavailable_moves.float() + diag > 0).float() |
|
|
|
masked_distances = distances*(1 - m) + m*self.graph.max_dist*100 |
|
K = min(self.num_neighbors_action, self.num_nodes) |
|
neighbor_indices = torch.topk(masked_distances, dim=2, k=K, largest=False)[1] |
|
|
|
a = torch.arange(self.num_nodes).reshape(1, 1, -1, 1).repeat(self.batch_size, self.num_cars, 1, K).to(self.device) |
|
b = neighbor_indices.reshape(self.batch_size, self.num_cars, 1, K).repeat(1, 1, self.num_nodes, 1) |
|
|
|
non_neighbor_mask = ((a == b).float().sum(dim=3) == 0) |
|
mask = 1 - (non_neighbor_mask | unavailable_moves).float() |
|
|
|
else: |
|
mask = 1 - unavailable_moves.float() |
|
|
|
num_movers = mover_indices.shape[1] |
|
ind = mover_indices.reshape(self.batch_size, num_movers, 1).repeat(1, 1, self.num_nodes) |
|
action_mask = torch.gather(mask, dim=1, index=ind) |
|
return action_mask |
|
|
|
|
|
|
|
def construct_decoder_input(self, mover_indices): |
|
|
|
embedding_size = self.node_embeddings.shape[2] |
|
mask = self.check_non_depot_options(use_time=True) |
|
mask = mask.permute(0, 2, 1).unsqueeze(3).repeat(1, 1, 1, embedding_size).float() |
|
|
|
node_vectors = self.node_embeddings.reshape(self.batch_size, self.num_nodes, 1, embedding_size).repeat(1, 1, self.num_cars, 1) |
|
|
|
|
|
start_ind = self.fleet.depot.reshape(self.batch_size, 1, self.num_cars, 1).repeat(1, 1, 1, embedding_size) |
|
a = torch.gather(node_vectors, dim=1, index=start_ind) |
|
depot_vector = a.squeeze(1) |
|
|
|
|
|
current_ind = self.fleet.node.reshape(self.batch_size, 1, self.num_cars, 1).repeat(1, 1, 1, embedding_size) |
|
b = torch.gather(node_vectors, dim=1, index=current_ind) |
|
current_node_vector = b.squeeze(1) |
|
|
|
|
|
mean_graph_vector = node_vectors.mean(dim=1) |
|
|
|
|
|
|
|
feature_values = self.fleet.construct_vector() |
|
|
|
|
|
|
|
num_movers = mover_indices.shape[1] |
|
if num_movers == 1: |
|
movers_vector = current_node_vector*0 |
|
elif not self.use_fleet_attention: |
|
a = mover_indices.reshape(self.batch_size, num_movers, 1).repeat(1, 1, self.num_cars) |
|
b = torch.arange(self.num_cars).reshape(1, 1, self.num_cars).repeat(self.batch_size, num_movers, 1).to(self.device) |
|
c = ((a == b).float().sum(dim=1) > 0).float() |
|
movers_mask = c.reshape(self.batch_size, self.num_cars, 1).repeat(1, 1, embedding_size) |
|
movers_vector = ((current_node_vector*movers_mask).sum(dim=1).unsqueeze(1) - current_node_vector)/(num_movers - 1) |
|
else: |
|
x = torch.cat([current_node_vector, feature_values], dim=2) |
|
movers_vector = self.fleet_attention(x) |
|
|
|
|
|
L = [current_node_vector, depot_vector, mean_graph_vector, movers_vector, feature_values] |
|
pre_output = torch.cat(L, dim=2) |
|
|
|
|
|
num_movers = mover_indices.shape[1] |
|
ind = mover_indices.reshape(self.batch_size, num_movers, 1).repeat(1, 1, pre_output.shape[2]) |
|
output = torch.gather(pre_output, dim=1, index=ind) |
|
|
|
return output |
|
|
|
|
|
def get_mover_indices(self, unavailable_moves): |
|
depot = self.fleet.depot.reshape(self.batch_size, self.num_cars).long() |
|
current_node = self.fleet.node.reshape(self.batch_size, self.num_cars).long() |
|
|
|
|
|
has_option = ((1 - unavailable_moves.float()).sum(dim=2) > 0) |
|
|
|
at_depot = (current_node == depot) |
|
in_field = (current_node != depot) |
|
active_in_field = in_field & has_option |
|
active_at_depot = at_depot & has_option |
|
|
|
deployment_score = self.compute_depoyment_priority_score() |
|
max_deployment_score = deployment_score.max() |
|
|
|
A = max_deployment_score * 100 |
|
B = deployment_score * 10 |
|
|
|
score = active_in_field.float() * A + active_at_depot.float() * B |
|
score = score.reshape(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_depots) |
|
|
|
a = self.fleet.depot.reshape(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_depots) |
|
b = torch.arange(self.num_depots).reshape(1, 1, self.num_depots).repeat(self.batch_size, self.num_cars, 1).to(self.device) |
|
m = (a == b).float() |
|
|
|
score = score*m |
|
|
|
K = self.num_movers_corrected |
|
indices = torch.topk(score, k=K, dim=1, largest=True)[1] |
|
indices = indices.reshape(self.batch_size, self.num_depots*K) |
|
return indices |
|
|
|
|
|
def check_arival_times(self): |
|
|
|
d = self.graph.time_matrix.unsqueeze(1).repeat(1, self.num_cars, 1, 1) |
|
ind = self.fleet.node.reshape(self.batch_size, self.num_cars, 1, 1).repeat(1, 1, 1, self.num_nodes) |
|
drive_times = torch.gather(d, dim=2, index=ind).squeeze(2) |
|
t = self.fleet.time.reshape(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_nodes) |
|
arival_times = drive_times + t |
|
b = self.graph.end_time.reshape(self.batch_size, 1, self.num_nodes).repeat(1, self.num_cars, 1) |
|
attainable = (arival_times <= b) |
|
return attainable |
|
|
|
|
|
def return_to_depot(self): |
|
unavailable_moves = 1 - self.check_non_depot_options(use_time=False).float() |
|
return_to_depot = (unavailable_moves.reshape(self.batch_size, -1).sum(dim=1) == 0).float() |
|
return_to_depot = return_to_depot.reshape(self.batch_size, 1).repeat(1, self.num_cars) |
|
|
|
|
|
if return_to_depot.sum().item() > 0: |
|
|
|
depot = self.fleet.depot.reshape(self.batch_size, self.num_cars).long() |
|
node = self.fleet.node.reshape(self.batch_size, self.num_cars).long() |
|
|
|
|
|
return_to_depot = return_to_depot.long() |
|
next_node = return_to_depot*depot + (1 - return_to_depot)*node |
|
|
|
|
|
return_to_depot = return_to_depot.unsqueeze(2).float() |
|
|
|
current_node = self.fleet.node |
|
ind_1 = current_node.reshape(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_nodes) |
|
drive_times = torch.gather(self.graph.time_matrix, dim=1, index=ind_1) |
|
ind_2 = next_node.reshape(self.batch_size, self.num_cars, 1) |
|
time_to_next_node = torch.gather(drive_times, dim=2, index=ind_2) |
|
self.fleet.time = self.fleet.time + time_to_next_node*return_to_depot |
|
|
|
|
|
self.fleet.node = next_node |
|
|
|
|
|
n = self.fleet.node.reshape(self.batch_size, self.num_cars, 1) |
|
self.fleet.path = torch.cat([self.fleet.path, n], dim=2) |
|
|
|
|
|
t = self.fleet.time.reshape(self.batch_size, self.num_cars, 1) |
|
self.fleet.arrival_times = torch.cat([self.fleet.arrival_times, t], dim=2) |
|
|
|
|
|
def return_to_depot_1(self): |
|
|
|
depot = self.fleet.depot.reshape(self.batch_size, self.num_cars).long() |
|
node = self.fleet.node.reshape(self.batch_size, self.num_cars).long() |
|
|
|
|
|
next_node = depot |
|
|
|
|
|
current_node = self.fleet.node |
|
ind_1 = current_node.reshape(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_nodes) |
|
drive_times = torch.gather(self.graph.time_matrix, dim=1, index=ind_1) |
|
ind_2 = next_node.reshape(self.batch_size, self.num_cars, 1) |
|
time_to_next_node = torch.gather(drive_times, dim=2, index=ind_2) |
|
self.fleet.time = self.fleet.time + time_to_next_node |
|
|
|
|
|
|
|
current_node = self.fleet.node |
|
ind_1 = current_node.reshape(self.batch_size, self.num_cars, 1).repeat(1, 1, self.num_nodes) |
|
drive_distances = torch.gather(self.graph.distance_matrix, dim=1, index=ind_1) |
|
ind_2 = next_node.reshape(self.batch_size, self.num_cars, 1) |
|
distance_to_next_node = torch.gather(drive_distances, dim=2, index=ind_2) |
|
self.fleet.distance = self.fleet.distance + distance_to_next_node |
|
|
|
|
|
|
|
self.fleet.node = next_node |
|
|
|
|
|
n = self.fleet.node.reshape(self.batch_size, self.num_cars, 1) |
|
self.fleet.path = torch.cat([self.fleet.path, n], dim=2) |
|
|
|
|
|
t = self.fleet.time.reshape(self.batch_size, self.num_cars, 1) |
|
self.fleet.arrival_times = torch.cat([self.fleet.arrival_times, t], dim=2) |
|
|