File size: 8,774 Bytes
1034391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e43723b
1034391
 
 
 
 
 
e43723b
 
 
 
 
 
 
 
 
 
 
 
 
 
1034391
 
e43723b
 
 
 
 
 
 
 
 
 
 
 
 
 
1034391
 
 
 
 
 
e43723b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1034391
 
e43723b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1034391
 
e43723b
1034391
 
 
e43723b
 
1034391
 
 
 
e43723b
 
 
 
 
 
 
 
 
 
 
1034391
 
e43723b
 
4aa0f34
e43723b
 
 
 
4aa0f34
e43723b
 
 
 
 
 
 
 
 
1034391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""Configuration management module for the Dia model.

This module provides comprehensive configuration management for the Dia model,
utilizing Pydantic for validation. It defines configurations for data processing,
model architecture (encoder and decoder), and training settings.

Key components:
- DataConfig: Parameters for data loading and preprocessing.
- EncoderConfig: Architecture details for the encoder module.
- DecoderConfig: Architecture details for the decoder module.
- ModelConfig: Combined model architecture settings.
- TrainingConfig: Training hyperparameters and settings.
- DiaConfig: Master configuration combining all components.
"""

import os

from pydantic import BaseModel, Field


class EncoderConfig(BaseModel, frozen=True):
    """Configuration for the encoder component of the Dia model.

    Attributes:
        model_type: Type of the model, defaults to "dia_encoder".
        hidden_size: Size of the encoder layers, defaults to 1024.
        intermediate_size: Size of the "intermediate" (i.e., feed-forward) layer in the encoder, defaults to 4096.
        num_hidden_layers: Number of hidden layers in the encoder, defaults to 12.
        num_attention_heads: Number of attention heads in the encoder, defaults to 16.
        num_key_value_heads: Number of key-value heads in the encoder, defaults to 16.
        head_dim: Dimension of each attention head, defaults to 128.
        hidden_act: Activation function in the encoder, defaults to "silu".
        max_position_embeddings: Maximum number of position embeddings, defaults to 1024.
        initializer_range: Range for initializing weights, defaults to 0.02.
        norm_eps: Epsilon value for normalization layers, defaults to 1e-5.
        rope_theta: Theta value for RoPE, defaults to 10000.0.
        rope_scaling: Optional scaling factor for RoPE.
        vocab_size: Vocabulary size, defaults to 256.
    """

    head_dim: int = Field(default=128, gt=0)
    hidden_act: str = Field(default="silu")
    hidden_size: int = Field(default=1024, gt=0)
    initializer_range: float = Field(default=0.02)
    intermediate_size: int = Field(default=4096, gt=0)
    max_position_embeddings: int = Field(default=1024, gt=0)
    model_type: str = Field(default="dia_encoder")
    norm_eps: float = Field(default=1e-5)
    num_attention_heads: int = Field(default=16, gt=0)
    num_hidden_layers: int = Field(default=12, gt=0)
    num_key_value_heads: int = Field(default=16, gt=0)
    rope_scaling: float | None = Field(default=None)
    rope_theta: float = Field(default=10000.0)
    vocab_size: int = Field(default=256, gt=0)


class DecoderConfig(BaseModel, frozen=True):
    """Configuration for the decoder component of the Dia model.

    Attributes:
        model_type: Type of the model, defaults to "dia_decoder".
        hidden_size: Size of the decoder layers, defaults to 2048.
        intermediate_size: Size of the "intermediate" (i.e., feed-forward) layer in the decoder, defaults to 8192.
        num_hidden_layers: Number of hidden layers in the decoder, defaults to 18.
        num_attention_heads: Number of attention heads in the decoder, defaults to 16.
        num_key_value_heads: Number of key-value heads in the decoder, defaults to 4.
        head_dim: Dimension of each attention head, defaults to 128.
        cross_hidden_size: Size of the cross-attention layers, defaults to 1024.
        cross_num_attention_heads: Number of attention heads in the cross-attention mechanism, defaults to 16.
        cross_num_key_value_heads: Number of key-value heads in the cross-attention mechanism, defaults to 16.
        cross_head_dim: Dimension of each cross-attention head, defaults to 128.
        hidden_act: Activation function in the decoder, defaults to "silu".
        max_position_embeddings: Maximum number of position embeddings in the decoder, defaults to 3072.
        initializer_range: Range for initializing weights in the decoder, defaults to 0.02.
        norm_eps: Epsilon value for normalization layers in the decoder, defaults to 1e-5.
        rope_theta: Theta value for RoPE in the decoder, defaults to 10000.0.
        rope_scaling: Optional scaling factor for RoPE in the decoder.
        vocab_size: Vocabulary size for the decoder, defaults to 1028.
        num_channels: Number of channels in the decoder, defaults to 9.
    """

    cross_head_dim: int = Field(default=128, gt=0)
    cross_hidden_size: int = Field(default=1024, gt=0)
    cross_num_attention_heads: int = Field(default=16, gt=0)
    cross_num_key_value_heads: int = Field(default=16, gt=0)
    head_dim: int = Field(default=128, gt=0)
    hidden_act: str = Field(default="silu")
    hidden_size: int = Field(default=2048, gt=0)
    initializer_range: float = Field(default=0.02)
    intermediate_size: int = Field(default=8192, gt=0)
    max_position_embeddings: int = Field(default=3072, gt=0)
    model_type: str = Field(default="dia_decoder")
    norm_eps: float = Field(default=1e-5)
    num_attention_heads: int = Field(default=16, gt=0)
    num_channels: int = Field(default=9, gt=0)
    num_hidden_layers: int = Field(default=18, gt=0)
    num_key_value_heads: int = Field(default=4, gt=0)
    rope_scaling: float | None = Field(default=None)
    rope_theta: float = Field(default=10000.0)
    vocab_size: int = Field(default=1028, gt=0)


class DiaConfig(BaseModel, frozen=True):
    """Main configuration container for the Dia model architecture.

    Attributes:
        model_type: Type of the model, defaults to "dia".
        is_encoder_decoder: Flag indicating if the model is an encoder-decoder type, defaults to True.
        encoder: Configuration for the encoder component.
        decoder: Configuration for the decoder component.
        src_vocab_size: Size of the source (text) vocabulary.
        tgt_vocab_size: Size of the target (audio code) vocabulary.
        initializer_range: Range for initializing weights, defaults to 0.02.
        norm_eps: Epsilon value for normalization layers, defaults to 1e-5.
        torch_dtype: Data type for model weights in PyTorch, defaults to "float32".
        bos_token_id: Beginning-of-sequence token ID, defaults to 1026.
        eos_token_id: End-of-sequence token ID, defaults to 1024.
        pad_token_id: Padding token ID, defaults to 1025.
        rope_theta: Theta value for RoPE, defaults to 10000.0.
        rope_scaling: Optional scaling factor for RoPE.
        transformers_version: Version of the transformers library, defaults to "4.53.0.dev0".
        architectures: List of model architectures, defaults to ["DiaForConditionalGeneration"].
        delay_pattern: List of delay values for each audio channel, defaults to [0,8,9,10,11,12,13,14,15].
    """

    architectures: list[str] = Field(
        default_factory=lambda: ["DiaForConditionalGeneration"]
    )
    bos_token_id: int = Field(default=1026)
    decoder_config: DecoderConfig
    delay_pattern: list[int] = Field(
        default_factory=lambda: [0, 8, 9, 10, 11, 12, 13, 14, 15]
    )
    encoder_config: EncoderConfig
    eos_token_id: int = Field(default=1024)
    initializer_range: float = Field(default=0.02)
    is_encoder_decoder: bool = Field(default=True)
    model_type: str = Field(default="dia")
    norm_eps: float = Field(default=1e-5)
    pad_token_id: int = Field(default=1025)
    torch_dtype: str = Field(default="float32")
    transformers_version: str = Field(default="4.53.0.dev0")

    def save(self, path: str) -> None:
        """Save the current configuration instance to a JSON file.

        Ensures the parent directory exists and the file has a .json extension.

        Args:
            path: The target file path to save the configuration.

        Raises:
            ValueError: If the path is not a file with a .json extension.
        """
        os.makedirs(os.path.dirname(path), exist_ok=True)
        config_json = self.model_dump_json(indent=2)
        with open(path, "w") as f:
            f.write(config_json)

    @classmethod
    def load(cls, path: str) -> "DiaConfig | None":
        """Load and validate a Dia configuration from a JSON file.

        Args:
            path: The path to the configuration file.

        Returns:
            A validated DiaConfig instance if the file exists and is valid,
            otherwise None if the file is not found.

        Raises:
            ValueError: If the path does not point to an existing .json file.
            pydantic.ValidationError: If the JSON content fails validation against the DiaConfig schema.
        """
        try:
            with open(path, "r") as f:
                content = f.read()
            return cls.model_validate_json(content)
        except FileNotFoundError:
            return None