Spaces:
Build error
Build error
File size: 4,750 Bytes
b6af722 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""registry for commandline override options for config."""
from hydra.core.config_store import ConfigStore
from cosmos_predict1.tokenizer.training.configs.base.callback import BASIC_CALLBACKS
from cosmos_predict1.tokenizer.training.configs.base.checkpoint import CHECKPOINT_LOCAL
from cosmos_predict1.tokenizer.training.configs.base.data import DATALOADER_OPTIONS
from cosmos_predict1.tokenizer.training.configs.base.loss import VideoLossConfig
from cosmos_predict1.tokenizer.training.configs.base.metric import DiscreteTokenizerMetricConfig, MetricConfig
from cosmos_predict1.tokenizer.training.configs.base.net import (
CausalContinuousFactorizedVideoTokenizerConfig,
CausalDiscreteFactorizedVideoTokenizerConfig,
ContinuousImageTokenizerConfig,
DiscreteImageTokenizerConfig,
)
from cosmos_predict1.tokenizer.training.configs.base.optim import (
AdamWConfig,
FusedAdamConfig,
WarmupCosineLRConfig,
WarmupLRConfig,
)
def register_training_data(cs):
for data_source in ["mock", "hdvila"]:
for resolution in ["1080", "720", "480", "360", "256"]:
cs.store(
group="data_train",
package="dataloader_train",
name=f"{data_source}_video{resolution}", # `davis_video720`
node=DATALOADER_OPTIONS["video_loader_basic"](
dataset_name=f"{data_source}_video",
is_train=True,
resolution=resolution,
),
)
def register_val_data(cs):
for data_source in ["mock", "hdvila"]:
for resolution in ["1080", "720", "480", "360", "256"]:
cs.store(
group="data_val",
package="dataloader_val",
name=f"{data_source}_video{resolution}", # `davis_video720`
node=DATALOADER_OPTIONS["video_loader_basic"](
dataset_name=f"{data_source}_video",
is_train=False,
resolution=resolution,
),
)
def register_net(cs):
cs.store(
group="network", package="model.config.network", name="continuous_image", node=ContinuousImageTokenizerConfig
)
cs.store(group="network", package="model.config.network", name="discrete_image", node=DiscreteImageTokenizerConfig)
cs.store(
group="network",
package="model.config.network",
name="continuous_factorized_video",
node=CausalContinuousFactorizedVideoTokenizerConfig,
)
cs.store(
group="network",
package="model.config.network",
name="discrete_factorized_video",
node=CausalDiscreteFactorizedVideoTokenizerConfig,
)
def register_optim(cs):
cs.store(group="optimizer", package="optimizer", name="fused_adam", node=FusedAdamConfig)
cs.store(group="optimizer", package="optimizer", name="adamw", node=AdamWConfig)
def register_scheduler(cs):
cs.store(group="scheduler", package="scheduler", name="warmup", node=WarmupLRConfig)
cs.store(
group="scheduler",
package="scheduler",
name="warmup_cosine",
node=WarmupCosineLRConfig,
)
def register_loss(cs):
cs.store(group="loss", package="model.config.loss", name="video", node=VideoLossConfig)
def register_metric(cs):
cs.store(group="metric", package="model.config.metric", name="reconstruction", node=MetricConfig)
cs.store(group="metric", package="model.config.metric", name="code_usage", node=DiscreteTokenizerMetricConfig)
def register_checkpoint(cs):
cs.store(group="checkpoint", package="checkpoint", name="local", node=CHECKPOINT_LOCAL)
def register_callback(cs):
cs.store(group="callbacks", package="trainer.callbacks", name="basic", node=BASIC_CALLBACKS)
def register_configs():
cs = ConfigStore.instance()
register_training_data(cs)
register_val_data(cs)
register_net(cs)
register_optim(cs)
register_scheduler(cs)
register_loss(cs)
register_metric(cs)
register_checkpoint(cs)
register_callback(cs)
|