# Copyright (c) 2025 NVIDIA CORPORATION. # Licensed under the MIT license. # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. # LICENSE is in incl_licenses directory. """ AutoResume callback. A transformer trainer callback for interfacing with ADLR's AutoResume SDK. Copyright 2024 NVIDIA CORPORATION. """ import os import sys import torch import transformers from transformers.utils import logging logger = logging.get_logger("transformers") def rank_print(*s): if not torch.distributed.is_initialized(): rank = 0 else: rank = torch.distributed.get_rank() print(rank, *s) sys.path.append(os.environ.get("SUBMIT_SCRIPTS", ".")) try: logger.info("Importing AutoResume lib...") from userlib.auto_resume import AutoResume AutoResume.init() logger.info("Found AutoResume SDK!") except: logger.warn("Did not find AutoResume SDK!") AutoResume = None class AutoResumeCallback(transformers.TrainerCallback): """ A [`TrainerCallback`] that handles autoresume. Args: interval: interval (in number of iterations) between checks as to whether to suspend. """ def __init__(self, interval: int = 50): self.interval = interval def on_step_end(self, args, state, control, **kwargs): if state.global_step % self.interval == 0: rank_print("AutoResumeHook: Checking whether to suspend...") # Check whether to suspend the job. should_preempt = AutoResume is not None and AutoResume.termination_requested() if should_preempt: if state.is_local_process_zero: logger.warn(f"AutoResumeHook: Request resume...") if AutoResume is not None: AutoResume.request_resume() control.should_training_stop = True control.should_save = True