Spaces:
Running
on
A100
Running
on
A100
File size: 1,930 Bytes
174ae06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
# Copyright (c) 2025 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license.
# LICENSE is in incl_licenses directory.
""" AutoResume callback.
A transformer trainer callback for interfacing with ADLR's AutoResume SDK.
Copyright 2024 NVIDIA CORPORATION.
"""
import os
import sys
import torch
import transformers
from transformers.utils import logging
logger = logging.get_logger("transformers")
def rank_print(*s):
if not torch.distributed.is_initialized():
rank = 0
else:
rank = torch.distributed.get_rank()
print(rank, *s)
sys.path.append(os.environ.get("SUBMIT_SCRIPTS", "."))
try:
logger.info("Importing AutoResume lib...")
from userlib.auto_resume import AutoResume
AutoResume.init()
logger.info("Found AutoResume SDK!")
except:
logger.warn("Did not find AutoResume SDK!")
AutoResume = None
class AutoResumeCallback(transformers.TrainerCallback):
"""
A [`TrainerCallback`] that handles autoresume.
Args:
interval: interval (in number of iterations) between checks as to
whether to suspend.
"""
def __init__(self, interval: int = 50):
self.interval = interval
def on_step_end(self, args, state, control, **kwargs):
if state.global_step % self.interval == 0:
rank_print("AutoResumeHook: Checking whether to suspend...")
# Check whether to suspend the job.
should_preempt = AutoResume is not None and AutoResume.termination_requested()
if should_preempt:
if state.is_local_process_zero:
logger.warn(f"AutoResumeHook: Request resume...")
if AutoResume is not None:
AutoResume.request_resume()
control.should_training_stop = True
control.should_save = True
|