# Copyright 2020-2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import contextlib import functools import time from collections.abc import Generator from transformers import Trainer from transformers.integrations import is_mlflow_available, is_wandb_available if is_wandb_available(): import wandb if is_mlflow_available(): import mlflow @contextlib.contextmanager def profiling_context(trainer: Trainer, name: str) -> Generator[None, None, None]: """ A context manager function for profiling a block of code. Results are logged to Weights & Biases or MLflow depending on the trainer's configuration. Args: trainer (`~transformers.Trainer`): Trainer object. name (`str`): Name of the block to be profiled. Used as a key in the logged dictionary. Example: ```python from transformers import Trainer from trl.extras.profiling import profiling_context class MyTrainer(Trainer): def some_method(self): A = np.random.rand(1000, 1000) B = np.random.rand(1000, 1000) with profiling_context(self, "matrix_multiplication"): # Code to profile: simulate a computationally expensive operation result = A @ B # Matrix multiplication ``` """ start_time = time.perf_counter() yield end_time = time.perf_counter() duration = end_time - start_time profiling_metrics = {f"profiling/Time taken: {trainer.__class__.__name__}.{name}": duration} if "wandb" in trainer.args.report_to and wandb.run is not None and trainer.accelerator.is_main_process: wandb.log(profiling_metrics) if "mlflow" in trainer.args.report_to and mlflow.run is not None and trainer.accelerator.is_main_process: mlflow.log_metrics(profiling_metrics, step=trainer.state.global_step) def profiling_decorator(func: callable) -> callable: """ Decorator to profile a function and log execution time using [`extras.profiling.profiling_context`]. Args: func (`callable`): Function to be profiled. Example: ```python from transformers import Trainer from trl.extras.profiling import profiling_decorator class MyTrainer(Trainer): @profiling_decorator def some_method(self): A = np.random.rand(1000, 1000) B = np.random.rand(1000, 1000) # Code to profile: simulate a computationally expensive operation result = A @ B ``` """ @functools.wraps(func) def wrapper(self, *args, **kwargs): with profiling_context(self, func.__name__): return func(self, *args, **kwargs) return wrapper