Spaces:
Sleeping
Sleeping
import torch.nn as nn | |
from models.solvers.ortools.ortools_tsp import ORToolsTSP | |
from models.solvers.ortools.ortools_tsptw import ORToolsTSPTW | |
from models.solvers.ortools.ortools_pctsp import ORToolsPCTSP | |
from models.solvers.ortools.ortools_pctsptw import ORToolsPCTSPTW | |
from models.solvers.ortools.ortools_cvrp import ORToolsCVRP | |
from models.solvers.ortools.ortools_cvrptw import ORToolsCVRPTW | |
class ORTools(nn.Module): | |
def __init__(self, problem, large_value=1e+6, scaling=False): | |
super().__init__() | |
self.coord_dim = 2 | |
self.problem = problem | |
self.large_value = large_value | |
self.scaling = scaling | |
self.ortools = self.get_ortools(problem) | |
def get_ortools(self, problem): | |
""" | |
Parameters | |
---------- | |
problem: str | |
problem type | |
Returns | |
------- | |
ortools: ortools for the specified problem | |
""" | |
if problem == "tsp": | |
return ORToolsTSP(self.large_value, self.scaling) | |
elif problem == "tsptw": | |
return ORToolsTSPTW(self.large_value, self.scaling) | |
elif problem == "pctsp": | |
return ORToolsPCTSP(self.large_value, self.scaling) | |
elif problem == "pctsptw": | |
return ORToolsPCTSPTW(self.large_value, self.scaling) | |
elif problem == "cvrp": | |
return ORToolsCVRP(self.large_value, self.scaling) | |
elif problem == "cvrptw": | |
return ORToolsCVRPTW(self.large_value, self.scaling) | |
else: | |
raise NotImplementedError | |
def solve(self, node_feats, fixed_paths=None, dist_martix=None, instance_name=None): | |
""" | |
Parameters | |
---------- | |
node_feats: np.array [num_nodes x node_dim] | |
fixed_paths: np.array [cf_step] | |
scaling: bool | |
whether or not coords are muliplied by a large value | |
to convert float-coods into int-coords | |
Returns | |
------- | |
tour: np.array [seq_length] | |
""" | |
return self.ortools.solve(node_feats, fixed_paths, dist_martix, instance_name) |