r"""
Base class for all solvers.
"""
# Copyright (c) 2024 Thinklab@SJTU
# ML4CO-Kit is licensed under Mulan PSL v2.
# You can use this software according to the terms and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
# http://license.coscl.org.cn/MulanPSL2
#
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.
from enum import Enum
from typing import List, Callable
from ml4co_kit.task.base import TaskBase
from ml4co_kit.optimizer.base import OptimizerBase
[docs]class SOLVER_TYPE(str, Enum):
"""Define the solver types as an enumeration."""
# Common Solvers
GNN4CO = "gnn4co"
GUROBI = "gurobi"
ILS = "ils"
INSERTION = "insertion"
NULL = "null"
ORTOOLS = "ortools"
RANDOM = "random"
SCIP = "scip"
# Solvers for EDA Problems
DREAMPLACE = "dreamplace"
# Solvers for Graph Problems
FEM = "fem"
GP_DEGREE = "gp_degree"
ISCO = "isco"
KAMIS = "kamis"
LC_DEGREE = "lc_degree"
RLSA = "rlsa"
# Solvers for QAP Problems
PYGM = "pygm"
# Solvers for Routing Problems
CONCORDE = "concorde"
GA_EAX = "ga_eax"
HGS = "hgs"
LKH = "lkh"
NEAREST = "nearest"
NEUROLKH = "neurolkh"
PYVRP = "pyvrp"
# Solvers for SAT Problems
PYSAT = "pysat"
# Solvers for development
ML4CO = "ml4co"
DIY = "diy"
[docs]class SolverBase(object):
"""Base class for all solvers."""
def __init__(
self,
solver_type: SOLVER_TYPE,
optimizer: OptimizerBase = None,
):
self.solver_type = solver_type
self.solve_func_dict: dict = None
self.optimizer = optimizer
#######################################
# Single Solving Methods #
#######################################
[docs] def solve(self, task_data: TaskBase) -> TaskBase:
self._solve(task_data)
if self.optimizer is not None:
self.optimizer.optimize(task_data)
return task_data
def _solve(self, task_data: TaskBase):
raise NotImplementedError(
"The ``solve`` function is required to implemented in subclasses."
)
#######################################
# Batch Solving Methods #
#######################################
[docs] def batch_solve(
self,
batch_task_data: List[TaskBase],
optimizer_parallel: bool = False
) -> List[TaskBase]:
self._batch_solve(batch_task_data)
if self.optimizer is not None:
if optimizer_parallel:
self.optimizer.batch_optimize(batch_task_data)
else:
for task_data in batch_task_data:
self.optimizer.optimize(task_data)
return batch_task_data
def _sequential_solve(
self,
batch_task_data: List[TaskBase],
single_func: Callable,
):
"""Solve the given batch task data sequentially."""
for task_data in batch_task_data:
single_func(task_data)
return batch_task_data
def _batch_solve(self, batch_task_data: List[TaskBase]):
"""
If not implemented, solve the batch task data sequentially.
"""
return self._sequential_solve(batch_task_data, self._solve)