Spaces:
Running
on
Zero
Running
on
Zero
| import math | |
| import os | |
| import time | |
| from collections import defaultdict | |
| from functools import wraps | |
| import torch | |
| import torch.distributed as dist | |
| from rich import box | |
| from rich.console import Console | |
| from rich.console import Group | |
| from rich.live import Live | |
| from rich.markdown import Markdown | |
| from rich.padding import Padding | |
| from rich.panel import Panel | |
| from rich.progress import BarColumn | |
| from rich.progress import Progress | |
| from rich.progress import SpinnerColumn | |
| from rich.progress import TimeElapsedColumn | |
| from rich.progress import TimeRemainingColumn | |
| from rich.rule import Rule | |
| from rich.table import Table | |
| from torch.utils.tensorboard import SummaryWriter | |
| # This is here so that the history can be pickled. | |
| def default_list(): | |
| return [] | |
| class Mean: | |
| """Keeps track of the running mean, along with the latest | |
| value. | |
| """ | |
| def __init__(self): | |
| self.reset() | |
| def __call__(self): | |
| mean = self.total / max(self.count, 1) | |
| return mean | |
| def reset(self): | |
| self.count = 0 | |
| self.total = 0 | |
| def update(self, val): | |
| if math.isfinite(val): | |
| self.count += 1 | |
| self.total += val | |
| def when(condition): | |
| """Runs a function only when the condition is met. The condition is | |
| a function that is run. | |
| Parameters | |
| ---------- | |
| condition : Callable | |
| Function to run to check whether or not to run the decorated | |
| function. | |
| Example | |
| ------- | |
| Checkpoint only runs every 100 iterations, and only if the | |
| local rank is 0. | |
| >>> i = 0 | |
| >>> rank = 0 | |
| >>> | |
| >>> @when(lambda: i % 100 == 0 and rank == 0) | |
| >>> def checkpoint(): | |
| >>> print("Saving to /runs/exp1") | |
| >>> | |
| >>> for i in range(1000): | |
| >>> checkpoint() | |
| """ | |
| def decorator(fn): | |
| def decorated(*args, **kwargs): | |
| if condition(): | |
| return fn(*args, **kwargs) | |
| return decorated | |
| return decorator | |
| def timer(prefix: str = "time"): | |
| """Adds execution time to the output dictionary of the decorated | |
| function. The function decorated by this must output a dictionary. | |
| The key added will follow the form "[prefix]/[name_of_function]" | |
| Parameters | |
| ---------- | |
| prefix : str, optional | |
| The key added will follow the form "[prefix]/[name_of_function]", | |
| by default "time". | |
| """ | |
| def decorator(fn): | |
| def decorated(*args, **kwargs): | |
| s = time.perf_counter() | |
| output = fn(*args, **kwargs) | |
| assert isinstance(output, dict) | |
| e = time.perf_counter() | |
| output[f"{prefix}/{fn.__name__}"] = e - s | |
| return output | |
| return decorated | |
| return decorator | |
| class Tracker: | |
| """ | |
| A tracker class that helps to monitor the progress of training and logging the metrics. | |
| Attributes | |
| ---------- | |
| metrics : dict | |
| A dictionary containing the metrics for each label. | |
| history : dict | |
| A dictionary containing the history of metrics for each label. | |
| writer : SummaryWriter | |
| A SummaryWriter object for logging the metrics. | |
| rank : int | |
| The rank of the current process. | |
| step : int | |
| The current step of the training. | |
| tasks : dict | |
| A dictionary containing the progress bars and tables for each label. | |
| pbar : Progress | |
| A progress bar object for displaying the progress. | |
| consoles : list | |
| A list of console objects for logging. | |
| live : Live | |
| A Live object for updating the display live. | |
| Methods | |
| ------- | |
| print(msg: str) | |
| Prints the given message to all consoles. | |
| update(label: str, fn_name: str) | |
| Updates the progress bar and table for the given label. | |
| done(label: str, title: str) | |
| Resets the progress bar and table for the given label and prints the final result. | |
| track(label: str, length: int, completed: int = 0, op: dist.ReduceOp = dist.ReduceOp.AVG, ddp_active: bool = "LOCAL_RANK" in os.environ) | |
| A decorator for tracking the progress and metrics of a function. | |
| log(label: str, value_type: str = "value", history: bool = True) | |
| A decorator for logging the metrics of a function. | |
| is_best(label: str, key: str) -> bool | |
| Checks if the latest value of the given key in the label is the best so far. | |
| state_dict() -> dict | |
| Returns a dictionary containing the state of the tracker. | |
| load_state_dict(state_dict: dict) -> Tracker | |
| Loads the state of the tracker from the given state dictionary. | |
| """ | |
| def __init__( | |
| self, | |
| writer: SummaryWriter = None, | |
| log_file: str = None, | |
| rank: int = 0, | |
| console_width: int = 100, | |
| step: int = 0, | |
| ): | |
| """ | |
| Initializes the Tracker object. | |
| Parameters | |
| ---------- | |
| writer : SummaryWriter, optional | |
| A SummaryWriter object for logging the metrics, by default None. | |
| log_file : str, optional | |
| The path to the log file, by default None. | |
| rank : int, optional | |
| The rank of the current process, by default 0. | |
| console_width : int, optional | |
| The width of the console, by default 100. | |
| step : int, optional | |
| The current step of the training, by default 0. | |
| """ | |
| self.metrics = {} | |
| self.history = {} | |
| self.writer = writer | |
| self.rank = rank | |
| self.step = step | |
| # Create progress bars etc. | |
| self.tasks = {} | |
| self.pbar = Progress( | |
| SpinnerColumn(), | |
| "[progress.description]{task.description}", | |
| "{task.completed}/{task.total}", | |
| BarColumn(), | |
| TimeElapsedColumn(), | |
| "/", | |
| TimeRemainingColumn(), | |
| ) | |
| self.consoles = [Console(width=console_width)] | |
| self.live = Live(console=self.consoles[0], refresh_per_second=10) | |
| if log_file is not None: | |
| self.consoles.append(Console(width=console_width, file=open(log_file, "a"))) | |
| def print(self, msg): | |
| """ | |
| Prints the given message to all consoles. | |
| Parameters | |
| ---------- | |
| msg : str | |
| The message to be printed. | |
| """ | |
| if self.rank == 0: | |
| for c in self.consoles: | |
| c.log(msg) | |
| def update(self, label, fn_name): | |
| """ | |
| Updates the progress bar and table for the given label. | |
| Parameters | |
| ---------- | |
| label : str | |
| The label of the progress bar and table to be updated. | |
| fn_name : str | |
| The name of the function associated with the label. | |
| """ | |
| if self.rank == 0: | |
| self.pbar.advance(self.tasks[label]["pbar"]) | |
| # Create table | |
| table = Table(title=label, expand=True, box=box.MINIMAL) | |
| table.add_column("key", style="cyan") | |
| table.add_column("value", style="bright_blue") | |
| table.add_column("mean", style="bright_green") | |
| keys = self.metrics[label]["value"].keys() | |
| for k in keys: | |
| value = self.metrics[label]["value"][k] | |
| mean = self.metrics[label]["mean"][k]() | |
| table.add_row(k, f"{value:10.6f}", f"{mean:10.6f}") | |
| self.tasks[label]["table"] = table | |
| tables = [t["table"] for t in self.tasks.values()] | |
| group = Group(*tables, self.pbar) | |
| self.live.update( | |
| Group( | |
| Padding("", (0, 0)), | |
| Rule(f"[italic]{fn_name}()", style="white"), | |
| Padding("", (0, 0)), | |
| Panel.fit( | |
| group, padding=(0, 5), title="[b]Progress", border_style="blue" | |
| ), | |
| ) | |
| ) | |
| def done(self, label: str, title: str): | |
| """ | |
| Resets the progress bar and table for the given label and prints the final result. | |
| Parameters | |
| ---------- | |
| label : str | |
| The label of the progress bar and table to be reset. | |
| title : str | |
| The title to be displayed when printing the final result. | |
| """ | |
| for label in self.metrics: | |
| for v in self.metrics[label]["mean"].values(): | |
| v.reset() | |
| if self.rank == 0: | |
| self.pbar.reset(self.tasks[label]["pbar"]) | |
| tables = [t["table"] for t in self.tasks.values()] | |
| group = Group(Markdown(f"# {title}"), *tables, self.pbar) | |
| self.print(group) | |
| def track( | |
| self, | |
| label: str, | |
| length: int, | |
| completed: int = 0, | |
| op: dist.ReduceOp = dist.ReduceOp.AVG, | |
| ddp_active: bool = "LOCAL_RANK" in os.environ, | |
| ): | |
| """ | |
| A decorator for tracking the progress and metrics of a function. | |
| Parameters | |
| ---------- | |
| label : str | |
| The label to be associated with the progress and metrics. | |
| length : int | |
| The total number of iterations to be completed. | |
| completed : int, optional | |
| The number of iterations already completed, by default 0. | |
| op : dist.ReduceOp, optional | |
| The reduce operation to be used, by default dist.ReduceOp.AVG. | |
| ddp_active : bool, optional | |
| Whether the DistributedDataParallel is active, by default "LOCAL_RANK" in os.environ. | |
| """ | |
| self.tasks[label] = { | |
| "pbar": self.pbar.add_task( | |
| f"[white]Iteration ({label})", total=length, completed=completed | |
| ), | |
| "table": Table(), | |
| } | |
| self.metrics[label] = { | |
| "value": defaultdict(), | |
| "mean": defaultdict(lambda: Mean()), | |
| } | |
| def decorator(fn): | |
| def decorated(*args, **kwargs): | |
| output = fn(*args, **kwargs) | |
| if not isinstance(output, dict): | |
| self.update(label, fn.__name__) | |
| return output | |
| # Collect across all DDP processes | |
| scalar_keys = [] | |
| for k, v in output.items(): | |
| if isinstance(v, (int, float)): | |
| v = torch.tensor([v]) | |
| if not torch.is_tensor(v): | |
| continue | |
| if ddp_active and v.is_cuda: # pragma: no cover | |
| dist.all_reduce(v, op=op) | |
| output[k] = v.detach() | |
| if torch.numel(v) == 1: | |
| scalar_keys.append(k) | |
| output[k] = v.item() | |
| # Save the outputs to tracker | |
| for k, v in output.items(): | |
| if k not in scalar_keys: | |
| continue | |
| self.metrics[label]["value"][k] = v | |
| # Update the running mean | |
| self.metrics[label]["mean"][k].update(v) | |
| self.update(label, fn.__name__) | |
| return output | |
| return decorated | |
| return decorator | |
| def log(self, label: str, value_type: str = "value", history: bool = True): | |
| """ | |
| A decorator for logging the metrics of a function. | |
| Parameters | |
| ---------- | |
| label : str | |
| The label to be associated with the logging. | |
| value_type : str, optional | |
| The type of value to be logged, by default "value". | |
| history : bool, optional | |
| Whether to save the history of the metrics, by default True. | |
| """ | |
| assert value_type in ["mean", "value"] | |
| if history: | |
| if label not in self.history: | |
| self.history[label] = defaultdict(default_list) | |
| def decorator(fn): | |
| def decorated(*args, **kwargs): | |
| output = fn(*args, **kwargs) | |
| if self.rank == 0: | |
| nonlocal value_type, label | |
| metrics = self.metrics[label][value_type] | |
| for k, v in metrics.items(): | |
| v = v() if isinstance(v, Mean) else v | |
| if self.writer is not None: | |
| self.writer.add_scalar(f"{k}/{label}", v, self.step) | |
| if label in self.history: | |
| self.history[label][k].append(v) | |
| if label in self.history: | |
| self.history[label]["step"].append(self.step) | |
| return output | |
| return decorated | |
| return decorator | |
| def is_best(self, label, key): | |
| """ | |
| Checks if the latest value of the given key in the label is the best so far. | |
| Parameters | |
| ---------- | |
| label : str | |
| The label of the metrics to be checked. | |
| key : str | |
| The key of the metric to be checked. | |
| Returns | |
| ------- | |
| bool | |
| True if the latest value is the best so far, otherwise False. | |
| """ | |
| return self.history[label][key][-1] == min(self.history[label][key]) | |
| def state_dict(self): | |
| """ | |
| Returns a dictionary containing the state of the tracker. | |
| Returns | |
| ------- | |
| dict | |
| A dictionary containing the history and step of the tracker. | |
| """ | |
| return {"history": self.history, "step": self.step} | |
| def load_state_dict(self, state_dict): | |
| """ | |
| Loads the state of the tracker from the given state dictionary. | |
| Parameters | |
| ---------- | |
| state_dict : dict | |
| A dictionary containing the history and step of the tracker. | |
| Returns | |
| ------- | |
| Tracker | |
| The tracker object with the loaded state. | |
| """ | |
| self.history = state_dict["history"] | |
| self.step = state_dict["step"] | |
| return self | |