Source code for ml4co_kit.wrapper.base

r"""
Base class for all wrappers in the ML4CO kit.
"""

# Copyright (c) 2024 Thinklab@SJTU
# ML4CO-Kit is licensed under Mulan PSL v2.
# You can use this software according to the terms and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
# http://license.coscl.org.cn/MulanPSL2
#
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.


import os
import random
import pickle
import pathlib
import numpy as np
from multiprocessing import Pool
from typing import Sequence, Union, Type, List
from ml4co_kit.solver.base import SolverBase
from ml4co_kit.utils.time_utils import Timer
from ml4co_kit.generator.base import GeneratorBase
from ml4co_kit.utils.time_utils import tqdm_by_time
from ml4co_kit.task.base import TASK_TYPE, TaskBase
from ml4co_kit.utils.pickle_utils import load_pickle
from ml4co_kit.utils.file_utils import check_file_path


[docs]class WrapperBase(object): def __init__( self, task_type: TASK_TYPE, precision: Union[np.float32, np.float64] = np.float32 ): self.task_type = task_type self.precision = precision self.task_list: List[TaskBase] = list()
[docs] def swap_sol_and_ref_sol(self): for task_data in self.task_list: tmp = task_data.ref_sol task_data.ref_sol = task_data.sol task_data.sol = tmp
[docs] def generate_w_to_txt( self, file_path: pathlib.Path, generator: GeneratorBase, solver: SolverBase, num_samples: int = 1280, num_threads: int = 1, batch_size: int = 1, write_per_iters: int = 1, write_mode: str = "a", show_time: bool = True ): # Calculate the total number of iterations if num_samples % (num_threads * batch_size * write_per_iters) != 0: raise ValueError(( "The number of samples must be divisible by " "the product of num_threads, batch size, and write_per_iters." )) tot_iters = num_samples // num_threads // batch_size # Generate tasks by chunks for cur_iter in tqdm_by_time( iterable=range(tot_iters), desc=f"Generating {self.task_type}", show_time=show_time ): # Generate tasks self.generate( generator=generator, solver=solver, num_samples=num_threads*batch_size, num_threads=num_threads, batch_size=batch_size, show_time=False ) # Write tasks to txt if (cur_iter+1) % write_per_iters == 0: self.to_txt(file_path=file_path, show_time=False, mode=write_mode) self.task_list = list()
[docs] def generate( self, generator: GeneratorBase, solver: SolverBase, num_samples: int = 1280, num_threads: int = 1, batch_size: int = 1, optimizer_parallel: bool = False, show_time: bool = True ): # Initialize Timer timer = Timer(apply=show_time) timer.start() # Case 1: Single Thread and Batch Size is 1 if num_threads == 1 and batch_size == 1: for _ in tqdm_by_time( iterable=range(num_samples), desc=f"Generating {self.task_type}", show_time=show_time ): task_data = generator.generate() solver.solve(task_data) self.task_list.append(task_data) # Case 2: Multi Thread and Batch Size is 1 (usually for traditional solver, cpu) elif num_threads != 1 and batch_size == 1: # Check if the number of samples is divisible by the number of threads if num_samples % num_threads != 0: raise ValueError( "The number of samples must be divisible by the number of threads." ) # Generate Tasks for _ in tqdm_by_time( iterable=range(num_samples // num_threads), desc=f"Generating {self.task_type}", show_time=show_time ): with Pool(num_threads) as p1: tasks = p1.starmap( self._generate, [(generator, solver) for _ in range(num_threads)] ) self.task_list.extend(tasks) # Case 3: Single Thread and Batch Size is not 1 (usually for ML4CO solver, gpu) elif num_threads == 1 and batch_size != 1: # Check if the number of samples is divisible by the batch size if num_samples % batch_size != 0: raise ValueError( "The number of samples must be divisible by the batch size." ) # Generate Tasks for _ in tqdm_by_time( iterable=range(num_samples // batch_size), desc=f"Generating {self.task_type}", show_time=show_time ): batch_task_data = [generator.generate() for _ in range(batch_size)] solver.batch_solve(batch_task_data, optimizer_parallel) self.task_list.extend(batch_task_data) # Case 4: Multi Thread and Batch Size is not 1 else: raise ValueError(( "``num_threads`` and ``batch_size`` cannot " "both be greater than 1 at the same time." )) # End Timer timer.end() timer.show_time()
def _generate(self, generator: GeneratorBase, solver: SolverBase) -> TaskBase: seed = os.getpid() % 2**32 # Seed random.seed(seed) # Set seed for random np.random.seed(seed) # Set seed for numpy task_data = generator.generate() # Generate Task Data solver.solve(task_data) # Solve Task Data return task_data
[docs] def from_txt(self, file_path: pathlib.Path, *args, **kwargs): raise NotImplementedError( "The ``from_txt`` function is required to implemented in subclasses." )
[docs] def to_txt(file_path: pathlib.Path, show_time: bool = False, mode: str = "w"): raise NotImplementedError( "The ``to_txt`` function is required to implemented in subclasses." )
[docs] def from_pickle(self, file_path: pathlib.Path): with open(file_path, "rb") as file: self.task_list = load_pickle(file) for task_data in self.task_list: if isinstance(task_data, TaskBase): task_data._repair_pickle_state()
[docs] def to_pickle(self, file_path: pathlib.Path): # Check file path check_file_path(file_path) # Save task data to ``.pkl`` file with open(file_path, "wb") as f: pickle.dump(self.task_list, f) f.close()
[docs] def from_task_pickle_folder( self, task_class: Type[TaskBase], folder_path: pathlib.Path ): # Get pickle files pickle_files = os.listdir(folder_path) pickle_files.sort() # Load task data self.task_list = list() for pickle_file in pickle_files: pkl_path = folder_path / pickle_file task_data = task_class() task_data.from_pickle(pkl_path) self.task_list.append(task_data)
[docs] def to_task_pickle_folder(self, folder_path: pathlib.Path): # Create folder os.makedirs(folder_path, exist_ok=True) # Save task data for task_data in self.task_list: pkl_path = folder_path / f"{task_data.name}.pkl" task_data.to_pickle(pkl_path)
[docs] def solve( self, solver: SolverBase, num_threads: int = 1, batch_size: int = 1, optimizer_parallel: bool = False, show_time: bool = False ): # Initialize Timer timer = Timer(apply=show_time) timer.start() # Solving Message solve_msg = f"Solving {self.task_type.value} Using {solver.solver_type.value}" # Case 1: Single Thread and Batch Size is 1 if num_threads == 1 and batch_size == 1: for task_data in tqdm_by_time( iterable=self.task_list, desc=solve_msg, show_time=show_time ): solver.solve(task_data) # Case 2: Multi Thread and Batch Size is 1 elif num_threads != 1 and batch_size == 1: # Check if the number of tasks is divisible by the number of threads if len(self.task_list) % num_threads != 0: raise ValueError( "The number of tasks must be divisible by the number of threads." ) # Solve Tasks for idx in tqdm_by_time( iterable=range(len(self.task_list) // num_threads), desc=solve_msg, show_time=show_time ): with Pool(num_threads) as p1: task_data_list = p1.map( solver.solve, [self.task_list[idx*num_threads+i] for i in range(num_threads)] ) for j, task_data in enumerate(task_data_list): self.task_list[idx*num_threads+j] = task_data # Case 3: Single Thread and Batch Size is not 1 elif num_threads == 1 and batch_size != 1: # Check if the number of tasks is divisible by the batch size if len(self.task_list) % batch_size != 0: raise ValueError( "The number of tasks must be divisible by the batch size." ) # Solve Tasks for idx in tqdm_by_time( iterable=range(len(self.task_list) // batch_size), desc=solve_msg, show_time=show_time ): batch_task_data = [self.task_list[idx*batch_size+i] for i in range(batch_size)] solver.batch_solve(batch_task_data, optimizer_parallel) # Case 4: Multi Thread and Batch Size is not 1 else: raise ValueError(( "``num_threads`` and ``batch_size`` cannot " "both be greater than 1 at the same time." )) # End Timer timer.end() timer.show_time()
[docs] def evaluate(self, check_constr: bool = True) -> float: """Evaluate the task list.""" sol_costs_list = list() for task_data in self.task_list: sol_cost = task_data.evaluate( task_data.sol, check_constr=check_constr ) sol_costs_list.append(sol_cost) return float(np.mean(sol_costs_list))
[docs] def evaluate_w_gap(self, check_constr: bool = True) -> Sequence[float]: """Evaluate the task list.""" # Initialize lists sol_costs_list = list() ref_costs_list = list() gaps_list = list() # Evaluate the task list for task_data in self.task_list: sol_cost, ref_cost, gap = task_data.evaluate_w_gap( check_constr=check_constr) sol_costs_list.append(sol_cost) ref_costs_list.append(ref_cost) gaps_list.append(gap) # Calculate the average solution cost, reference cost, and gap avg_sol_cost = float(np.mean(sol_costs_list)) avg_ref_cost = float(np.mean(ref_costs_list)) if None not in gaps_list: avg_gap = float(np.mean(gaps_list)) gap_std = float(np.std(gaps_list)) else: avg_gap = None gap_std = None return avg_sol_cost, avg_ref_cost, avg_gap, gap_std