Spaces:
Build error
Build error
File size: 5,261 Bytes
b6af722 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from cosmos_predict1.utils.lazy_config import LazyDict, instantiate
class Model(torch.nn.Module):
"""The base model class. It is inherited from torch.nn.Module.
All models should inherit Model. It should include the implementions for all the
computation graphs. All inheriting child classes should implement the following methods:
- training_step(): The training step of the model, including the loss computation.
- validation_step(): The validation step of the model, including the loss computation.
- forward(): The computation graph for model inference.
The following methods have default implementations in Model:
- init_optimizer_scheduler(): Creates the optimizer and scheduler for the model.
"""
def __init__(self) -> None:
super().__init__()
self.on_model_init_start(set_barrier=False)
def init_optimizer_scheduler(
self, optimizer_config: LazyDict, scheduler_config: LazyDict
) -> tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]:
"""Creates the optimizer and scheduler for the model.
Args:
config_model (ModelConfig): The config object for the model.
Returns:
optimizer (torch.optim.Optimizer): The model optimizer.
scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
"""
optimizer_config.params = self.parameters()
optimizer = instantiate(optimizer_config)
scheduler_config.optimizer = optimizer
scheduler = instantiate(scheduler_config)
return optimizer, scheduler
def training_step(
self, data_batch: dict[str, torch.Tensor], iteration: int
) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
"""The training step of the model, including the loss computation.
Args:
data (dict[str, torch.Tensor]): Data batch (dictionary of tensors).
iteration (int): Current iteration number.
Returns:
output_batch (dict[str, torch.Tensor]): Auxiliary model output from the training batch.
loss (torch.Tensor): The total loss for backprop (weighted sum of various losses).
"""
raise NotImplementedError
@torch.no_grad()
def validation_step(
self, data_batch: dict[str, torch.Tensor], iteration: int
) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
"""The validation step of the model, including the loss computation.
Args:
data (dict[str, torch.Tensor]): Data batch (dictionary of tensors).
iteration (int): Current iteration number.
Returns:
output_batch (dict[str, torch.Tensor]): Auxiliary model output from the validation batch.
loss (torch.Tensor): The total loss (weighted sum of various losses).
"""
raise NotImplementedError
@torch.inference_mode()
def forward(self, *args: Any, **kwargs: Any) -> Any:
"""The computation graph for model inference.
Args:
*args: Whatever you decide to pass into the forward method.
**kwargs: Keyword arguments are also possible.
Return:
Your model's output.
"""
raise NotImplementedError
def on_model_init_start(self, set_barrier=False) -> None:
return
def on_model_init_end(self, set_barrier=False) -> None:
return
def on_train_start(self, memory_format: torch.memory_format = torch.preserve_format) -> None:
"""The model preparation before the training is launched
Args:
memory_format (torch.memory_format): Memory format of the model.
"""
pass
def on_before_zero_grad(
self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, iteration: int
) -> None:
"""Hook before zero_grad() is called.
Args:
optimizer (torch.optim.Optimizer): The model optimizer.
scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler.
iteration (int): Current iteration number.
"""
pass
def on_after_backward(self, iteration: int = 0) -> None:
"""Hook after loss.backward() is called.
This method is called immediately after the backward pass, allowing for custom operations
or modifications to be performed on the gradients before the optimizer step.
Args:
iteration (int): Current iteration number.
"""
pass
|