Source code for ml4co_kit.generator.base

r"""
Base class for all generators in the ML4CO kit.
"""

# Copyright (c) 2024 Thinklab@SJTU
# ML4CO-Kit is licensed under Mulan PSL v2.
# You can use this software according to the terms and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
# http://license.coscl.org.cn/MulanPSL2
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.


import numpy as np
from typing import Union, Any
from ml4co_kit.task.base import TaskBase, TASK_TYPE


[docs]class GeneratorBase(object): """Base class for all generators.""" def __init__( self, task_type: TASK_TYPE, distribution_type: Any, precision: Union[np.float32, np.float64] = np.float32, ): self.task_type = task_type self.distribution_type = distribution_type self.precision = precision self.generate_func_dict: dict = None
[docs] def generate(self) -> TaskBase: if self.distribution_type not in self.generate_func_dict: raise NotImplementedError( f"The distribution type {self.distribution_type} is not supported." ) else: return self.generate_func_dict[self.distribution_type]()
def __repr__(self) -> str: return f"{self.task_type.value}Generator"