| from dataclasses import dataclass, field |
| from typing import Dict, List, Optional |
|
|
|
|
| @dataclass(frozen=True) |
| class TextConfig: |
| dim: int = 2048 |
| n_layers: int = 24 |
| vocab_size: int = 51200 |
| max_context: int = 2048 |
| n_heads: int = 32 |
| prefix_attn: int = 730 |
|
|
|
|
| @dataclass(frozen=True) |
| class VisionConfig: |
| enc_dim: int = 1152 |
| enc_patch_size: int = 14 |
| enc_n_layers: int = 27 |
| enc_ff_dim: int = 4304 |
| enc_n_heads: int = 16 |
| proj_out_dim: int = 2048 |
| crop_size: int = 378 |
| in_channels: int = 3 |
| max_crops: int = 12 |
| overlap_margin: int = 4 |
| proj_inner_dim: int = 8192 |
|
|
|
|
| @dataclass(frozen=True) |
| class RegionConfig: |
| dim: int = 2048 |
| coord_feat_dim: int = 256 |
| coord_out_dim: int = 1024 |
| size_feat_dim: int = 512 |
| size_out_dim: int = 2048 |
| inner_dim: int = 8192 |
|
|
|
|
| @dataclass(frozen=True) |
| class TokenizerConfig: |
| bos_id: int = 50256 |
| eos_id: int = 50256 |
| templates: Dict[str, Optional[Dict[str, List[int]]]] = field( |
| default_factory=lambda: { |
| "caption": { |
| "short": [198, 198, 16438, 8305, 25], |
| "normal": [198, 198, 24334, 1159, 25], |
| }, |
| "query": {"prefix": [198, 198, 24361, 25], "suffix": [198, 198, 33706, 25]}, |
| "detect": {"prefix": [198, 198, 47504, 25], "suffix": [628]}, |
| "point": {"prefix": [198, 198, 12727, 25], "suffix": [628]}, |
| } |
| ) |
|
|
|
|
| @dataclass(frozen=True) |
| class MoondreamConfig: |
| text: TextConfig = TextConfig() |
| vision: VisionConfig = VisionConfig() |
| region: RegionConfig = RegionConfig() |
| tokenizer: TokenizerConfig = TokenizerConfig() |
|
|
| @classmethod |
| def from_dict(cls, config_dict: dict): |
| text_config = TextConfig(**config_dict.get("text", {})) |
| vision_config = VisionConfig(**config_dict.get("vision", {})) |
| region_config = RegionConfig(**config_dict.get("region", {})) |
| tokenizer_config = TokenizerConfig(**config_dict.get("tokenizer", {})) |
| return cls( |
| text=text_config, |
| vision=vision_config, |
| region=region_config, |
| tokenizer=tokenizer_config, |
| ) |
|
|
| def to_dict(self): |
| return { |
| "text": self.text.__dict__, |
| "vision": self.vision.__dict__, |
| "region": self.region.__dict__, |
| "tokenizer": self.tokenizer.__dict__, |
| } |
|
|