Source code for ml4co_kit.solver.base

r"""
Base class for all solvers.
"""

# Copyright (c) 2024 Thinklab@SJTU
# ML4CO-Kit is licensed under Mulan PSL v2.
# You can use this software according to the terms and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
# http://license.coscl.org.cn/MulanPSL2
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.


from enum import Enum
from typing import List, Callable
from ml4co_kit.task.base import TaskBase
from ml4co_kit.optimizer.base import OptimizerBase


[docs]class SOLVER_TYPE(str, Enum): """Define the solver types as an enumeration.""" # Common Solvers GNN4CO = "gnn4co" GUROBI = "gurobi" ILS = "ils" INSERTION = "insertion" NULL = "null" ORTOOLS = "ortools" RANDOM = "random" SCIP = "scip" # Solvers for EDA Problems DREAMPLACE = "dreamplace" # Solvers for Graph Problems FEM = "fem" GP_DEGREE = "gp_degree" ISCO = "isco" KAMIS = "kamis" LC_DEGREE = "lc_degree" RLSA = "rlsa" # Solvers for QAP Problems PYGM = "pygm" # Solvers for Routing Problems CONCORDE = "concorde" GA_EAX = "ga_eax" HGS = "hgs" LKH = "lkh" NEAREST = "nearest" NEUROLKH = "neurolkh" PYVRP = "pyvrp" # Solvers for SAT Problems PYSAT = "pysat" # Solvers for development ML4CO = "ml4co" DIY = "diy"
[docs]class SolverBase(object): """Base class for all solvers.""" def __init__( self, solver_type: SOLVER_TYPE, optimizer: OptimizerBase = None, ): self.solver_type = solver_type self.solve_func_dict: dict = None self.optimizer = optimizer ####################################### # Single Solving Methods # #######################################
[docs] def solve(self, task_data: TaskBase) -> TaskBase: self._solve(task_data) if self.optimizer is not None: self.optimizer.optimize(task_data) return task_data
def _solve(self, task_data: TaskBase): raise NotImplementedError( "The ``solve`` function is required to implemented in subclasses." ) ####################################### # Batch Solving Methods # #######################################
[docs] def batch_solve( self, batch_task_data: List[TaskBase], optimizer_parallel: bool = False ) -> List[TaskBase]: self._batch_solve(batch_task_data) if self.optimizer is not None: if optimizer_parallel: self.optimizer.batch_optimize(batch_task_data) else: for task_data in batch_task_data: self.optimizer.optimize(task_data) return batch_task_data
def _sequential_solve( self, batch_task_data: List[TaskBase], single_func: Callable, ): """Solve the given batch task data sequentially.""" for task_data in batch_task_data: single_func(task_data) return batch_task_data def _batch_solve(self, batch_task_data: List[TaskBase]): """ If not implemented, solve the batch task data sequentially. """ return self._sequential_solve(batch_task_data, self._solve)