Source code for ml4co_kit.learning.train

r"""
Trainer for ML4CO models.
"""

# Copyright (c) 2024 Thinklab@SJTU
# ML4CO-Kit is licensed under Mulan PSL v2.
# You can use this software according to the terms and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
# http://license.coscl.org.cn/MulanPSL2
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.


import os
import torch
from torch import nn
from typing import Optional, List
from wandb.util import generate_id
from typing import Union, Optional
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.trainer import Trainer as PLTrainer
from pytorch_lightning.strategies import Strategy, DDPStrategy
from pytorch_lightning.callbacks import (
    LearningRateMonitor, ModelCheckpoint, TQDMProgressBar
)


[docs]class Checkpoint(ModelCheckpoint): def __init__( self, dirpath: str = "wandb/checkpoints", monitor: str = "val/loss", every_n_epochs: int = 1, every_n_train_steps=None, filename=None, save_top_k: int = -1, mode: str = None, ): super().__init__( dirpath=dirpath, monitor=monitor, mode=mode, save_top_k=save_top_k, save_last=True, every_n_epochs=every_n_epochs, every_n_train_steps=every_n_train_steps, filename=filename, auto_insert_metric_name=False, )
[docs]class Logger(WandbLogger): def __init__( self, name: str = "wandb", project: str = "project", entity: Optional[str] = None, save_dir: str = "log", id: Optional[str] = None, resume_id: Optional[str] = None, ): if not os.path.exists(save_dir): os.makedirs(save_dir) if id is None and resume_id is None: wandb_id = os.getenv("WANDB_RUN_ID") or generate_id() else: wandb_id = id if id is not None else resume_id super().__init__( name=name, project=project, entity=entity, save_dir=save_dir, id=wandb_id )
[docs]class Trainer(PLTrainer): def __init__( self, model: nn.Module, # logger logger: Optional[Logger] = None, wandb_logger_name: str = "wandb", resume_id: Optional[str] = None, # checkpoint ckpt_save_path: Optional[str] = None, ckpt_monitor: str = "val/loss", save_top_k: int = -1, mode: str = "min", ckpt_every_n_epochs: int = 1, ckpt_every_n_train_steps: Optional[int] = None, ckpt_filename: str = None, # trainer basic accelerator: str = "auto", strategy: Union[str, Strategy] = None, devices: Union[List[int], str, int] = "auto", fp16: bool = False, max_epochs: int = 100, max_steps: int = -1, val_check_interval: Optional[int] = None, log_every_n_steps: Optional[int] = 50, gradient_clip_val: int = 1, inference_mode: bool = False, reload_dataloaders_every_n_epochs: int = 0, # Disable JIT profiling executor. disable_profiling_executor: bool = True, # pretrained ckpt_path: Optional[str] = None, weight_path: Optional[str] = None ): # logger if logger is None: self.logger = Logger(name=wandb_logger_name, resume_id=resume_id) else: self.logger = logger # checkpoint if ckpt_save_path is None: self.ckpt_save_path = os.path.join( "train_ckpts", self.logger._name, self.logger._id ) self.ckpt_callback = Checkpoint( dirpath=self.ckpt_save_path, monitor=ckpt_monitor, every_n_epochs=ckpt_every_n_epochs, every_n_train_steps=ckpt_every_n_train_steps, filename="epoch={epoch}-step={step}" if ckpt_filename is None else ckpt_filename, save_top_k=save_top_k, mode=mode ) # learning rate self.lr_callback = LearningRateMonitor(logging_interval="step") # strategy if strategy is None: strategy = DDPStrategy( static_graph=True, find_unused_parameters=True, gradient_as_bucket_view=True ) # Disable JIT profiling executor if disable_profiling_executor: try: torch._C._jit_set_profiling_executor(False) torch._C._jit_set_profiling_mode(False) except AttributeError: pass # super super().__init__( accelerator=accelerator, strategy=strategy, devices=devices, precision=16 if fp16 else 32, logger=self.logger, callbacks=[ TQDMProgressBar(refresh_rate=20), self.ckpt_callback, self.lr_callback, ], max_epochs=max_epochs, max_steps=max_steps, check_val_every_n_epoch=1, val_check_interval=val_check_interval, log_every_n_steps=log_every_n_steps, gradient_clip_val=gradient_clip_val, inference_mode=inference_mode, reload_dataloaders_every_n_epochs=reload_dataloaders_every_n_epochs ) if ckpt_path is not None: model.load_from_checkpoint(ckpt_path) elif weight_path is not None: model.load_state_dict(torch.load(weight_path)) self.train_model = model
[docs] def model_train(self): rank_zero_info( f"Logging to {self.logger.save_dir}/{self.logger.name}/{self.logger.version}" ) rank_zero_info(f"checkpoint_callback's dirpath is {self.ckpt_save_path}") rank_zero_info(f"{'-' * 100}\n" f"{str(self.train_model)}\n" f"{'-' * 100}\n") self.fit(self.train_model) self.logger.finalize("success")
[docs] def model_test(self): rank_zero_info( f"Logging to {self.logger.save_dir}/{self.logger.name}/{self.logger.version}" ) rank_zero_info(f"{'-' * 100}\n" f"{str(self.train_model)}\n" f"{'-' * 100}\n") self.test(self.train_model)