|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any |
|
|
|
import torch |
|
|
|
from cosmos_transfer1.utils.lazy_config import LazyDict, instantiate |
|
|
|
|
|
class Model(torch.nn.Module): |
|
"""The base model class. It is inherited from torch.nn.Module. |
|
|
|
All models should inherit Model. It should include the implementions for all the |
|
computation graphs. All inheriting child classes should implement the following methods: |
|
- training_step(): The training step of the model, including the loss computation. |
|
- validation_step(): The validation step of the model, including the loss computation. |
|
- forward(): The computation graph for model inference. |
|
The following methods have default implementations in Model: |
|
- init_optimizer_scheduler(): Creates the optimizer and scheduler for the model. |
|
""" |
|
|
|
def __init__(self) -> None: |
|
super().__init__() |
|
self.on_model_init_start(set_barrier=False) |
|
|
|
def init_optimizer_scheduler( |
|
self, optimizer_config: LazyDict, scheduler_config: LazyDict |
|
) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]: |
|
"""Creates the optimizer and scheduler for the model. |
|
|
|
Args: |
|
config_model (ModelConfig): The config object for the model. |
|
|
|
Returns: |
|
optimizer (torch.optim.Optimizer): The model optimizer. |
|
scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. |
|
""" |
|
optimizer_config.params = self.parameters() |
|
optimizer = instantiate(optimizer_config) |
|
scheduler_config.optimizer = optimizer |
|
scheduler = instantiate(scheduler_config) |
|
return optimizer, scheduler |
|
|
|
def training_step( |
|
self, data_batch: dict[str, torch.Tensor], iteration: int |
|
) -> tuple[dict[str, torch.Tensor], torch.Tensor]: |
|
"""The training step of the model, including the loss computation. |
|
|
|
Args: |
|
data (dict[str, torch.Tensor]): Data batch (dictionary of tensors). |
|
iteration (int): Current iteration number. |
|
|
|
Returns: |
|
output_batch (dict[str, torch.Tensor]): Auxiliary model output from the training batch. |
|
loss (torch.Tensor): The total loss for backprop (weighted sum of various losses). |
|
""" |
|
raise NotImplementedError |
|
|
|
@torch.no_grad() |
|
def validation_step( |
|
self, data_batch: dict[str, torch.Tensor], iteration: int |
|
) -> tuple[dict[str, torch.Tensor], torch.Tensor]: |
|
"""The validation step of the model, including the loss computation. |
|
|
|
Args: |
|
data (dict[str, torch.Tensor]): Data batch (dictionary of tensors). |
|
iteration (int): Current iteration number. |
|
|
|
Returns: |
|
output_batch (dict[str, torch.Tensor]): Auxiliary model output from the validation batch. |
|
loss (torch.Tensor): The total loss (weighted sum of various losses). |
|
""" |
|
raise NotImplementedError |
|
|
|
@torch.inference_mode() |
|
def forward(self, *args: Any, **kwargs: Any) -> Any: |
|
"""The computation graph for model inference. |
|
|
|
Args: |
|
*args: Whatever you decide to pass into the forward method. |
|
**kwargs: Keyword arguments are also possible. |
|
|
|
Return: |
|
Your model's output. |
|
""" |
|
raise NotImplementedError |
|
|
|
def on_model_init_start(self, set_barrier=False) -> None: |
|
return |
|
|
|
def on_model_init_end(self, set_barrier=False) -> None: |
|
return |
|
|
|
def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None: |
|
"""The model preparation before the training is launched |
|
|
|
Args: |
|
memory_format (torch.memory_format): Memory format of the model. |
|
""" |
|
pass |
|
|
|
def on_before_zero_grad( |
|
self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, iteration: int |
|
) -> None: |
|
"""Hook before zero_grad() is called. |
|
|
|
Args: |
|
optimizer (torch.optim.Optimizer): The model optimizer. |
|
scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. |
|
iteration (int): Current iteration number. |
|
""" |
|
pass |
|
|
|
def on_after_backward(self, iteration: int = 0) -> None: |
|
"""Hook after loss.backward() is called. |
|
|
|
This method is called immediately after the backward pass, allowing for custom operations |
|
or modifications to be performed on the gradients before the optimizer step. |
|
|
|
Args: |
|
iteration (int): Current iteration number. |
|
""" |
|
pass |
|
|