Source code for ml4co_kit.task.base

r"""
Base class for all problems in the ML4CO kit.
"""

# Copyright (c) 2024 Thinklab@SJTU
# ML4CO-Kit is licensed under Mulan PSL v2.
# You can use this software according to the terms and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
# http://license.coscl.org.cn/MulanPSL2
#
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.


import uuid
import pickle
import pathlib
import hashlib
import numpy as np
from enum import Enum
from typing import Sequence, Union
from ml4co_kit.utils.pickle_utils import load_pickle
from ml4co_kit.utils.file_utils import check_file_path


[docs]class TASK_TYPE(str, Enum): """Define the task types as an enumeration.""" # 1. Routing Problems (Routing) # 1.1 TSP Variants TSP = "TSP" # Traveling Salesman Problem ATSP = "ATSP" # Asymmetric Traveling Salesman Problem OP = "OP" # Orienteering Problem PCTSP = "PCTSP" # Prize Collection Traveling Salesman Problem SPCTSP = "SPCTSP" # Stochastic Prize Collection Traveling Salesman Problem # 1.2 VRP Variants CVRP = "CVRP" # Capacitated Vehicle Routing Problem CVRPB = "CVRPB" # B: Backhauls CVRPL = "CVRPL" # L: Route Length Limit CVRPTW = "CVRPTW" # TW: Time Windows CVRPBL = "CVRPBL" # B and L CVRPBTW = "CVRPBTW" # B and TW CVRPLTW = "CVRPLTW" # L and TW CVRPBLTW = "CVRPBLTW" # B and L and TW # 2. Graph Problems (Graph) MCL = "MCl" # Maximum Clique MCUT = "MCut" # Maximum Cut MIS = "MIS" # Maximum Independent Set MVC = "MVC" # Minimum Vertex Cover # 3. Quadratic Assignment Problems (QAP) GM = "GM" # Graph Matching GED = "GED" # Graph Edit Distance KQAP = "KQAP" # Koopmans-Beckmann QAP LQAP = "LQAP" # Lawler QAP # 4. Mixed Integer Programming Problems (MIP) MIP = "MIP" # Mixed Integer Programming MILP = "MILP" # Mixed Integer Linear Programming LP = "LP" # Linear Program # 5. Portfolio Optimization Problems (Portfolio) MAXRETPO = "MaxRetPO" # Maximum Return Portfolio Optimization MINVARPO = "MinVarPO" # Minimum Variance Portfolio Optimization MOPO = "MOPO" # Multi-Objective Portfolio Optimization # 6. Boolean Satisfiability Problems (SAT) SATP = "SAT-P" # Satisfiability Prediction Problem SATA = "SAT-A" # Satisfying Assignment Prediction # 7. Electronic Design Automation Problems (EDA) EDAP = "EDA-P" # EDA Placement EDATDP = "EDA-TDP" # EDA Timing-Driven Placement EDAR = "EDA-R" # EDA Routing
[docs]class TaskBase(object): """Base class for all tasks in the ML4CO kit.""" def __init__( self, task_type: TASK_TYPE, minimize: bool, precision: Union[np.float32, np.float64] = np.float32 ): self.task_type = task_type # Task type self.minimize = minimize # Whether to minimize the objective function self.precision = precision # Precision self.sol: np.ndarray = None # Solution self.ref_sol: np.ndarray = None # Reference solution self.cache: dict = {} # Cache (used for optimization) self.name: str = uuid.uuid4().hex # Name of the instance def _check_sol_not_none(self): """Check if solution is not None.""" if self.sol is None: raise ValueError("``sol`` cannot be None!") def _check_ref_sol_not_none(self): """Check if reference solution is not None.""" if self.ref_sol is None: raise ValueError("``ref_sol`` cannot be None!")
[docs] def from_pickle(self, file_path: pathlib.Path): """Create a problem instance from a pickle file.""" with open(file_path, "rb") as file: loaded_instance: TaskBase = load_pickle(file) self.__dict__.update(loaded_instance.__dict__) self._repair_pickle_state()
def _repair_pickle_state(self): """Fill missing attributes after loading legacy pickle files.""" fresh = self.__class__() for key, value in fresh.__dict__.items(): if key not in self.__dict__: self.__dict__[key] = value
[docs] def to_pickle(self, file_path: pathlib.Path): # Check file path check_file_path(file_path) # Save task data to ``.pkl`` file with open(file_path, "wb") as f: pickle.dump(self, f) f.close()
[docs] def from_data(self): """Create a problem instance from raw data. To be implemented by subclasses.""" raise NotImplementedError("Subclasses should implement this method.")
[docs] def check_constraints(self, sol: np.ndarray) -> bool: """Check if the given solution satisfies all problem constraints. To be implemented by subclasses.""" raise NotImplementedError("Subclasses should implement this method.")
[docs] def evaluate(self, sol: np.ndarray, check_constr: bool = True) -> np.floating: """Evaluate the given solution. To be implemented by subclasses.""" raise NotImplementedError("Subclasses should implement this method.")
[docs] def evaluate_w_gap(self, check_constr: bool = True) -> Sequence[np.floating]: """Evaluate the given solution with gap.""" # Check if the solution and reference solution are not None if self.sol is None or self.ref_sol is None: raise ValueError("Solution and reference solution cannot be None!") # Evaluate the solution and reference solution sol_cost = self.evaluate(self.sol, check_constr=check_constr) ref_cost = self.evaluate(self.ref_sol, check_constr=check_constr) # Calculate the gap if abs(ref_cost) < 1e-8: gap = None else: if self.minimize: gap = (sol_cost - ref_cost) / ref_cost else: gap = (ref_cost - sol_cost) / ref_cost gap = gap * np.array(100.0).astype(self.precision) return sol_cost, ref_cost, gap
[docs] def render(self): """Render the problem instance. To be implemented by subclasses.""" raise NotImplementedError("Subclasses should implement this method.")
[docs] def get_data_md5(self) -> str: """ Calculate MD5 hash of the task's data content. This method computes the MD5 hash based on the actual data content rather than the file content, which is useful for verifying data integrity when pickle files may have different object references. Returns: str: MD5 hash of the task's data content """ data_parts = [] ignore_list = ['dist_eval', 'name', 'g1', 'g2', 'affn_builder'] # Get all attributes from __dict__ except dist_eval (which contains object references) task_dict = {k: v for k, v in self.__dict__.items() if k not in ignore_list} # Sort keys for consistent ordering for key in sorted(task_dict.keys()): value = task_dict[key] # Handle numpy arrays if isinstance(value, np.ndarray) and value is not None: data_parts.append(value.tobytes()) # Handle other data types elif value is not None: data_parts.append(str(value).encode()) # Combine all data and compute MD5 combined_data = b''.join(data_parts) return hashlib.md5(combined_data).hexdigest()
def __repr__(self): return f"{self.task_type.value}Task({self.name})"