trl-sandbox / trl /extras /profiling.py
ivangabriele's picture
feat: initialize project
2f5127c verified
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import functools
import time
from collections.abc import Generator
from transformers import Trainer
from transformers.integrations import is_mlflow_available, is_wandb_available
if is_wandb_available():
import wandb
if is_mlflow_available():
import mlflow
@contextlib.contextmanager
def profiling_context(trainer: Trainer, name: str) -> Generator[None, None, None]:
"""
A context manager function for profiling a block of code. Results are logged to Weights & Biases or MLflow
depending on the trainer's configuration.
Args:
trainer (`~transformers.Trainer`):
Trainer object.
name (`str`):
Name of the block to be profiled. Used as a key in the logged dictionary.
Example:
```python
from transformers import Trainer
from trl.extras.profiling import profiling_context
class MyTrainer(Trainer):
def some_method(self):
A = np.random.rand(1000, 1000)
B = np.random.rand(1000, 1000)
with profiling_context(self, "matrix_multiplication"):
# Code to profile: simulate a computationally expensive operation
result = A @ B # Matrix multiplication
```
"""
start_time = time.perf_counter()
yield
end_time = time.perf_counter()
duration = end_time - start_time
profiling_metrics = {f"profiling/Time taken: {trainer.__class__.__name__}.{name}": duration}
if "wandb" in trainer.args.report_to and wandb.run is not None and trainer.accelerator.is_main_process:
wandb.log(profiling_metrics)
if "mlflow" in trainer.args.report_to and mlflow.run is not None and trainer.accelerator.is_main_process:
mlflow.log_metrics(profiling_metrics, step=trainer.state.global_step)
def profiling_decorator(func: callable) -> callable:
"""
Decorator to profile a function and log execution time using [`extras.profiling.profiling_context`].
Args:
func (`callable`):
Function to be profiled.
Example:
```python
from transformers import Trainer
from trl.extras.profiling import profiling_decorator
class MyTrainer(Trainer):
@profiling_decorator
def some_method(self):
A = np.random.rand(1000, 1000)
B = np.random.rand(1000, 1000)
# Code to profile: simulate a computationally expensive operation
result = A @ B
```
"""
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
with profiling_context(self, func.__name__):
return func(self, *args, **kwargs)
return wrapper