ml4co_kit.learning.train
Trainer for ML4CO models.
Classes
|
|
|
|
|
- class ml4co_kit.learning.train.Checkpoint(dirpath: str = 'wandb/checkpoints', monitor: str = 'val/loss', every_n_epochs: int = 1, every_n_train_steps=None, filename=None, save_top_k: int = -1, mode: str | None = None)[source]
Bases:
ModelCheckpoint
- class ml4co_kit.learning.train.Logger(name: str = 'wandb', project: str = 'project', entity: str | None = None, save_dir: str = 'log', id: str | None = None, resume_id: str | None = None)[source]
Bases:
WandbLogger
- class ml4co_kit.learning.train.Trainer(model: Module, logger: Logger | None = None, wandb_logger_name: str = 'wandb', resume_id: str | None = None, ckpt_save_path: str | None = None, ckpt_monitor: str = 'val/loss', save_top_k: int = -1, mode: str = 'min', ckpt_every_n_epochs: int = 1, ckpt_every_n_train_steps: int | None = None, ckpt_filename: str | None = None, accelerator: str = 'auto', strategy: str | Strategy | None = None, devices: List[int] | str | int = 'auto', fp16: bool = False, max_epochs: int = 100, max_steps: int = -1, val_check_interval: int | None = None, log_every_n_steps: int | None = 50, gradient_clip_val: int = 1, inference_mode: bool = False, reload_dataloaders_every_n_epochs: int = 0, disable_profiling_executor: bool = True, ckpt_path: str | None = None, weight_path: str | None = None)[source]
Bases:
Trainer