File size: 1,043 Bytes
59d751c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from pathlib import Path
from typing import Any, Dict, List

import torch
from pydantic import BaseModel


class State(BaseModel):
    model_config = {"arbitrary_types_allowed": True}

    train_frames: int
    train_height: int
    train_width: int

    transformer_config: Dict[str, Any] = None

    weight_dtype: torch.dtype = torch.float32  # dtype for mixed precision training
    num_trainable_parameters: int = 0
    overwrote_max_train_steps: bool = False
    num_update_steps_per_epoch: int = 0
    total_batch_size_count: int = 0

    generator: torch.Generator | None = None

    validation_prompts: List[str] = []
    validation_images: List[Path | None] = []
    validation_videos: List[Path | None] = []

    # WJ: Added..
    validation_prompt_embeddings: List[Path | None] = []
    validation_video_latents: List[Path | None] = []
    validation_flow_latents: List[Path | None] = []
    validation_valid_masks: List[Path | None] = []
    validation_valid_masks_interp: List[Path | None] = []

    using_deepspeed: bool = False