Source code for ml4co_kit.optimizer.base

r"""
Base class for all optimizers.
"""

# Copyright (c) 2024 Thinklab@SJTU
# ML4CO-Kit is licensed under Mulan PSL v2.
# You can use this software according to the terms and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
# http://license.coscl.org.cn/MulanPSL2
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.


from enum import Enum
from typing import List, Callable, Dict
from concurrent.futures import ProcessPoolExecutor
from ml4co_kit.utils.impl_utils import IMPL_TYPE
from ml4co_kit.task.base import TaskBase, TASK_TYPE


[docs]class OPTIMIZER_TYPE(str, Enum): """Define the optimizer types as an enumeration.""" TWO_OPT = "two_opt" FAST_2OPT = "fast_2opt" MCTS = "mcts" CVRP_LS = "cvrp_ls" RLSA = "rlsa" MCMC = "mcmc"
[docs]class OptimizerBase: """Base class for all optimizers.""" def __init__( self, optimizer_type: OPTIMIZER_TYPE, impl_type: IMPL_TYPE ): self.optimizer_type = optimizer_type self.impl_type = impl_type # Build method mapping for single optimization self._optimize_methods: Dict[IMPL_TYPE, Callable[[TaskBase, bool], None]] = { IMPL_TYPE.AUTO: self._auto_optimize, IMPL_TYPE.CTYPES: self._ctypes_optimize, IMPL_TYPE.CYTHON: self._cython_optimize, IMPL_TYPE.NUMPY: self._numpy_optimize, IMPL_TYPE.PYBIND11: self._pybind11_optimize, IMPL_TYPE.TORCH: self._torch_optimize, } # Build method mapping for batch optimization self._batch_optimize_methods: Dict[IMPL_TYPE, Callable[[List[TaskBase]], None]] = { IMPL_TYPE.AUTO: self._auto_batch_optimize, IMPL_TYPE.CTYPES: self._ctypes_batch_optimize, IMPL_TYPE.CYTHON: self._cython_batch_optimize, IMPL_TYPE.NUMPY: self._numpy_batch_optimize, IMPL_TYPE.PYBIND11: self._pybind11_batch_optimize, IMPL_TYPE.TORCH: self._torch_batch_optimize, } ####################################### # Single Optimization Methods # #######################################
[docs] def optimize(self, task_data: TaskBase): """ Optimize the given task data. """ # Check if solution is not None if task_data.sol is None: raise ValueError("`sol` cannot be None!") # Get the appropriate optimization method optimize_method = self._optimize_methods.get(self.impl_type, None) if optimize_method is None: raise ValueError( f"Implementation type {self.impl_type} is not supported." ) # Optimize the task data optimize_method(task_data)
def _auto_optimize(self, task_data: TaskBase, return_sol: bool = False): """Optimize the given task data using auto implementation.""" raise self._get_not_implemented_error(batch=False) def _ctypes_optimize(self, task_data: TaskBase, return_sol: bool = False): """Optimize the given task data using CTypes.""" raise self._get_not_implemented_error(batch=False) def _cython_optimize(self, task_data: TaskBase, return_sol: bool = False): """Optimize the given task data using Cython.""" raise self._get_not_implemented_error(batch=False) def _numpy_optimize(self, task_data: TaskBase, return_sol: bool = False): """Optimize the given task data using NumPy.""" raise self._get_not_implemented_error(batch=False) def _pybind11_optimize(self, task_data: TaskBase, return_sol: bool = False): """Optimize the given task data using PyBind11.""" raise self._get_not_implemented_error(batch=False) def _torch_optimize(self, task_data: TaskBase, return_sol: bool = False): """Optimize the given task data using Torch.""" raise self._get_not_implemented_error(batch=False) ####################################### # Batch Optimization Methods # #######################################
[docs] def batch_optimize(self, batch_task_data: List[TaskBase]): """ Optimize the given batch task data. """ # Check if solution is not None if any(task_data.sol is None for task_data in batch_task_data): raise ValueError("`sol` cannot be None!") # Get the appropriate batch optimization method batch_optimize_method = self._batch_optimize_methods.get(self.impl_type, None) if batch_optimize_method is None: raise ValueError( f"Implementation type {self.impl_type} is not supported." ) # Optimize the batch task data batch_optimize_method(batch_task_data)
def _auto_batch_optimize(self, batch_task_data: List[TaskBase]): """Optimize the given batch task data using auto implementation.""" raise self._get_not_implemented_error(batch=True) def _ctypes_batch_optimize(self, batch_task_data: List[TaskBase]): """Optimize the given batch task data using CTypes.""" raise self._get_not_implemented_error(batch=True) def _cython_batch_optimize(self, batch_task_data: List[TaskBase]): """Optimize the given batch task data using Cython.""" raise self._get_not_implemented_error(batch=True) def _numpy_batch_optimize(self, batch_task_data: List[TaskBase]): """Optimize the given batch task data using NumPy.""" raise self._get_not_implemented_error(batch=True) def _pybind11_batch_optimize(self, batch_task_data: List[TaskBase]): """Optimize the given batch task data using PyBind11.""" raise self._get_not_implemented_error(batch=True) def _torch_batch_optimize(self, batch_task_data: List[TaskBase]): """Optimize the given batch task data using Torch.""" raise self._get_not_implemented_error(batch=True) ####################################### # Helper Methods # ####################################### def _get_not_implemented_error( self, task_type: TASK_TYPE, batch: bool ) -> NotImplementedError: """Helper method to create a consistent NotImplementedError message.""" if batch: return NotImplementedError( f"Optimizer {self.optimizer_type} with implementation type {self.impl_type} " f"is not supported for batch optimization of {task_type}." ) else: return NotImplementedError( f"Optimizer {self.optimizer_type} with implementation type {self.impl_type} " f"is not supported for single optimization of {task_type}." ) def _pool_optimize( self, batch_task_data: List[TaskBase], single_func: Callable[[TaskBase, bool], None], ): """Optimize the given batch task data using ProcessPoolExecutor.""" if not batch_task_data: return # Optimize parallelly with ProcessPoolExecutor(max_workers=len(batch_task_data)) as executor: optimized_sols = list( executor.map( single_func, batch_task_data, [True] * len(batch_task_data), ) ) # Update the original task data with the optimized solutions for task_data, sol in zip(batch_task_data, optimized_sols): task_data.sol = sol