File size: 2,880 Bytes
59d751c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#!/usr/bin/env bash

# Prevent tokenizer parallelism issues
export TOKENIZERS_PARALLELISM=false
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

# Model Configuration
MODEL_ARGS=(
    --model_path "THUDM/CogVideoX-5b-I2V"
    --model_name "cogvideox-flovd"  # ["cogvideox-i2v" or "cogvideox-flovd"]
    --model_type "i2vFlow" # ["t2v", "i2v", "i2vFlow"]
    --training_type "controlnet"
    # --additional_save_blocks "patch_embed" # additional blocks to update when using lora. e.g., "patch_embed,text_proj"
)

# Output Configuration
OUTPUT_ARGS=(
    --output_dir "absolute/path/to/output"
    --report_to "wandb"
    --run_name "FloVD_CogVideoX_controlnet"
)

# Data Configuration
DATA_ARGS=(
    --data_root "absolute/path/to/whole_data"
    --caption_column "prompt.txt" # Do not need
    --video_column "videos.txt" # Do not need
    # --image_column "images.txt"  # comment this line will use first frame of video as image conditioning
    --train_resolution "49x480x720"  # (frames x height x width), frames should be 8N+1
)

# Training Configuration
TRAIN_ARGS=(
    --train_epochs 10 # number of training epochs
    --seed 42 # random seed
    --batch_size 1
    --gradient_accumulation_steps 2
    --mixed_precision "bf16"  # ["no", "fp16"] # Only CogVideoX-2B supports fp16 training
    --learning_rate 1e-5
)

# System Configuration
SYSTEM_ARGS=(
    --num_workers 8
    --pin_memory True
    --nccl_timeout 1800
)

# Checkpointing Configuration
CHECKPOINT_ARGS=(
    --checkpointing_steps 2000 # save checkpoint every x steps
    --checkpointing_limit 2 # maximum number of checkpoints to keep, after which the oldest one is deleted
    # --resume_from_checkpoint /path/to/ckpt # if you want to resume from a checkpoint, otherwise, comment this line
)

# Validation Configuration
VALIDATION_ARGS=(
    --do_validation true  # ["true", "false"]
    --validation_dir "absolute/path/to/whole_data"
    --validation_steps 2000  # should be multiple of checkpointing_steps
    --validation_prompts "prompts.txt" # Do not need
    --validation_images "images.txt" # Do not need
    --gen_fps 16
    --max_scene 4
)

# Controlnet Configuration
CONTROLNET_ARGS=(
    --controlnet_transformer_num_layers 6
    --controlnet_input_channels 16
    --controlnet_weights 1.0
    --controlnet_guidance_start 0.0
    --controlnet_guidance_end 0.4
    --controlnet_out_proj_dim_factor 64
    --enable_time_sampling false
    --time_sampling_type "truncated_normal"
    --time_sampling_mean 0.95
    --time_sampling_std 0.1
    --notextinflow true
)


# Combine all arguments and launch training
accelerate launch --config_file accelerate_config.yaml train.py \
    "${MODEL_ARGS[@]}" \
    "${OUTPUT_ARGS[@]}" \
    "${DATA_ARGS[@]}" \
    "${TRAIN_ARGS[@]}" \
    "${SYSTEM_ARGS[@]}" \
    "${CHECKPOINT_ARGS[@]}" \
    "${VALIDATION_ARGS[@]}" \
    "${CONTROLNET_ARGS[@]}"