File size: 1,136 Bytes
e8bdafd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from pathlib import Path
from typing import Any, Dict, List

import torch
from pydantic import BaseModel


class State(BaseModel):
    model_config = {"arbitrary_types_allowed": True}

    train_frames: int  # user-defined training frames, **containing one image padding frame**
    train_height: int
    train_width: int

    transformer_config: Dict[str, Any] = None
    controlnetxs_transformer_config: Dict[str, Any] = None

    weight_dtype: torch.dtype = torch.float32
    num_trainable_parameters: int = 0
    overwrote_max_train_steps: bool = False
    num_update_steps_per_epoch: int = 0
    total_batch_size_count: int = 0

    generator: torch.Generator | None = None

    validation_prompts: List[str] = []
    validation_images: List[Path | None] = []
    validation_videos: List[Path | None] = []

    image: torch.Tensor = None  # C, H, W;  value in [0, 255]
    video: torch.Tensor = None  # F, C, H, W; value in [-1, 1]
    prompt_embedding: torch.Tensor = None  # L, D
    prompt: str = None
    plucker_embedding: torch.Tensor = None  # F, 6, H, W
    timestep: torch.Tensor = None

    using_deepspeed: bool = False