import torch import torch.nn as nn import numpy as np import os import multiprocessing from models.solvers.general_solver import GeneralSolver from utils.utils import calc_tour_length def get_visited_mask(tour, step, node_feats, dist_matrix=None): """ Visited nodes -> feasible, Unvisited nodes -> infeasible. When solving a problem with visited_paths fixed, they should be included to the solution. Therefore, visited nodes are set to feasible nodes. """ if dist_matrix is not None: num_nodes = len(dist_matrix) else: num_nodes = len(node_feats["coords"]) visited = np.isin(np.arange(num_nodes), tour[:step]) return visited def get_tw_mask(tour, step, node_feats, dist_matrix=None): """ Nodes whose tw exceeds current_time -> infeasible, otherwise -> feasible. Parameters ---------- tour: list [seq_length] step: int node_feats: dict of np.array Returns ------- mask_tw: np.array [num_nodes] """ node_feats = node_feats.copy() time_window = node_feats["time_window"] if dist_matrix is not None: num_nodes = len(dist_matrix) curr_time = 0.0 not_exceed_tw = np.ones(num_nodes).astype(np.int32) for i in range(1, step): prev_id = tour[i - 1] curr_id = tour[i] travel_time = dist_matrix[prev_id, curr_id] # assert curr_time + travel_time < time_window[curr_id, 1], f"Invalid tour! arrival_time: {curr_time + travel_time}, time_window: {time_window[curr_id]}" if curr_time + travel_time < time_window[curr_id, 0]: curr_time = time_window[curr_id, 0].copy() else: curr_time += travel_time curr_time = curr_time + dist_matrix[tour[step-1]] # [num_nodes] TODO: check else: coords = node_feats["coords"] num_nodes = len(coords) curr_time = 0.0 not_exceed_tw = np.ones(num_nodes).astype(np.int32) for i in range(1, step): prev_id = tour[i - 1] curr_id = tour[i] travel_time = np.linalg.norm(coords[prev_id] - coords[curr_id]) # assert curr_time + travel_time < time_window[curr_id, 1], f"Invalid tour! arrival_time: {curr_time + travel_time}, time_window: {time_window[curr_id]}" if curr_time + travel_time < time_window[curr_id, 0]: curr_time = time_window[curr_id, 0].copy() else: curr_time += travel_time curr_time = curr_time + np.linalg.norm(coords[tour[step-1]][None, :] - coords, axis=-1) # [num_nodes] TODO: check not_exceed_tw[curr_time > time_window[:, 1]] = 0 not_exceed_tw = not_exceed_tw > 0 return not_exceed_tw def get_cap_mask(tour, step, node_feats): num_nodes = len(node_feats["coords"]) demands = node_feats["demand"] remaining_cap = node_feats["capacity"].copy() less_than_cap = np.ones(num_nodes).astype(np.int32) for i in range(step): remaining_cap -= demands[tour[i]] less_than_cap[remaining_cap < demands] = 0 less_than_cap = less_than_cap > 0 return less_than_cap def get_pc_mask(tour, step, node_feats): """ Mask for Price collecting problems (e.g., PCTSP, PCTSPTW, PCCVRP, PCCVRPTW, ...) Returns ------- not_exceed_max_length """ large_value = 1e+5 coords = node_feats["coords"] max_length = (node_feats["max_length"] * large_value).astype(np.int64) tour_length = 0 for i in range(1, step): prev_id = tour[i - 1] curr_id = tour[i] tour_length += (np.linalg.norm(coords[prev_id] - coords[curr_id]) * large_value).astype(np.int64) curr_to_next = (np.linalg.norm(coords[tour[step-1]][None, :] - coords, axis=-1) * large_value).astype(np.int64) # [num_nodes] next_to_depot = (np.linalg.norm(coords[tour[0]][None, :] - coords, axis=-1) * large_value).astype(np.int64) # [num_nodes] not_exceed_max_length = (tour_length + curr_to_next + next_to_depot) <= max_length # [num_nodes] return not_exceed_max_length def analyze_tour(tour, node_feats): coords = node_feats["coords"] time_window = node_feats["time_window"] curr_time = 0 for i in range(1, len(tour)): prev_id = tour[i - 1] curr_id = tour[i] travel_time = np.linalg.norm(coords[prev_id] - coords[curr_id]) valid = curr_time + travel_time < time_window[curr_id, 1] print(f"visit #{i}: {prev_id} -> {curr_id}, travel_time: {travel_time}, arrival_time: {curr_time + travel_time}, time_window: {time_window[curr_id]}, valid: {valid}") if curr_time + travel_time < time_window[curr_id, 0]: curr_time = time_window[curr_id, 0] else: curr_time += travel_time FAIL_FLAG = -1 class GroundTruthBase(nn.Module): def __init__(self, problem, compared_problems, solver_type): """ Parameters ---------- """ super().__init__() self.problem = problem self.compared_problems = compared_problems self.num_compared_problems = len(compared_problems) self.solver_type = solver_type self.solvers = [] for i in range(self.num_compared_problems): # TODO: self.solvers.append(GeneralSolver(self.compared_problems[i], self.solver_type, scaling=False)) def forward(self, inputs, annotation=False, parallel=True): """ Parameters ---------- inputs: dict tour: 2d list [num_vehicles x seq_length] first_explained_step: int node_feats: dict of np.array annotation: bool please set it True when annotating data Returns ------- labels: probs: torch.tensor [batch_size (num_vehicles) x max_seq_length x num_classes] """ input_tours = inputs["tour"] node_feats = inputs["node_feats"] dist_matrix = inputs["dist_matrix"] first_explained_step = inputs["first_explained_step"] num_vehicles = len(input_tours) if annotation: labels = [[] for _ in range(num_vehicles)] for vehicle_id in range(num_vehicles): input_tour = input_tours[vehicle_id] # analyze_tour(input_tour, node_feats) for step in range(first_explained_step + 1, len(input_tour)): _, __, label = self.label_path(vehicle_id, step, input_tour, node_feats) if label == FAIL_FLAG: return labels[vehicle_id].append((step, label)) return labels else: if parallel: labels = [[-1] * (len(range(first_explained_step+1, len(input_tours[vehicle_id])))) for vehicle_id in range(num_vehicles)] num_cpus = os.cpu_count() with multiprocessing.Pool(num_cpus) as pool: for vehicle_id, step, label in pool.starmap(self.label_path, [(vehicle_id, step, input_tours[vehicle_id], node_feats, dist_matrix) for vehicle_id in range(num_vehicles) for step in range(first_explained_step+1, len(input_tours[vehicle_id]))]): labels[vehicle_id][step-(first_explained_step+1)] = label else: labels = [[-1] * (len(range(first_explained_step+1, len(input_tours[vehicle_id])))) for vehicle_id in range(num_vehicles)] for vehicle_id in range(num_vehicles): for step in range(first_explained_step+1, len(input_tours[vehicle_id])): vehicle_id, step, label = self.label_path(vehicle_id, step, input_tours[vehicle_id], node_feats, dist_matrix) labels[vehicle_id][step-(first_explained_step+1)] = label # validate labels for vehicle_id in range(num_vehicles): assert (len(input_tours[vehicle_id]) - 1) == len(labels[vehicle_id]), f"vehicle_id={vehicle_id}, {input_tours}, {labels}" return labels # labels = [torch.LongTensor(label) for label in labels] # [num_vehicles x seq_length] # labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True) # [num_vehicles x max_seq_length] # probs = torch.zeros((labels.size(0), labels.size(1), self.num_compared_problems+1)) # [num_vehicles x max_seq_length x (num_compared_problems+1)] # probs.scatter_(-1, labels.unsqueeze(-1).expand_as(probs), 1.0) # return probs def label_path(self, vehicle_id, step, input_tour, node_feats, dist_matrix=None): compared_tour_list = [[] for _ in range(self.num_compared_problems)] visited_path = input_tour[:step].copy() new_node_id, new_node_feats, new_dist_matrix = self.get_feasible_nodes(input_tour, step, node_feats, dist_matrix) new_visited_path = np.array(list(map(lambda x: np.where(new_node_id==x)[0].item(), visited_path))) for i in range(self.num_compared_problems): # TODO: in CVRPTW / PCCVRPTW, need to modify classification of the first and last paths compared_tours = self.solvers[i].solve(new_node_feats, new_visited_path, new_dist_matrix) if compared_tours is None: return vehicle_id, step, FAIL_FLAG compared_tour = None for compared_tour_tmp in compared_tours: if new_visited_path[-1] in compared_tour_tmp: compared_tour = compared_tour_tmp break assert compared_tour is not None, f"Found no appropriate vhiecle. {compared_tours}, {new_visited_path}" compared_tour = np.array(list(map(lambda x: new_node_id[x], compared_tour))) if (step > 0) and (compared_tour[1] != input_tour[1]): compared_tour = np.flipud(compared_tour) # make direction of the cf tour the same as factual one compared_tour_list[i] = compared_tour # print("fixed_paths :", visited_path) # print("input_tour :", input_tour) # print("compared_tour:", compared_tour) # print() # annotation label = self.get_label(input_tour, compared_tour_list, step) return vehicle_id, step, label def solve(self, step, input_tour, node_feats, instance_name=None): compared_tours = {} visited_path = input_tour[:step].copy() new_node_id, new_node_feats = self.get_feasible_nodes(input_tour, step, node_feats) new_visited_path = np.array(list(map(lambda x: np.where(new_node_id==x)[0].item(), visited_path))) for i, compared_problem in enumerate(self.compared_problems): compared_tours[compared_problem] = self.solvers[i].solve(new_node_feats, new_visited_path, instance_name) compared_tours[compared_problem] = list(map(lambda compared_tour: list(map(lambda x: new_node_id[x], compared_tour)), compared_tours[compared_problem])) compared_tours[compared_problem] = list(map(lambda compared_tour: calc_tour_length(compared_tour, node_feats["coords"]), compared_tours[compared_problem])) return compared_tours def get_label(self, input_tour, compared_tours, step): for i in range(self.num_compared_problems): compared_tour = compared_tours[i] if input_tour[step] == compared_tour[step]: return i return self.num_compared_problems def get_inputs(self, tour, first_explained_step, node_feats, dist_matrix=None): input_features = { "tour": tour, "first_explained_step": first_explained_step, "node_feats": node_feats, "dist_matrix": dist_matrix } return input_features def get_feasible_nodes(self, tour, step, node_feats, dist_matrix=None): """ Parameters ---------- tour: np.array [seq_length] step: int node_feats: np.array [num_nodes x node_dim] Returns ------- new_node_id: np.array [num_feasible_nodes] new_node_feats: dict of np.array [num_feasible_nodes x coord_dim] """ if dist_matrix is not None: num_nodes = len(dist_matrix) else: num_nodes = len(node_feats["coords"]) mask = self.get_mask(tour, step, node_feats, dist_matrix) node_id = np.arange(num_nodes) new_node_id = node_id[mask].copy() new_node_feats = { key: node_feat[mask].copy() if key in ["coords", "time_window", "demand", "penalties", "prizes"] else node_feat.copy() for key, node_feat in node_feats.items() } if dist_matrix is not None: delete_id = node_id[~mask] new_dist_matrix = np.delete(np.delete(dist_matrix, delete_id, 0), delete_id, 1) else: new_dist_matrix = None return new_node_id, new_node_feats, new_dist_matrix def get_mask(self, tour, step, node_feats, dist_matrix=None): raise NotImplementedError def check_feasibility(self, tour, node_feats): raise NotImplementedError