Spaces:
Paused
Paused
Upload 177 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +4 -0
- finetune/accelerate_config.yaml +21 -0
- finetune/configs/zero2.yaml +38 -0
- finetune/configs/zero2_controlnet.yaml +38 -0
- finetune/configs/zero2_offload.yaml +42 -0
- finetune/configs/zero3.yaml +43 -0
- finetune/configs/zero3_offload.yaml +51 -0
- finetune/constants.py +2 -0
- finetune/datasets/__init__.py +14 -0
- finetune/datasets/bucket_sampler.py +71 -0
- finetune/datasets/i2v_dataset.py +311 -0
- finetune/datasets/i2v_flow_dataset.py +188 -0
- finetune/datasets/t2v_dataset.py +251 -0
- finetune/datasets/utils.py +211 -0
- finetune/models/__init__.py +12 -0
- finetune/models/cogvideox_i2v/flovd_OMSM_lora_trainer.py +748 -0
- finetune/models/cogvideox_i2v/flovd_controlnet_trainer.py +814 -0
- finetune/models/cogvideox_i2v/lora_trainer.py +246 -0
- finetune/models/cogvideox_i2v/sft_trainer.py +9 -0
- finetune/models/utils.py +57 -0
- finetune/modules/__init__.py +0 -0
- finetune/modules/camera_flow_generator.py +46 -0
- finetune/modules/camera_sampler.py +52 -0
- finetune/modules/cogvideox_controlnet.py +353 -0
- finetune/modules/cogvideox_custom_model.py +109 -0
- finetune/modules/cogvideox_custom_modules.py +357 -0
- finetune/modules/depth_warping/__init__.py +0 -0
- finetune/modules/depth_warping/camera/Camera.py +70 -0
- finetune/modules/depth_warping/camera/WarperPytorch.py +416 -0
- finetune/modules/depth_warping/depth_anything_v2/depth_anything_wrapper.py +12 -0
- finetune/modules/depth_warping/depth_anything_v2/dinov2.py +415 -0
- finetune/modules/depth_warping/depth_anything_v2/dinov2_layers/__init__.py +11 -0
- finetune/modules/depth_warping/depth_anything_v2/dinov2_layers/attention.py +83 -0
- finetune/modules/depth_warping/depth_anything_v2/dinov2_layers/block.py +252 -0
- finetune/modules/depth_warping/depth_anything_v2/dinov2_layers/drop_path.py +35 -0
- finetune/modules/depth_warping/depth_anything_v2/dinov2_layers/layer_scale.py +28 -0
- finetune/modules/depth_warping/depth_anything_v2/dinov2_layers/mlp.py +41 -0
- finetune/modules/depth_warping/depth_anything_v2/dinov2_layers/patch_embed.py +89 -0
- finetune/modules/depth_warping/depth_anything_v2/dinov2_layers/swiglu_ffn.py +63 -0
- finetune/modules/depth_warping/depth_anything_v2/dpt.py +235 -0
- finetune/modules/depth_warping/depth_anything_v2/util/blocks.py +148 -0
- finetune/modules/depth_warping/depth_anything_v2/util/transform.py +158 -0
- finetune/modules/depth_warping/depth_pro/__init__.py +5 -0
- finetune/modules/depth_warping/depth_pro/cli/__init__.py +4 -0
- finetune/modules/depth_warping/depth_pro/cli/run.py +154 -0
- finetune/modules/depth_warping/depth_pro/depth_pro.py +298 -0
- finetune/modules/depth_warping/depth_pro/eval/boundary_metrics.py +332 -0
- finetune/modules/depth_warping/depth_pro/eval/dis5k_sample_list.txt +200 -0
- finetune/modules/depth_warping/depth_pro/network/__init__.py +2 -0
- finetune/modules/depth_warping/depth_pro/network/decoder.py +206 -0
.gitattributes
CHANGED
@@ -73,3 +73,7 @@ assets/pages/res1.mp4 filter=lfs diff=lfs merge=lfs -text
|
|
73 |
assets/pages/res2.mp4 filter=lfs diff=lfs merge=lfs -text
|
74 |
assets/pages/res3.mp4 filter=lfs diff=lfs merge=lfs -text
|
75 |
assets/pages/teaser.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
73 |
assets/pages/res2.mp4 filter=lfs diff=lfs merge=lfs -text
|
74 |
assets/pages/res3.mp4 filter=lfs diff=lfs merge=lfs -text
|
75 |
assets/pages/teaser.png filter=lfs diff=lfs merge=lfs -text
|
76 |
+
results/generated_videos/A_chef_in_a_white_coat_and_gla_1593596b99e2dde9.txt.mp4 filter=lfs diff=lfs merge=lfs -text
|
77 |
+
results/generated_videos/A_stunning_and_untouched_coast_6b6d20c6a46b9fe9.txt.mp4 filter=lfs diff=lfs merge=lfs -text
|
78 |
+
tools/caption/assests/CogVLM2-Caption-example.png filter=lfs diff=lfs merge=lfs -text
|
79 |
+
tools/caption/assests/cogvlm2-video-example.png filter=lfs diff=lfs merge=lfs -text
|
finetune/accelerate_config.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
compute_environment: LOCAL_MACHINE
|
2 |
+
|
3 |
+
gpu_ids: "0,1,2,3,4,5,6,7"
|
4 |
+
num_processes: 8 # should be the same as the number of GPUs
|
5 |
+
|
6 |
+
debug: false
|
7 |
+
deepspeed_config:
|
8 |
+
deepspeed_config_file: configs/zero2_controlnet.yaml # e.g. configs/zero2.yaml, need use absolute path
|
9 |
+
zero3_init_flag: false
|
10 |
+
distributed_type: DEEPSPEED
|
11 |
+
downcast_bf16: 'no'
|
12 |
+
enable_cpu_affinity: false
|
13 |
+
machine_rank: 0
|
14 |
+
main_training_function: main
|
15 |
+
num_machines: 1
|
16 |
+
rdzv_backend: static
|
17 |
+
same_network: true
|
18 |
+
tpu_env: []
|
19 |
+
tpu_use_cluster: false
|
20 |
+
tpu_use_sudo: false
|
21 |
+
use_cpu: false
|
finetune/configs/zero2.yaml
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bf16": {
|
3 |
+
"enabled": true
|
4 |
+
},
|
5 |
+
"optimizer": {
|
6 |
+
"type": "AdamW",
|
7 |
+
"params": {
|
8 |
+
"lr": "auto",
|
9 |
+
"weight_decay": "auto",
|
10 |
+
"torch_adam": true,
|
11 |
+
"adam_w_mode": true
|
12 |
+
}
|
13 |
+
},
|
14 |
+
"scheduler": {
|
15 |
+
"type": "WarmupDecayLR",
|
16 |
+
"params": {
|
17 |
+
"warmup_min_lr": "auto",
|
18 |
+
"warmup_max_lr": "auto",
|
19 |
+
"warmup_num_steps": "auto",
|
20 |
+
"total_num_steps": "auto"
|
21 |
+
}
|
22 |
+
},
|
23 |
+
"zero_optimization": {
|
24 |
+
"stage": 2,
|
25 |
+
"allgather_partitions": true,
|
26 |
+
"allgather_bucket_size": 2e8,
|
27 |
+
"overlap_comm": true,
|
28 |
+
"reduce_scatter": true,
|
29 |
+
"reduce_bucket_size": 5e8,
|
30 |
+
"contiguous_gradients": true
|
31 |
+
},
|
32 |
+
"gradient_accumulation_steps": 1,
|
33 |
+
"train_micro_batch_size_per_gpu": 1,
|
34 |
+
"train_batch_size": "auto",
|
35 |
+
"gradient_clipping": "auto",
|
36 |
+
"steps_per_print": 2000,
|
37 |
+
"wall_clock_breakdown": false
|
38 |
+
}
|
finetune/configs/zero2_controlnet.yaml
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bf16": {
|
3 |
+
"enabled": true
|
4 |
+
},
|
5 |
+
"optimizer": {
|
6 |
+
"type": "AdamW",
|
7 |
+
"params": {
|
8 |
+
"lr": "auto",
|
9 |
+
"weight_decay": "auto",
|
10 |
+
"torch_adam": true,
|
11 |
+
"adam_w_mode": true
|
12 |
+
}
|
13 |
+
},
|
14 |
+
"scheduler": {
|
15 |
+
"type": "WarmupCosineLR",
|
16 |
+
"params": {
|
17 |
+
"warmup_min_ratio": 0.0,
|
18 |
+
"cos_min_ratio": 0.0001,
|
19 |
+
"warmup_num_steps": 250,
|
20 |
+
"total_num_steps": "auto"
|
21 |
+
}
|
22 |
+
},
|
23 |
+
"zero_optimization": {
|
24 |
+
"stage": 2,
|
25 |
+
"allgather_partitions": true,
|
26 |
+
"allgather_bucket_size": 2e8,
|
27 |
+
"overlap_comm": true,
|
28 |
+
"reduce_scatter": true,
|
29 |
+
"reduce_bucket_size": 5e8,
|
30 |
+
"contiguous_gradients": true
|
31 |
+
},
|
32 |
+
"gradient_accumulation_steps": 1,
|
33 |
+
"train_micro_batch_size_per_gpu": 1,
|
34 |
+
"train_batch_size": "auto",
|
35 |
+
"gradient_clipping": "auto",
|
36 |
+
"steps_per_print": 2000,
|
37 |
+
"wall_clock_breakdown": false
|
38 |
+
}
|
finetune/configs/zero2_offload.yaml
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bf16": {
|
3 |
+
"enabled": true
|
4 |
+
},
|
5 |
+
"optimizer": {
|
6 |
+
"type": "AdamW",
|
7 |
+
"params": {
|
8 |
+
"lr": "auto",
|
9 |
+
"weight_decay": "auto",
|
10 |
+
"torch_adam": true,
|
11 |
+
"adam_w_mode": true
|
12 |
+
}
|
13 |
+
},
|
14 |
+
"scheduler": {
|
15 |
+
"type": "WarmupDecayLR",
|
16 |
+
"params": {
|
17 |
+
"warmup_min_lr": "auto",
|
18 |
+
"warmup_max_lr": "auto",
|
19 |
+
"warmup_num_steps": "auto",
|
20 |
+
"total_num_steps": "auto"
|
21 |
+
}
|
22 |
+
},
|
23 |
+
"zero_optimization": {
|
24 |
+
"stage": 2,
|
25 |
+
"allgather_partitions": true,
|
26 |
+
"allgather_bucket_size": 2e8,
|
27 |
+
"overlap_comm": true,
|
28 |
+
"reduce_scatter": true,
|
29 |
+
"reduce_bucket_size": 5e8,
|
30 |
+
"contiguous_gradients": true,
|
31 |
+
"offload_optimizer": {
|
32 |
+
"device": "cpu",
|
33 |
+
"pin_memory": true
|
34 |
+
}
|
35 |
+
},
|
36 |
+
"gradient_accumulation_steps": 1,
|
37 |
+
"train_micro_batch_size_per_gpu": 1,
|
38 |
+
"train_batch_size": "auto",
|
39 |
+
"gradient_clipping": "auto",
|
40 |
+
"steps_per_print": 2000,
|
41 |
+
"wall_clock_breakdown": false
|
42 |
+
}
|
finetune/configs/zero3.yaml
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bf16": {
|
3 |
+
"enabled": true
|
4 |
+
},
|
5 |
+
"optimizer": {
|
6 |
+
"type": "AdamW",
|
7 |
+
"params": {
|
8 |
+
"lr": "auto",
|
9 |
+
"weight_decay": "auto",
|
10 |
+
"torch_adam": true,
|
11 |
+
"adam_w_mode": true
|
12 |
+
}
|
13 |
+
},
|
14 |
+
"scheduler": {
|
15 |
+
"type": "WarmupDecayLR",
|
16 |
+
"params": {
|
17 |
+
"warmup_min_lr": "auto",
|
18 |
+
"warmup_max_lr": "auto",
|
19 |
+
"warmup_num_steps": "auto",
|
20 |
+
"total_num_steps": "auto"
|
21 |
+
}
|
22 |
+
},
|
23 |
+
"zero_optimization": {
|
24 |
+
"stage": 3,
|
25 |
+
"overlap_comm": true,
|
26 |
+
"contiguous_gradients": true,
|
27 |
+
"reduce_bucket_size": 5e8,
|
28 |
+
"stage3_prefetch_bucket_size": "auto",
|
29 |
+
"stage3_param_persistence_threshold": "auto",
|
30 |
+
"sub_group_size": 1e9,
|
31 |
+
"stage3_max_live_parameters": 1e9,
|
32 |
+
"stage3_max_reuse_distance": 1e9,
|
33 |
+
"stage3_gather_16bit_weights_on_model_save": "auto",
|
34 |
+
"stage3_prefetch_bucket_size": 5e8,
|
35 |
+
"stage3_param_persistence_threshold": 1e5
|
36 |
+
},
|
37 |
+
"gradient_accumulation_steps": 1,
|
38 |
+
"train_micro_batch_size_per_gpu": 1,
|
39 |
+
"train_batch_size": "auto",
|
40 |
+
"gradient_clipping": "auto",
|
41 |
+
"steps_per_print": 2000,
|
42 |
+
"wall_clock_breakdown": false
|
43 |
+
}
|
finetune/configs/zero3_offload.yaml
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bf16": {
|
3 |
+
"enabled": true
|
4 |
+
},
|
5 |
+
"optimizer": {
|
6 |
+
"type": "AdamW",
|
7 |
+
"params": {
|
8 |
+
"lr": "auto",
|
9 |
+
"weight_decay": "auto",
|
10 |
+
"torch_adam": true,
|
11 |
+
"adam_w_mode": true
|
12 |
+
}
|
13 |
+
},
|
14 |
+
"scheduler": {
|
15 |
+
"type": "WarmupDecayLR",
|
16 |
+
"params": {
|
17 |
+
"warmup_min_lr": "auto",
|
18 |
+
"warmup_max_lr": "auto",
|
19 |
+
"warmup_num_steps": "auto",
|
20 |
+
"total_num_steps": "auto"
|
21 |
+
}
|
22 |
+
},
|
23 |
+
"zero_optimization": {
|
24 |
+
"stage": 3,
|
25 |
+
"offload_optimizer": {
|
26 |
+
"device": "cpu",
|
27 |
+
"pin_memory": true
|
28 |
+
},
|
29 |
+
"offload_param": {
|
30 |
+
"device": "cpu",
|
31 |
+
"pin_memory": true
|
32 |
+
},
|
33 |
+
"overlap_comm": true,
|
34 |
+
"contiguous_gradients": true,
|
35 |
+
"reduce_bucket_size": 5e8,
|
36 |
+
"stage3_prefetch_bucket_size": "auto",
|
37 |
+
"stage3_param_persistence_threshold": "auto",
|
38 |
+
"sub_group_size": 1e9,
|
39 |
+
"stage3_max_live_parameters": 1e9,
|
40 |
+
"stage3_max_reuse_distance": 1e9,
|
41 |
+
"stage3_gather_16bit_weights_on_model_save": "auto",
|
42 |
+
"stage3_prefetch_bucket_size": 5e8,
|
43 |
+
"stage3_param_persistence_threshold": 1e6
|
44 |
+
},
|
45 |
+
"gradient_accumulation_steps": 1,
|
46 |
+
"train_micro_batch_size_per_gpu": 1,
|
47 |
+
"train_batch_size": "auto",
|
48 |
+
"gradient_clipping": "auto",
|
49 |
+
"steps_per_print": 2000,
|
50 |
+
"wall_clock_breakdown": false
|
51 |
+
}
|
finetune/constants.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
LOG_NAME = "trainer"
|
2 |
+
LOG_LEVEL = "INFO"
|
finetune/datasets/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .bucket_sampler import BucketSampler
|
2 |
+
from .i2v_dataset import I2VDatasetWithBuckets, I2VDatasetWithResize
|
3 |
+
from .t2v_dataset import T2VDatasetWithBuckets, T2VDatasetWithResize
|
4 |
+
from .i2v_flow_dataset import I2VFlowDataset
|
5 |
+
|
6 |
+
|
7 |
+
__all__ = [
|
8 |
+
"I2VDatasetWithResize",
|
9 |
+
"I2VDatasetWithBuckets",
|
10 |
+
"T2VDatasetWithResize",
|
11 |
+
"T2VDatasetWithBuckets",
|
12 |
+
"BucketSampler",
|
13 |
+
"I2VFlowDataset",
|
14 |
+
]
|
finetune/datasets/bucket_sampler.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import random
|
3 |
+
|
4 |
+
from torch.utils.data import Dataset, Sampler
|
5 |
+
|
6 |
+
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
|
9 |
+
|
10 |
+
class BucketSampler(Sampler):
|
11 |
+
r"""
|
12 |
+
PyTorch Sampler that groups 3D data by height, width and frames.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
data_source (`VideoDataset`):
|
16 |
+
A PyTorch dataset object that is an instance of `VideoDataset`.
|
17 |
+
batch_size (`int`, defaults to `8`):
|
18 |
+
The batch size to use for training.
|
19 |
+
shuffle (`bool`, defaults to `True`):
|
20 |
+
Whether or not to shuffle the data in each batch before dispatching to dataloader.
|
21 |
+
drop_last (`bool`, defaults to `False`):
|
22 |
+
Whether or not to drop incomplete buckets of data after completely iterating over all data
|
23 |
+
in the dataset. If set to True, only batches that have `batch_size` number of entries will
|
24 |
+
be yielded. If set to False, it is guaranteed that all data in the dataset will be processed
|
25 |
+
and batches that do not have `batch_size` number of entries will also be yielded.
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self, data_source: Dataset, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False
|
30 |
+
) -> None:
|
31 |
+
self.data_source = data_source
|
32 |
+
self.batch_size = batch_size
|
33 |
+
self.shuffle = shuffle
|
34 |
+
self.drop_last = drop_last
|
35 |
+
|
36 |
+
self.buckets = {resolution: [] for resolution in data_source.video_resolution_buckets}
|
37 |
+
|
38 |
+
self._raised_warning_for_drop_last = False
|
39 |
+
|
40 |
+
def __len__(self):
|
41 |
+
if self.drop_last and not self._raised_warning_for_drop_last:
|
42 |
+
self._raised_warning_for_drop_last = True
|
43 |
+
logger.warning(
|
44 |
+
"Calculating the length for bucket sampler is not possible when `drop_last` is set to True. This may cause problems when setting the number of epochs used for training."
|
45 |
+
)
|
46 |
+
return (len(self.data_source) + self.batch_size - 1) // self.batch_size
|
47 |
+
|
48 |
+
def __iter__(self):
|
49 |
+
for index, data in enumerate(self.data_source):
|
50 |
+
video_metadata = data["video_metadata"]
|
51 |
+
f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"]
|
52 |
+
|
53 |
+
self.buckets[(f, h, w)].append(data)
|
54 |
+
if len(self.buckets[(f, h, w)]) == self.batch_size:
|
55 |
+
if self.shuffle:
|
56 |
+
random.shuffle(self.buckets[(f, h, w)])
|
57 |
+
yield self.buckets[(f, h, w)]
|
58 |
+
del self.buckets[(f, h, w)]
|
59 |
+
self.buckets[(f, h, w)] = []
|
60 |
+
|
61 |
+
if self.drop_last:
|
62 |
+
return
|
63 |
+
|
64 |
+
for fhw, bucket in list(self.buckets.items()):
|
65 |
+
if len(bucket) == 0:
|
66 |
+
continue
|
67 |
+
if self.shuffle:
|
68 |
+
random.shuffle(bucket)
|
69 |
+
yield bucket
|
70 |
+
del self.buckets[fhw]
|
71 |
+
self.buckets[fhw] = []
|
finetune/datasets/i2v_dataset.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from accelerate.logging import get_logger
|
7 |
+
from safetensors.torch import load_file, save_file
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
from torchvision import transforms
|
10 |
+
from typing_extensions import override
|
11 |
+
|
12 |
+
from finetune.constants import LOG_LEVEL, LOG_NAME
|
13 |
+
|
14 |
+
from .utils import (
|
15 |
+
load_images,
|
16 |
+
load_images_from_videos,
|
17 |
+
load_prompts,
|
18 |
+
load_videos,
|
19 |
+
preprocess_image_with_resize,
|
20 |
+
preprocess_video_with_buckets,
|
21 |
+
preprocess_video_with_resize,
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
if TYPE_CHECKING:
|
26 |
+
from finetune.trainer import Trainer
|
27 |
+
|
28 |
+
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
|
29 |
+
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
|
30 |
+
import decord # isort:skip
|
31 |
+
|
32 |
+
decord.bridge.set_bridge("torch")
|
33 |
+
|
34 |
+
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
35 |
+
|
36 |
+
|
37 |
+
class BaseI2VDataset(Dataset):
|
38 |
+
"""
|
39 |
+
Base dataset class for Image-to-Video (I2V) training.
|
40 |
+
|
41 |
+
This dataset loads prompts, videos and corresponding conditioning images for I2V training.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
data_root (str): Root directory containing the dataset files
|
45 |
+
caption_column (str): Path to file containing text prompts/captions
|
46 |
+
video_column (str): Path to file containing video paths
|
47 |
+
image_column (str): Path to file containing image paths
|
48 |
+
device (torch.device): Device to load the data on
|
49 |
+
encode_video_fn (Callable[[torch.Tensor], torch.Tensor], optional): Function to encode videos
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
data_root: str,
|
55 |
+
caption_column: str,
|
56 |
+
video_column: str,
|
57 |
+
image_column: str | None,
|
58 |
+
device: torch.device,
|
59 |
+
trainer: "Trainer" = None,
|
60 |
+
*args,
|
61 |
+
**kwargs,
|
62 |
+
) -> None:
|
63 |
+
super().__init__()
|
64 |
+
|
65 |
+
data_root = Path(data_root)
|
66 |
+
self.prompts = load_prompts(data_root / caption_column)
|
67 |
+
self.videos = load_videos(data_root / video_column)
|
68 |
+
if image_column is not None:
|
69 |
+
self.images = load_images(data_root / image_column)
|
70 |
+
else:
|
71 |
+
self.images = load_images_from_videos(self.videos)
|
72 |
+
self.trainer = trainer
|
73 |
+
|
74 |
+
self.device = device
|
75 |
+
self.encode_video = trainer.encode_video
|
76 |
+
self.encode_text = trainer.encode_text
|
77 |
+
|
78 |
+
# Check if number of prompts matches number of videos and images
|
79 |
+
if not (len(self.videos) == len(self.prompts) == len(self.images)):
|
80 |
+
raise ValueError(
|
81 |
+
f"Expected length of prompts, videos and images to be the same but found {len(self.prompts)=}, {len(self.videos)=} and {len(self.images)=}. Please ensure that the number of caption prompts, videos and images match in your dataset."
|
82 |
+
)
|
83 |
+
|
84 |
+
# Check if all video files exist
|
85 |
+
if any(not path.is_file() for path in self.videos):
|
86 |
+
raise ValueError(
|
87 |
+
f"Some video files were not found. Please ensure that all video files exist in the dataset directory. Missing file: {next(path for path in self.videos if not path.is_file())}"
|
88 |
+
)
|
89 |
+
|
90 |
+
# Check if all image files exist
|
91 |
+
if any(not path.is_file() for path in self.images):
|
92 |
+
raise ValueError(
|
93 |
+
f"Some image files were not found. Please ensure that all image files exist in the dataset directory. Missing file: {next(path for path in self.images if not path.is_file())}"
|
94 |
+
)
|
95 |
+
|
96 |
+
def __len__(self) -> int:
|
97 |
+
return len(self.videos)
|
98 |
+
|
99 |
+
def __getitem__(self, index: int) -> Dict[str, Any]:
|
100 |
+
if isinstance(index, list):
|
101 |
+
# Here, index is actually a list of data objects that we need to return.
|
102 |
+
# The BucketSampler should ideally return indices. But, in the sampler, we'd like
|
103 |
+
# to have information about num_frames, height and width. Since this is not stored
|
104 |
+
# as metadata, we need to read the video to get this information. You could read this
|
105 |
+
# information without loading the full video in memory, but we do it anyway. In order
|
106 |
+
# to not load the video twice (once to get the metadata, and once to return the loaded video
|
107 |
+
# based on sampled indices), we cache it in the BucketSampler. When the sampler is
|
108 |
+
# to yield, we yield the cache data instead of indices. So, this special check ensures
|
109 |
+
# that data is not loaded a second time. PRs are welcome for improvements.
|
110 |
+
return index
|
111 |
+
|
112 |
+
prompt = self.prompts[index]
|
113 |
+
video = self.videos[index]
|
114 |
+
image = self.images[index]
|
115 |
+
train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
|
116 |
+
|
117 |
+
cache_dir = self.trainer.args.data_root / "cache"
|
118 |
+
video_latent_dir = cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str
|
119 |
+
prompt_embeddings_dir = cache_dir / "prompt_embeddings"
|
120 |
+
video_latent_dir.mkdir(parents=True, exist_ok=True)
|
121 |
+
prompt_embeddings_dir.mkdir(parents=True, exist_ok=True)
|
122 |
+
|
123 |
+
prompt_hash = str(hashlib.sha256(prompt.encode()).hexdigest())
|
124 |
+
prompt_embedding_path = prompt_embeddings_dir / (prompt_hash + ".safetensors")
|
125 |
+
encoded_video_path = video_latent_dir / (video.stem + ".safetensors")
|
126 |
+
|
127 |
+
if prompt_embedding_path.exists():
|
128 |
+
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
|
129 |
+
logger.debug(
|
130 |
+
f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}",
|
131 |
+
main_process_only=False,
|
132 |
+
)
|
133 |
+
else:
|
134 |
+
prompt_embedding = self.encode_text(prompt)
|
135 |
+
prompt_embedding = prompt_embedding.to("cpu")
|
136 |
+
# [1, seq_len, hidden_size] -> [seq_len, hidden_size]
|
137 |
+
prompt_embedding = prompt_embedding[0]
|
138 |
+
save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path)
|
139 |
+
logger.info(f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False)
|
140 |
+
|
141 |
+
if encoded_video_path.exists():
|
142 |
+
encoded_video = load_file(encoded_video_path)["encoded_video"]
|
143 |
+
logger.debug(f"Loaded encoded video from {encoded_video_path}", main_process_only=False)
|
144 |
+
# shape of image: [C, H, W]
|
145 |
+
_, image = self.preprocess(None, self.images[index])
|
146 |
+
image = self.image_transform(image)
|
147 |
+
else:
|
148 |
+
frames, image = self.preprocess(video, image)
|
149 |
+
frames = frames.to(self.device)
|
150 |
+
image = image.to(self.device)
|
151 |
+
image = self.image_transform(image)
|
152 |
+
# Current shape of frames: [F, C, H, W]
|
153 |
+
frames = self.video_transform(frames)
|
154 |
+
|
155 |
+
# Convert to [B, C, F, H, W]
|
156 |
+
frames = frames.unsqueeze(0)
|
157 |
+
frames = frames.permute(0, 2, 1, 3, 4).contiguous()
|
158 |
+
encoded_video = self.encode_video(frames)
|
159 |
+
|
160 |
+
# [1, C, F, H, W] -> [C, F, H, W]
|
161 |
+
encoded_video = encoded_video[0]
|
162 |
+
encoded_video = encoded_video.to("cpu")
|
163 |
+
image = image.to("cpu")
|
164 |
+
save_file({"encoded_video": encoded_video}, encoded_video_path)
|
165 |
+
logger.info(f"Saved encoded video to {encoded_video_path}", main_process_only=False)
|
166 |
+
|
167 |
+
# shape of encoded_video: [C, F, H, W]
|
168 |
+
# shape of image: [C, H, W]
|
169 |
+
return {
|
170 |
+
"image": image,
|
171 |
+
"prompt_embedding": prompt_embedding,
|
172 |
+
"encoded_video": encoded_video,
|
173 |
+
"video_metadata": {
|
174 |
+
"num_frames": encoded_video.shape[1],
|
175 |
+
"height": encoded_video.shape[2],
|
176 |
+
"width": encoded_video.shape[3],
|
177 |
+
},
|
178 |
+
}
|
179 |
+
|
180 |
+
def preprocess(self, video_path: Path | None, image_path: Path | None) -> Tuple[torch.Tensor, torch.Tensor]:
|
181 |
+
"""
|
182 |
+
Loads and preprocesses a video and an image.
|
183 |
+
If either path is None, no preprocessing will be done for that input.
|
184 |
+
|
185 |
+
Args:
|
186 |
+
video_path: Path to the video file to load
|
187 |
+
image_path: Path to the image file to load
|
188 |
+
|
189 |
+
Returns:
|
190 |
+
A tuple containing:
|
191 |
+
- video(torch.Tensor) of shape [F, C, H, W] where F is number of frames,
|
192 |
+
C is number of channels, H is height and W is width
|
193 |
+
- image(torch.Tensor) of shape [C, H, W]
|
194 |
+
"""
|
195 |
+
raise NotImplementedError("Subclass must implement this method")
|
196 |
+
|
197 |
+
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
198 |
+
"""
|
199 |
+
Applies transformations to a video.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
frames (torch.Tensor): A 4D tensor representing a video
|
203 |
+
with shape [F, C, H, W] where:
|
204 |
+
- F is number of frames
|
205 |
+
- C is number of channels (3 for RGB)
|
206 |
+
- H is height
|
207 |
+
- W is width
|
208 |
+
|
209 |
+
Returns:
|
210 |
+
torch.Tensor: The transformed video tensor
|
211 |
+
"""
|
212 |
+
raise NotImplementedError("Subclass must implement this method")
|
213 |
+
|
214 |
+
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
|
215 |
+
"""
|
216 |
+
Applies transformations to an image.
|
217 |
+
|
218 |
+
Args:
|
219 |
+
image (torch.Tensor): A 3D tensor representing an image
|
220 |
+
with shape [C, H, W] where:
|
221 |
+
- C is number of channels (3 for RGB)
|
222 |
+
- H is height
|
223 |
+
- W is width
|
224 |
+
|
225 |
+
Returns:
|
226 |
+
torch.Tensor: The transformed image tensor
|
227 |
+
"""
|
228 |
+
raise NotImplementedError("Subclass must implement this method")
|
229 |
+
|
230 |
+
|
231 |
+
class I2VDatasetWithResize(BaseI2VDataset):
|
232 |
+
"""
|
233 |
+
A dataset class for image-to-video generation that resizes inputs to fixed dimensions.
|
234 |
+
|
235 |
+
This class preprocesses videos and images by resizing them to specified dimensions:
|
236 |
+
- Videos are resized to max_num_frames x height x width
|
237 |
+
- Images are resized to height x width
|
238 |
+
|
239 |
+
Args:
|
240 |
+
max_num_frames (int): Maximum number of frames to extract from videos
|
241 |
+
height (int): Target height for resizing videos and images
|
242 |
+
width (int): Target width for resizing videos and images
|
243 |
+
"""
|
244 |
+
|
245 |
+
def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None:
|
246 |
+
super().__init__(*args, **kwargs)
|
247 |
+
|
248 |
+
self.max_num_frames = max_num_frames
|
249 |
+
self.height = height
|
250 |
+
self.width = width
|
251 |
+
|
252 |
+
self.__frame_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
|
253 |
+
self.__image_transforms = self.__frame_transforms
|
254 |
+
|
255 |
+
@override
|
256 |
+
def preprocess(self, video_path: Path | None, image_path: Path | None) -> Tuple[torch.Tensor, torch.Tensor]:
|
257 |
+
if video_path is not None:
|
258 |
+
video = preprocess_video_with_resize(video_path, self.max_num_frames, self.height, self.width)
|
259 |
+
else:
|
260 |
+
video = None
|
261 |
+
if image_path is not None:
|
262 |
+
image = preprocess_image_with_resize(image_path, self.height, self.width)
|
263 |
+
else:
|
264 |
+
image = None
|
265 |
+
return video, image
|
266 |
+
|
267 |
+
@override
|
268 |
+
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
269 |
+
return torch.stack([self.__frame_transforms(f) for f in frames], dim=0)
|
270 |
+
|
271 |
+
@override
|
272 |
+
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
|
273 |
+
return self.__image_transforms(image)
|
274 |
+
|
275 |
+
|
276 |
+
class I2VDatasetWithBuckets(BaseI2VDataset):
|
277 |
+
def __init__(
|
278 |
+
self,
|
279 |
+
video_resolution_buckets: List[Tuple[int, int, int]],
|
280 |
+
vae_temporal_compression_ratio: int,
|
281 |
+
vae_height_compression_ratio: int,
|
282 |
+
vae_width_compression_ratio: int,
|
283 |
+
*args,
|
284 |
+
**kwargs,
|
285 |
+
) -> None:
|
286 |
+
super().__init__(*args, **kwargs)
|
287 |
+
|
288 |
+
self.video_resolution_buckets = [
|
289 |
+
(
|
290 |
+
int(b[0] / vae_temporal_compression_ratio),
|
291 |
+
int(b[1] / vae_height_compression_ratio),
|
292 |
+
int(b[2] / vae_width_compression_ratio),
|
293 |
+
)
|
294 |
+
for b in video_resolution_buckets
|
295 |
+
]
|
296 |
+
self.__frame_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
|
297 |
+
self.__image_transforms = self.__frame_transforms
|
298 |
+
|
299 |
+
@override
|
300 |
+
def preprocess(self, video_path: Path, image_path: Path) -> Tuple[torch.Tensor, torch.Tensor]:
|
301 |
+
video = preprocess_video_with_buckets(video_path, self.video_resolution_buckets)
|
302 |
+
image = preprocess_image_with_resize(image_path, video.shape[2], video.shape[3])
|
303 |
+
return video, image
|
304 |
+
|
305 |
+
@override
|
306 |
+
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
307 |
+
return torch.stack([self.__frame_transforms(f) for f in frames], dim=0)
|
308 |
+
|
309 |
+
@override
|
310 |
+
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
|
311 |
+
return self.__image_transforms(image)
|
finetune/datasets/i2v_flow_dataset.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
4 |
+
import json
|
5 |
+
import random
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from accelerate.logging import get_logger
|
9 |
+
from safetensors.torch import load_file, save_file
|
10 |
+
from torch.utils.data import Dataset
|
11 |
+
from torchvision import transforms
|
12 |
+
from typing_extensions import override
|
13 |
+
|
14 |
+
from finetune.constants import LOG_LEVEL, LOG_NAME
|
15 |
+
|
16 |
+
from .utils import (
|
17 |
+
load_images,
|
18 |
+
load_images_from_videos,
|
19 |
+
load_prompts,
|
20 |
+
load_videos,
|
21 |
+
preprocess_image_with_resize,
|
22 |
+
preprocess_video_with_buckets,
|
23 |
+
preprocess_video_with_resize,
|
24 |
+
load_binary_mask_compressed,
|
25 |
+
)
|
26 |
+
|
27 |
+
import pdb
|
28 |
+
|
29 |
+
if TYPE_CHECKING:
|
30 |
+
from finetune.trainer import Trainer
|
31 |
+
|
32 |
+
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
|
33 |
+
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
|
34 |
+
import decord # isort:skip
|
35 |
+
|
36 |
+
decord.bridge.set_bridge("torch")
|
37 |
+
|
38 |
+
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
39 |
+
|
40 |
+
|
41 |
+
class I2VFlowDataset(Dataset):
|
42 |
+
"""
|
43 |
+
A dataset class for (image,flow)-to-video generation or image-to-flow_video that resizes inputs to fixed dimensions.
|
44 |
+
|
45 |
+
This class preprocesses videos and images by resizing them to specified dimensions:
|
46 |
+
- Videos are resized to max_num_frames x height x width
|
47 |
+
- Images are resized to height x width
|
48 |
+
|
49 |
+
Args:
|
50 |
+
max_num_frames (int): Maximum number of frames to extract from videos
|
51 |
+
height (int): Target height for resizing videos and images
|
52 |
+
width (int): Target width for resizing videos and images
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
max_num_frames: int,
|
58 |
+
height: int,
|
59 |
+
width: int,
|
60 |
+
data_root: str,
|
61 |
+
caption_column: str,
|
62 |
+
video_column: str,
|
63 |
+
image_column: str | None,
|
64 |
+
device: torch.device,
|
65 |
+
trainer: "Trainer" = None,
|
66 |
+
*args,
|
67 |
+
**kwargs
|
68 |
+
) -> None:
|
69 |
+
data_root = Path(data_root)
|
70 |
+
metadata_path = data_root / "metadata_revised.jsonl"
|
71 |
+
assert metadata_path.is_file(), "For this dataset type, you need metadata.jsonl in the root path"
|
72 |
+
|
73 |
+
# Load metadata
|
74 |
+
# metadata = {
|
75 |
+
# "video_path": ...,
|
76 |
+
# "hash_code": ...,
|
77 |
+
# "prompt": ...,
|
78 |
+
# }
|
79 |
+
metadata = []
|
80 |
+
with open(metadata_path, "r") as f:
|
81 |
+
for line in f:
|
82 |
+
metadata.append( json.loads(line) )
|
83 |
+
|
84 |
+
self.prompts = [x["prompt"] for x in metadata]
|
85 |
+
if 'curated' in str(data_root).lower():
|
86 |
+
self.prompt_embeddings = [data_root / "prompt_embeddings" / (x["hash_code"] + '.safetensors') for x in metadata]
|
87 |
+
else:
|
88 |
+
self.prompt_embeddings = [data_root / "prompt_embeddings_revised" / (x["hash_code"] + '.safetensors') for x in metadata]
|
89 |
+
self.videos = [data_root / "video_latent" / "x".join(str(x) for x in trainer.args.train_resolution) / (x["hash_code"] + '.safetensors') for x in metadata]
|
90 |
+
self.images = [data_root / "first_frames" / (x["hash_code"] + '.png') for x in metadata]
|
91 |
+
self.flows = [data_root / "flow_direct_f_latent" / (x["hash_code"] + '.safetensors') for x in metadata]
|
92 |
+
|
93 |
+
|
94 |
+
# data_root = Path(data_root)
|
95 |
+
# self.prompts = load_prompts(data_root / caption_column)
|
96 |
+
# self.videos = load_videos(data_root / video_column)
|
97 |
+
|
98 |
+
self.trainer = trainer
|
99 |
+
|
100 |
+
self.device = device
|
101 |
+
self.encode_video = trainer.encode_video
|
102 |
+
self.encode_text = trainer.encode_text
|
103 |
+
|
104 |
+
# Check if number of prompts matches number of videos and images
|
105 |
+
if not (len(self.videos) == len(self.prompts) == len(self.images) == len(self.flows)):
|
106 |
+
raise ValueError(
|
107 |
+
f"Expected length of prompts, videos and images to be the same but found {len(self.prompts)=}, {len(self.videos)=}, {len(self.images)=} and {len(self.flows)=}. Please ensure that the number of caption prompts, videos and images match in your dataset."
|
108 |
+
)
|
109 |
+
|
110 |
+
self.max_num_frames = max_num_frames
|
111 |
+
self.height = height
|
112 |
+
self.width = width
|
113 |
+
|
114 |
+
self.__frame_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
|
115 |
+
self.__image_transforms = self.__frame_transforms
|
116 |
+
|
117 |
+
self.length = len(self.videos)
|
118 |
+
|
119 |
+
print(f"Dataset size: {self.length}")
|
120 |
+
|
121 |
+
def __len__(self) -> int:
|
122 |
+
return self.length
|
123 |
+
|
124 |
+
def load_data_pair(self, index):
|
125 |
+
# prompt = self.prompts[index]
|
126 |
+
prompt_embedding_path = self.prompt_embeddings[index]
|
127 |
+
encoded_video_path = self.videos[index]
|
128 |
+
encoded_flow_path = self.flows[index]
|
129 |
+
# mask_path = self.masks[index]
|
130 |
+
# image_path = self.images[index]
|
131 |
+
# train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
|
132 |
+
|
133 |
+
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
|
134 |
+
encoded_video = load_file(encoded_video_path)["encoded_video"] # CFHW
|
135 |
+
encoded_flow = load_file(encoded_flow_path)["encoded_flow_f"] # CFHW
|
136 |
+
|
137 |
+
return prompt_embedding, encoded_video, encoded_flow
|
138 |
+
|
139 |
+
def __getitem__(self, index: int) -> Dict[str, Any]:
|
140 |
+
while True:
|
141 |
+
try:
|
142 |
+
prompt_embedding, encoded_video, encoded_flow = self.load_data_pair(index)
|
143 |
+
break
|
144 |
+
except Exception as e:
|
145 |
+
print(f"Error loading {self.prompt_embeddings[index]}: {str(e)}")
|
146 |
+
index = random.randint(0, self.length - 1)
|
147 |
+
|
148 |
+
image_path = self.images[index]
|
149 |
+
prompt = self.prompts[index]
|
150 |
+
train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
|
151 |
+
|
152 |
+
_, image = self.preprocess(None, image_path)
|
153 |
+
image = self.image_transform(image)
|
154 |
+
|
155 |
+
|
156 |
+
# shape of encoded_video: [C, F, H, W]
|
157 |
+
# shape and scale of image: [C, H, W], [-1,1]
|
158 |
+
return {
|
159 |
+
"image": image,
|
160 |
+
"prompt_embedding": prompt_embedding,
|
161 |
+
"encoded_video": encoded_video,
|
162 |
+
"encoded_flow": encoded_flow,
|
163 |
+
"video_metadata": {
|
164 |
+
"num_frames": encoded_video.shape[1],
|
165 |
+
"height": encoded_video.shape[2],
|
166 |
+
"width": encoded_video.shape[3],
|
167 |
+
},
|
168 |
+
}
|
169 |
+
|
170 |
+
@override
|
171 |
+
def preprocess(self, video_path: Path | None, image_path: Path | None) -> Tuple[torch.Tensor, torch.Tensor]:
|
172 |
+
if video_path is not None:
|
173 |
+
video = preprocess_video_with_resize(video_path, self.max_num_frames, self.height, self.width)
|
174 |
+
else:
|
175 |
+
video = None
|
176 |
+
if image_path is not None:
|
177 |
+
image = preprocess_image_with_resize(image_path, self.height, self.width)
|
178 |
+
else:
|
179 |
+
image = None
|
180 |
+
return video, image
|
181 |
+
|
182 |
+
@override
|
183 |
+
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
184 |
+
return torch.stack([self.__frame_transforms(f) for f in frames], dim=0)
|
185 |
+
|
186 |
+
@override
|
187 |
+
def image_transform(self, image: torch.Tensor) -> torch.Tensor:
|
188 |
+
return self.__image_transforms(image)
|
finetune/datasets/t2v_dataset.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from accelerate.logging import get_logger
|
7 |
+
from safetensors.torch import load_file, save_file
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
from torchvision import transforms
|
10 |
+
from typing_extensions import override
|
11 |
+
|
12 |
+
from finetune.constants import LOG_LEVEL, LOG_NAME
|
13 |
+
|
14 |
+
from .utils import load_prompts, load_videos, preprocess_video_with_buckets, preprocess_video_with_resize
|
15 |
+
|
16 |
+
|
17 |
+
if TYPE_CHECKING:
|
18 |
+
from finetune.trainer import Trainer
|
19 |
+
|
20 |
+
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
|
21 |
+
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
|
22 |
+
import decord # isort:skip
|
23 |
+
|
24 |
+
decord.bridge.set_bridge("torch")
|
25 |
+
|
26 |
+
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
27 |
+
|
28 |
+
|
29 |
+
class BaseT2VDataset(Dataset):
|
30 |
+
"""
|
31 |
+
Base dataset class for Text-to-Video (T2V) training.
|
32 |
+
|
33 |
+
This dataset loads prompts and videos for T2V training.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
data_root (str): Root directory containing the dataset files
|
37 |
+
caption_column (str): Path to file containing text prompts/captions
|
38 |
+
video_column (str): Path to file containing video paths
|
39 |
+
device (torch.device): Device to load the data on
|
40 |
+
encode_video_fn (Callable[[torch.Tensor], torch.Tensor], optional): Function to encode videos
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
data_root: str,
|
46 |
+
caption_column: str,
|
47 |
+
video_column: str,
|
48 |
+
device: torch.device = None,
|
49 |
+
trainer: "Trainer" = None,
|
50 |
+
*args,
|
51 |
+
**kwargs,
|
52 |
+
) -> None:
|
53 |
+
super().__init__()
|
54 |
+
|
55 |
+
data_root = Path(data_root)
|
56 |
+
self.prompts = load_prompts(data_root / caption_column)
|
57 |
+
self.videos = load_videos(data_root / video_column)
|
58 |
+
self.device = device
|
59 |
+
self.encode_video = trainer.encode_video
|
60 |
+
self.encode_text = trainer.encode_text
|
61 |
+
self.trainer = trainer
|
62 |
+
|
63 |
+
# Check if all video files exist
|
64 |
+
if any(not path.is_file() for path in self.videos):
|
65 |
+
raise ValueError(
|
66 |
+
f"Some video files were not found. Please ensure that all video files exist in the dataset directory. Missing file: {next(path for path in self.videos if not path.is_file())}"
|
67 |
+
)
|
68 |
+
|
69 |
+
# Check if number of prompts matches number of videos
|
70 |
+
if len(self.videos) != len(self.prompts):
|
71 |
+
raise ValueError(
|
72 |
+
f"Expected length of prompts and videos to be the same but found {len(self.prompts)=} and {len(self.videos)=}. Please ensure that the number of caption prompts and videos match in your dataset."
|
73 |
+
)
|
74 |
+
|
75 |
+
def __len__(self) -> int:
|
76 |
+
return len(self.videos)
|
77 |
+
|
78 |
+
def __getitem__(self, index: int) -> Dict[str, Any]:
|
79 |
+
if isinstance(index, list):
|
80 |
+
# Here, index is actually a list of data objects that we need to return.
|
81 |
+
# The BucketSampler should ideally return indices. But, in the sampler, we'd like
|
82 |
+
# to have information about num_frames, height and width. Since this is not stored
|
83 |
+
# as metadata, we need to read the video to get this information. You could read this
|
84 |
+
# information without loading the full video in memory, but we do it anyway. In order
|
85 |
+
# to not load the video twice (once to get the metadata, and once to return the loaded video
|
86 |
+
# based on sampled indices), we cache it in the BucketSampler. When the sampler is
|
87 |
+
# to yield, we yield the cache data instead of indices. So, this special check ensures
|
88 |
+
# that data is not loaded a second time. PRs are welcome for improvements.
|
89 |
+
return index
|
90 |
+
|
91 |
+
prompt = self.prompts[index]
|
92 |
+
video = self.videos[index]
|
93 |
+
train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
|
94 |
+
|
95 |
+
cache_dir = self.trainer.args.data_root / "cache"
|
96 |
+
video_latent_dir = cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str
|
97 |
+
prompt_embeddings_dir = cache_dir / "prompt_embeddings"
|
98 |
+
video_latent_dir.mkdir(parents=True, exist_ok=True)
|
99 |
+
prompt_embeddings_dir.mkdir(parents=True, exist_ok=True)
|
100 |
+
|
101 |
+
prompt_hash = str(hashlib.sha256(prompt.encode()).hexdigest())
|
102 |
+
prompt_embedding_path = prompt_embeddings_dir / (prompt_hash + ".safetensors")
|
103 |
+
encoded_video_path = video_latent_dir / (video.stem + ".safetensors")
|
104 |
+
|
105 |
+
if prompt_embedding_path.exists():
|
106 |
+
prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
|
107 |
+
logger.debug(
|
108 |
+
f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}",
|
109 |
+
main_process_only=False,
|
110 |
+
)
|
111 |
+
else:
|
112 |
+
prompt_embedding = self.encode_text(prompt)
|
113 |
+
prompt_embedding = prompt_embedding.to("cpu")
|
114 |
+
# [1, seq_len, hidden_size] -> [seq_len, hidden_size]
|
115 |
+
prompt_embedding = prompt_embedding[0]
|
116 |
+
save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path)
|
117 |
+
logger.info(f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False)
|
118 |
+
|
119 |
+
if encoded_video_path.exists():
|
120 |
+
# encoded_video = torch.load(encoded_video_path, weights_only=True)
|
121 |
+
encoded_video = load_file(encoded_video_path)["encoded_video"]
|
122 |
+
logger.debug(f"Loaded encoded video from {encoded_video_path}", main_process_only=False)
|
123 |
+
# shape of image: [C, H, W]
|
124 |
+
else:
|
125 |
+
frames = self.preprocess(video)
|
126 |
+
frames = frames.to(self.device)
|
127 |
+
# Current shape of frames: [F, C, H, W]
|
128 |
+
frames = self.video_transform(frames)
|
129 |
+
# Convert to [B, C, F, H, W]
|
130 |
+
frames = frames.unsqueeze(0)
|
131 |
+
frames = frames.permute(0, 2, 1, 3, 4).contiguous()
|
132 |
+
encoded_video = self.encode_video(frames)
|
133 |
+
|
134 |
+
# [1, C, F, H, W] -> [C, F, H, W]
|
135 |
+
encoded_video = encoded_video[0]
|
136 |
+
encoded_video = encoded_video.to("cpu")
|
137 |
+
save_file({"encoded_video": encoded_video}, encoded_video_path)
|
138 |
+
logger.info(f"Saved encoded video to {encoded_video_path}", main_process_only=False)
|
139 |
+
|
140 |
+
# shape of encoded_video: [C, F, H, W]
|
141 |
+
return {
|
142 |
+
"prompt_embedding": prompt_embedding,
|
143 |
+
"encoded_video": encoded_video,
|
144 |
+
"video_metadata": {
|
145 |
+
"num_frames": encoded_video.shape[1],
|
146 |
+
"height": encoded_video.shape[2],
|
147 |
+
"width": encoded_video.shape[3],
|
148 |
+
},
|
149 |
+
}
|
150 |
+
|
151 |
+
def preprocess(self, video_path: Path) -> torch.Tensor:
|
152 |
+
"""
|
153 |
+
Loads and preprocesses a video.
|
154 |
+
|
155 |
+
Args:
|
156 |
+
video_path: Path to the video file to load.
|
157 |
+
|
158 |
+
Returns:
|
159 |
+
torch.Tensor: Video tensor of shape [F, C, H, W] where:
|
160 |
+
- F is number of frames
|
161 |
+
- C is number of channels (3 for RGB)
|
162 |
+
- H is height
|
163 |
+
- W is width
|
164 |
+
"""
|
165 |
+
raise NotImplementedError("Subclass must implement this method")
|
166 |
+
|
167 |
+
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
168 |
+
"""
|
169 |
+
Applies transformations to a video.
|
170 |
+
|
171 |
+
Args:
|
172 |
+
frames (torch.Tensor): A 4D tensor representing a video
|
173 |
+
with shape [F, C, H, W] where:
|
174 |
+
- F is number of frames
|
175 |
+
- C is number of channels (3 for RGB)
|
176 |
+
- H is height
|
177 |
+
- W is width
|
178 |
+
|
179 |
+
Returns:
|
180 |
+
torch.Tensor: The transformed video tensor with the same shape as the input
|
181 |
+
"""
|
182 |
+
raise NotImplementedError("Subclass must implement this method")
|
183 |
+
|
184 |
+
|
185 |
+
class T2VDatasetWithResize(BaseT2VDataset):
|
186 |
+
"""
|
187 |
+
A dataset class for text-to-video generation that resizes inputs to fixed dimensions.
|
188 |
+
|
189 |
+
This class preprocesses videos by resizing them to specified dimensions:
|
190 |
+
- Videos are resized to max_num_frames x height x width
|
191 |
+
|
192 |
+
Args:
|
193 |
+
max_num_frames (int): Maximum number of frames to extract from videos
|
194 |
+
height (int): Target height for resizing videos
|
195 |
+
width (int): Target width for resizing videos
|
196 |
+
"""
|
197 |
+
|
198 |
+
def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None:
|
199 |
+
super().__init__(*args, **kwargs)
|
200 |
+
|
201 |
+
self.max_num_frames = max_num_frames
|
202 |
+
self.height = height
|
203 |
+
self.width = width
|
204 |
+
|
205 |
+
self.__frame_transform = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
|
206 |
+
|
207 |
+
@override
|
208 |
+
def preprocess(self, video_path: Path) -> torch.Tensor:
|
209 |
+
return preprocess_video_with_resize(
|
210 |
+
video_path,
|
211 |
+
self.max_num_frames,
|
212 |
+
self.height,
|
213 |
+
self.width,
|
214 |
+
)
|
215 |
+
|
216 |
+
@override
|
217 |
+
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
218 |
+
return torch.stack([self.__frame_transform(f) for f in frames], dim=0)
|
219 |
+
|
220 |
+
|
221 |
+
class T2VDatasetWithBuckets(BaseT2VDataset):
|
222 |
+
def __init__(
|
223 |
+
self,
|
224 |
+
video_resolution_buckets: List[Tuple[int, int, int]],
|
225 |
+
vae_temporal_compression_ratio: int,
|
226 |
+
vae_height_compression_ratio: int,
|
227 |
+
vae_width_compression_ratio: int,
|
228 |
+
*args,
|
229 |
+
**kwargs,
|
230 |
+
) -> None:
|
231 |
+
""" """
|
232 |
+
super().__init__(*args, **kwargs)
|
233 |
+
|
234 |
+
self.video_resolution_buckets = [
|
235 |
+
(
|
236 |
+
int(b[0] / vae_temporal_compression_ratio),
|
237 |
+
int(b[1] / vae_height_compression_ratio),
|
238 |
+
int(b[2] / vae_width_compression_ratio),
|
239 |
+
)
|
240 |
+
for b in video_resolution_buckets
|
241 |
+
]
|
242 |
+
|
243 |
+
self.__frame_transform = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
|
244 |
+
|
245 |
+
@override
|
246 |
+
def preprocess(self, video_path: Path) -> torch.Tensor:
|
247 |
+
return preprocess_video_with_buckets(video_path, self.video_resolution_buckets)
|
248 |
+
|
249 |
+
@override
|
250 |
+
def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
|
251 |
+
return torch.stack([self.__frame_transform(f) for f in frames], dim=0)
|
finetune/datasets/utils.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import List, Tuple
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import torch
|
7 |
+
from torchvision.transforms.functional import resize
|
8 |
+
from einops import repeat, rearrange
|
9 |
+
|
10 |
+
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
|
11 |
+
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
|
12 |
+
import decord # isort:skip
|
13 |
+
|
14 |
+
decord.bridge.set_bridge("torch")
|
15 |
+
|
16 |
+
from PIL import Image
|
17 |
+
import numpy as np
|
18 |
+
import pdb
|
19 |
+
|
20 |
+
########## loaders ##########
|
21 |
+
|
22 |
+
|
23 |
+
def load_prompts(prompt_path: Path) -> List[str]:
|
24 |
+
with open(prompt_path, "r", encoding="utf-8") as file:
|
25 |
+
return [line.strip() for line in file.readlines() if len(line.strip()) > 0]
|
26 |
+
|
27 |
+
|
28 |
+
def load_videos(video_path: Path) -> List[Path]:
|
29 |
+
with open(video_path, "r", encoding="utf-8") as file:
|
30 |
+
return [video_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0]
|
31 |
+
|
32 |
+
|
33 |
+
def load_images(image_path: Path) -> List[Path]:
|
34 |
+
with open(image_path, "r", encoding="utf-8") as file:
|
35 |
+
return [image_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0]
|
36 |
+
|
37 |
+
|
38 |
+
def load_images_from_videos(videos_path: List[Path]) -> List[Path]:
|
39 |
+
first_frames_dir = videos_path[0].parent.parent / "first_frames"
|
40 |
+
first_frames_dir.mkdir(exist_ok=True)
|
41 |
+
|
42 |
+
first_frame_paths = []
|
43 |
+
for video_path in videos_path:
|
44 |
+
frame_path = first_frames_dir / f"{video_path.stem}.png"
|
45 |
+
if frame_path.exists():
|
46 |
+
first_frame_paths.append(frame_path)
|
47 |
+
continue
|
48 |
+
|
49 |
+
# Open video
|
50 |
+
cap = cv2.VideoCapture(str(video_path))
|
51 |
+
|
52 |
+
# Read first frame
|
53 |
+
ret, frame = cap.read()
|
54 |
+
if not ret:
|
55 |
+
raise RuntimeError(f"Failed to read video: {video_path}")
|
56 |
+
|
57 |
+
# Save frame as PNG with same name as video
|
58 |
+
cv2.imwrite(str(frame_path), frame)
|
59 |
+
logging.info(f"Saved first frame to {frame_path}")
|
60 |
+
|
61 |
+
# Release video capture
|
62 |
+
cap.release()
|
63 |
+
|
64 |
+
first_frame_paths.append(frame_path)
|
65 |
+
|
66 |
+
return first_frame_paths
|
67 |
+
|
68 |
+
|
69 |
+
def load_binary_mask_compressed(path, shape, device, dtype):
|
70 |
+
# shape: (F,C,H,W), C=1
|
71 |
+
with open(path, 'rb') as f:
|
72 |
+
packed = np.frombuffer(f.read(), dtype=np.uint8)
|
73 |
+
unpacked = np.unpackbits(packed)[:np.prod(shape)]
|
74 |
+
mask_loaded = torch.from_numpy(unpacked).to(device, dtype).reshape(shape)
|
75 |
+
|
76 |
+
mask_interp = torch.nn.functional.interpolate(rearrange(mask_loaded, 'f c h w -> c f h w').unsqueeze(0), size=(shape[0]//4+1, shape[2]//8, shape[3]//8), mode='trilinear', align_corners=False).squeeze(0) # CFHW
|
77 |
+
mask_interp[mask_interp>=0.5] = 1.0
|
78 |
+
mask_interp[mask_interp<0.5] = 0.0
|
79 |
+
|
80 |
+
return rearrange(mask_loaded, 'f c h w -> c f h w'), mask_interp
|
81 |
+
|
82 |
+
########## preprocessors ##########
|
83 |
+
|
84 |
+
|
85 |
+
def preprocess_image_with_resize(
|
86 |
+
image_path: Path | str,
|
87 |
+
height: int,
|
88 |
+
width: int,
|
89 |
+
) -> torch.Tensor:
|
90 |
+
"""
|
91 |
+
Loads and resizes a single image.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
image_path: Path to the image file.
|
95 |
+
height: Target height for resizing.
|
96 |
+
width: Target width for resizing.
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
torch.Tensor: Image tensor with shape [C, H, W] where:
|
100 |
+
C = number of channels (3 for RGB)
|
101 |
+
H = height
|
102 |
+
W = width
|
103 |
+
"""
|
104 |
+
if isinstance(image_path, str):
|
105 |
+
image_path = Path(image_path)
|
106 |
+
# image = cv2.imread(image_path.as_posix())
|
107 |
+
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
108 |
+
# image = cv2.resize(image, (width, height))
|
109 |
+
# image = torch.from_numpy(image).float()
|
110 |
+
# image = image.permute(2, 0, 1).contiguous()
|
111 |
+
|
112 |
+
image = np.array(Image.open(image_path.as_posix()).resize((width, height)))
|
113 |
+
image = torch.from_numpy(image).float()
|
114 |
+
image = image.permute(2, 0, 1).contiguous()
|
115 |
+
|
116 |
+
return image
|
117 |
+
|
118 |
+
|
119 |
+
def preprocess_video_with_resize(
|
120 |
+
video_path: Path | str,
|
121 |
+
max_num_frames: int,
|
122 |
+
height: int,
|
123 |
+
width: int,
|
124 |
+
) -> torch.Tensor:
|
125 |
+
"""
|
126 |
+
Loads and resizes a single video.
|
127 |
+
|
128 |
+
The function processes the video through these steps:
|
129 |
+
1. If video frame count > max_num_frames, downsample frames evenly
|
130 |
+
2. If video dimensions don't match (height, width), resize frames
|
131 |
+
|
132 |
+
Args:
|
133 |
+
video_path: Path to the video file.
|
134 |
+
max_num_frames: Maximum number of frames to keep.
|
135 |
+
height: Target height for resizing.
|
136 |
+
width: Target width for resizing.
|
137 |
+
|
138 |
+
Returns:
|
139 |
+
A torch.Tensor with shape [F, C, H, W] where:
|
140 |
+
F = number of frames
|
141 |
+
C = number of channels (3 for RGB)
|
142 |
+
H = height
|
143 |
+
W = width
|
144 |
+
"""
|
145 |
+
if isinstance(video_path, str):
|
146 |
+
video_path = Path(video_path)
|
147 |
+
video_reader = decord.VideoReader(uri=video_path.as_posix(), width=width, height=height)
|
148 |
+
video_num_frames = len(video_reader)
|
149 |
+
if video_num_frames < max_num_frames:
|
150 |
+
# Get all frames first
|
151 |
+
frames = video_reader.get_batch(list(range(video_num_frames)))
|
152 |
+
# Repeat the last frame until we reach max_num_frames
|
153 |
+
last_frame = frames[-1:]
|
154 |
+
num_repeats = max_num_frames - video_num_frames
|
155 |
+
repeated_frames = last_frame.repeat(num_repeats, 1, 1, 1)
|
156 |
+
frames = torch.cat([frames, repeated_frames], dim=0)
|
157 |
+
return frames.float().permute(0, 3, 1, 2).contiguous()
|
158 |
+
else:
|
159 |
+
indices = list(range(0, video_num_frames, video_num_frames // max_num_frames))
|
160 |
+
frames = video_reader.get_batch(indices)
|
161 |
+
import pdb
|
162 |
+
pdb.set_trace()
|
163 |
+
frames = frames[:max_num_frames].float()
|
164 |
+
frames = frames.permute(0, 3, 1, 2).contiguous()
|
165 |
+
return frames
|
166 |
+
|
167 |
+
|
168 |
+
def preprocess_video_with_buckets(
|
169 |
+
video_path: Path,
|
170 |
+
resolution_buckets: List[Tuple[int, int, int]],
|
171 |
+
) -> torch.Tensor:
|
172 |
+
"""
|
173 |
+
Args:
|
174 |
+
video_path: Path to the video file.
|
175 |
+
resolution_buckets: List of tuples (num_frames, height, width) representing
|
176 |
+
available resolution buckets.
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
torch.Tensor: Video tensor with shape [F, C, H, W] where:
|
180 |
+
F = number of frames
|
181 |
+
C = number of channels (3 for RGB)
|
182 |
+
H = height
|
183 |
+
W = width
|
184 |
+
|
185 |
+
The function processes the video through these steps:
|
186 |
+
1. Finds nearest frame bucket <= video frame count
|
187 |
+
2. Downsamples frames evenly to match bucket size
|
188 |
+
3. Finds nearest resolution bucket based on dimensions
|
189 |
+
4. Resizes frames to match bucket resolution
|
190 |
+
"""
|
191 |
+
video_reader = decord.VideoReader(uri=video_path.as_posix())
|
192 |
+
video_num_frames = len(video_reader)
|
193 |
+
resolution_buckets = [bucket for bucket in resolution_buckets if bucket[0] <= video_num_frames]
|
194 |
+
if len(resolution_buckets) == 0:
|
195 |
+
raise ValueError(f"video frame count in {video_path} is less than all frame buckets {resolution_buckets}")
|
196 |
+
|
197 |
+
nearest_frame_bucket = min(
|
198 |
+
resolution_buckets,
|
199 |
+
key=lambda bucket: video_num_frames - bucket[0],
|
200 |
+
default=1,
|
201 |
+
)[0]
|
202 |
+
frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket))
|
203 |
+
frames = video_reader.get_batch(frame_indices)
|
204 |
+
frames = frames[:nearest_frame_bucket].float()
|
205 |
+
frames = frames.permute(0, 3, 1, 2).contiguous()
|
206 |
+
|
207 |
+
nearest_res = min(resolution_buckets, key=lambda x: abs(x[1] - frames.shape[2]) + abs(x[2] - frames.shape[3]))
|
208 |
+
nearest_res = (nearest_res[1], nearest_res[2])
|
209 |
+
frames = torch.stack([resize(f, nearest_res) for f in frames], dim=0)
|
210 |
+
|
211 |
+
return frames
|
finetune/models/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
|
5 |
+
package_dir = Path(__file__).parent
|
6 |
+
|
7 |
+
for subdir in package_dir.iterdir():
|
8 |
+
if subdir.is_dir() and not subdir.name.startswith("_"):
|
9 |
+
for module_path in subdir.glob("*.py"):
|
10 |
+
module_name = module_path.stem
|
11 |
+
full_module_name = f".{subdir.name}.{module_name}"
|
12 |
+
importlib.import_module(full_module_name, package=__name__)
|
finetune/models/cogvideox_i2v/flovd_OMSM_lora_trainer.py
ADDED
@@ -0,0 +1,748 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Tuple
|
2 |
+
from pathlib import Path
|
3 |
+
import os
|
4 |
+
import hashlib
|
5 |
+
import json
|
6 |
+
import random
|
7 |
+
import wandb
|
8 |
+
import math
|
9 |
+
import numpy as np
|
10 |
+
from einops import rearrange, repeat
|
11 |
+
from safetensors.torch import load_file, save_file
|
12 |
+
from accelerate.logging import get_logger
|
13 |
+
|
14 |
+
import torch
|
15 |
+
|
16 |
+
from accelerate.utils import gather_object
|
17 |
+
|
18 |
+
from diffusers import (
|
19 |
+
AutoencoderKLCogVideoX,
|
20 |
+
CogVideoXDPMScheduler,
|
21 |
+
CogVideoXImageToVideoPipeline,
|
22 |
+
CogVideoXTransformer3DModel,
|
23 |
+
)
|
24 |
+
from diffusers.utils.export_utils import export_to_video
|
25 |
+
|
26 |
+
from finetune.pipeline.flovd_OMSM_cogvideox_pipeline import FloVDOMSMCogVideoXImageToVideoPipeline
|
27 |
+
from finetune.constants import LOG_LEVEL, LOG_NAME
|
28 |
+
|
29 |
+
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
30 |
+
from PIL import Image
|
31 |
+
from numpy import dtype
|
32 |
+
from transformers import AutoTokenizer, T5EncoderModel
|
33 |
+
from typing_extensions import override
|
34 |
+
|
35 |
+
from finetune.schemas import Args, Components, State
|
36 |
+
from finetune.trainer import Trainer
|
37 |
+
from finetune.utils import (
|
38 |
+
cast_training_params,
|
39 |
+
free_memory,
|
40 |
+
get_memory_statistics,
|
41 |
+
string_to_filename,
|
42 |
+
unwrap_model,
|
43 |
+
)
|
44 |
+
from finetune.datasets.utils import (
|
45 |
+
preprocess_image_with_resize,
|
46 |
+
load_binary_mask_compressed,
|
47 |
+
)
|
48 |
+
from finetune.modules.camera_sampler import SampleManualCam
|
49 |
+
from finetune.modules.camera_flow_generator import CameraFlowGenerator
|
50 |
+
from finetune.modules.utils import get_camera_flow_generator_input, forward_bilinear_splatting, flow_to_color
|
51 |
+
|
52 |
+
from ..utils import register
|
53 |
+
|
54 |
+
import sys
|
55 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
56 |
+
|
57 |
+
import pdb
|
58 |
+
|
59 |
+
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
60 |
+
|
61 |
+
class FloVDOMSMCogVideoXI2VLoraTrainer(Trainer):
|
62 |
+
UNLOAD_LIST = ["text_encoder"]
|
63 |
+
|
64 |
+
@override
|
65 |
+
def __init__(self, args: Args) -> None:
|
66 |
+
super().__init__(args)
|
67 |
+
|
68 |
+
|
69 |
+
@override
|
70 |
+
def load_components(self) -> Dict[str, Any]:
|
71 |
+
# TODO. Change the pipeline and ...
|
72 |
+
components = Components()
|
73 |
+
model_path = str(self.args.model_path)
|
74 |
+
|
75 |
+
components.pipeline_cls = FloVDOMSMCogVideoXImageToVideoPipeline
|
76 |
+
|
77 |
+
components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
|
78 |
+
|
79 |
+
components.text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder")
|
80 |
+
|
81 |
+
components.transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer")
|
82 |
+
|
83 |
+
components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
|
84 |
+
|
85 |
+
components.scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
|
86 |
+
|
87 |
+
return components
|
88 |
+
|
89 |
+
|
90 |
+
@override
|
91 |
+
def initialize_pipeline(self) -> FloVDOMSMCogVideoXImageToVideoPipeline:
|
92 |
+
# TODO. Change the pipeline and ...
|
93 |
+
pipe = FloVDOMSMCogVideoXImageToVideoPipeline(
|
94 |
+
tokenizer=self.components.tokenizer,
|
95 |
+
text_encoder=self.components.text_encoder,
|
96 |
+
vae=self.components.vae,
|
97 |
+
transformer=unwrap_model(self.accelerator, self.components.transformer),
|
98 |
+
scheduler=self.components.scheduler,
|
99 |
+
)
|
100 |
+
return pipe
|
101 |
+
|
102 |
+
def initialize_flow_generator(self):
|
103 |
+
depth_estimator_kwargs = {
|
104 |
+
"target": 'modules.depth_warping.depth_warping.DepthWarping_wrapper',
|
105 |
+
"kwargs": {
|
106 |
+
"ckpt_path": '/workspace/workspace/checkpoints/depth_anything/depth_anything_v2_metric_hypersim_vitb.pth',
|
107 |
+
"model_config": {
|
108 |
+
"max_depth": 20,
|
109 |
+
"encoder": 'vitb',
|
110 |
+
"features": 128,
|
111 |
+
"out_channels": [96, 192, 384, 768],
|
112 |
+
}
|
113 |
+
|
114 |
+
}
|
115 |
+
}
|
116 |
+
|
117 |
+
return CameraFlowGenerator(depth_estimator_kwargs)
|
118 |
+
|
119 |
+
@override
|
120 |
+
def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
|
121 |
+
ret = {"encoded_videos": [], "prompt_embedding": [], "images": [], "encoded_flow": []}
|
122 |
+
|
123 |
+
for sample in samples:
|
124 |
+
encoded_video = sample["encoded_video"]
|
125 |
+
prompt_embedding = sample["prompt_embedding"]
|
126 |
+
image = sample["image"]
|
127 |
+
encoded_flow = sample["encoded_flow"]
|
128 |
+
|
129 |
+
ret["encoded_videos"].append(encoded_video)
|
130 |
+
ret["prompt_embedding"].append(prompt_embedding)
|
131 |
+
ret["images"].append(image)
|
132 |
+
ret["encoded_flow"].append(encoded_flow)
|
133 |
+
|
134 |
+
ret["encoded_videos"] = torch.stack(ret["encoded_videos"])
|
135 |
+
ret["prompt_embedding"] = torch.stack(ret["prompt_embedding"])
|
136 |
+
ret["images"] = torch.stack(ret["images"])
|
137 |
+
ret["encoded_flow"] = torch.stack(ret["encoded_flow"])
|
138 |
+
|
139 |
+
return ret
|
140 |
+
|
141 |
+
|
142 |
+
@override
|
143 |
+
def compute_loss(self, batch) -> torch.Tensor:
|
144 |
+
prompt_embedding = batch["prompt_embedding"]
|
145 |
+
images = batch["images"]
|
146 |
+
latent_flow = batch["encoded_flow"]
|
147 |
+
|
148 |
+
# Shape of prompt_embedding: [B, seq_len, hidden_size]
|
149 |
+
# Shape of images: [B, C, H, W]
|
150 |
+
# Shape of latent_flow: [B, C, F, H, W]
|
151 |
+
|
152 |
+
patch_size_t = self.state.transformer_config.patch_size_t # WJ: None in i2v setting...
|
153 |
+
if patch_size_t is not None:
|
154 |
+
# ncopy = latent.shape[2] % patch_size_t
|
155 |
+
# # Copy the first frame ncopy times to match patch_size_t
|
156 |
+
# first_frame = latent[:, :, :1, :, :] # Get first frame [B, C, 1, H, W]
|
157 |
+
# latent = torch.cat([first_frame.repeat(1, 1, ncopy, 1, 1), latent], dim=2)
|
158 |
+
# assert latent.shape[2] % patch_size_t == 0
|
159 |
+
raise NotImplementedError("Do not use the case whose patch_size_t is not None")
|
160 |
+
|
161 |
+
batch_size, num_channels, num_frames, height, width = latent_flow.shape
|
162 |
+
|
163 |
+
# Get prompt embeddings
|
164 |
+
_, seq_len, _ = prompt_embedding.shape
|
165 |
+
prompt_embedding = prompt_embedding.view(batch_size, seq_len, -1).to(dtype=latent_flow.dtype)
|
166 |
+
|
167 |
+
# Add frame dimension to images [B,C,H,W] -> [B,C,F,H,W]
|
168 |
+
images = images.unsqueeze(2)
|
169 |
+
# Add noise to images
|
170 |
+
image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=self.accelerator.device)
|
171 |
+
image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=images.dtype)
|
172 |
+
noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None]
|
173 |
+
image_latent_dist = self.components.vae.encode(noisy_images.to(dtype=self.components.vae.dtype)).latent_dist
|
174 |
+
image_latents = image_latent_dist.sample() * self.components.vae.config.scaling_factor
|
175 |
+
|
176 |
+
# Sample a random timestep for each sample
|
177 |
+
timesteps = torch.randint(
|
178 |
+
0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device
|
179 |
+
)
|
180 |
+
timesteps = timesteps.long()
|
181 |
+
|
182 |
+
# from [B, C, F, H, W] to [B, F, C, H, W]
|
183 |
+
latent_flow = latent_flow.permute(0, 2, 1, 3, 4)
|
184 |
+
image_latents = image_latents.permute(0, 2, 1, 3, 4)
|
185 |
+
assert (image_latents.shape[0], *image_latents.shape[2:]) == (latent_flow.shape[0], *latent_flow.shape[2:])
|
186 |
+
|
187 |
+
# Padding image_latents to the same frame number as latent
|
188 |
+
padding_shape = (latent_flow.shape[0], latent_flow.shape[1] - 1, *latent_flow.shape[2:])
|
189 |
+
latent_padding = image_latents.new_zeros(padding_shape)
|
190 |
+
image_latents = torch.cat([image_latents, latent_padding], dim=1)
|
191 |
+
|
192 |
+
# Add noise to latent
|
193 |
+
noise = torch.randn_like(latent_flow)
|
194 |
+
latent_flow_noisy = self.components.scheduler.add_noise(latent_flow, noise, timesteps)
|
195 |
+
|
196 |
+
|
197 |
+
# Concatenate latent and image_latents in the channel dimension
|
198 |
+
latent_flow_img_noisy = torch.cat([latent_flow_noisy, image_latents], dim=2)
|
199 |
+
|
200 |
+
# Prepare rotary embeds
|
201 |
+
vae_scale_factor_spatial = 2 ** (len(self.components.vae.config.block_out_channels) - 1)
|
202 |
+
transformer_config = self.state.transformer_config
|
203 |
+
rotary_emb = (
|
204 |
+
self.prepare_rotary_positional_embeddings(
|
205 |
+
height=height * vae_scale_factor_spatial,
|
206 |
+
width=width * vae_scale_factor_spatial,
|
207 |
+
num_frames=num_frames,
|
208 |
+
transformer_config=transformer_config,
|
209 |
+
vae_scale_factor_spatial=vae_scale_factor_spatial,
|
210 |
+
device=self.accelerator.device,
|
211 |
+
)
|
212 |
+
if transformer_config.use_rotary_positional_embeddings
|
213 |
+
else None
|
214 |
+
)
|
215 |
+
|
216 |
+
# Predict noise, For CogVideoX1.5 Only.
|
217 |
+
ofs_emb = (
|
218 |
+
None if self.state.transformer_config.ofs_embed_dim is None else latent_flow.new_full((1,), fill_value=2.0)
|
219 |
+
)
|
220 |
+
|
221 |
+
predicted_noise = self.components.transformer(
|
222 |
+
hidden_states=latent_flow_img_noisy,
|
223 |
+
encoder_hidden_states=prompt_embedding,
|
224 |
+
timestep=timesteps,
|
225 |
+
ofs=ofs_emb,
|
226 |
+
image_rotary_emb=rotary_emb,
|
227 |
+
return_dict=False,
|
228 |
+
)[0]
|
229 |
+
|
230 |
+
# Denoise
|
231 |
+
latent_pred = self.components.scheduler.get_velocity(predicted_noise, latent_flow_noisy, timesteps)
|
232 |
+
|
233 |
+
alphas_cumprod = self.components.scheduler.alphas_cumprod[timesteps]
|
234 |
+
weights = 1 / (1 - alphas_cumprod)
|
235 |
+
while len(weights.shape) < len(latent_pred.shape):
|
236 |
+
weights = weights.unsqueeze(-1)
|
237 |
+
|
238 |
+
loss = torch.mean((weights * (latent_pred - latent_flow) ** 2).reshape(batch_size, -1), dim=1)
|
239 |
+
loss = loss.mean()
|
240 |
+
|
241 |
+
return loss
|
242 |
+
|
243 |
+
def prepare_rotary_positional_embeddings(
|
244 |
+
self,
|
245 |
+
height: int,
|
246 |
+
width: int,
|
247 |
+
num_frames: int,
|
248 |
+
transformer_config: Dict,
|
249 |
+
vae_scale_factor_spatial: int,
|
250 |
+
device: torch.device,
|
251 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
252 |
+
grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
|
253 |
+
grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)
|
254 |
+
|
255 |
+
if transformer_config.patch_size_t is None:
|
256 |
+
base_num_frames = num_frames
|
257 |
+
else:
|
258 |
+
base_num_frames = (num_frames + transformer_config.patch_size_t - 1) // transformer_config.patch_size_t
|
259 |
+
|
260 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
261 |
+
embed_dim=transformer_config.attention_head_dim,
|
262 |
+
crops_coords=None,
|
263 |
+
grid_size=(grid_height, grid_width),
|
264 |
+
temporal_size=base_num_frames,
|
265 |
+
grid_type="slice",
|
266 |
+
max_size=(grid_height, grid_width),
|
267 |
+
device=device,
|
268 |
+
)
|
269 |
+
|
270 |
+
return freqs_cos, freqs_sin
|
271 |
+
|
272 |
+
# Validation
|
273 |
+
|
274 |
+
@override
|
275 |
+
def prepare_for_validation(self):
|
276 |
+
# Load from dataset?
|
277 |
+
# Data_root
|
278 |
+
# - metadata.jsonl
|
279 |
+
# - video_latent / args.resolution /
|
280 |
+
# - prompt_embeddings /
|
281 |
+
# - first_frames /
|
282 |
+
# - flow_direct_f_latent /
|
283 |
+
|
284 |
+
data_root = self.args.data_root
|
285 |
+
metadata_path = data_root / "metadata_revised.jsonl"
|
286 |
+
assert metadata_path.is_file(), "For this dataset type, you need metadata.jsonl or metadata_revised.jsonl in the root path"
|
287 |
+
|
288 |
+
# Load metadata
|
289 |
+
# metadata = {
|
290 |
+
# "video_path": ...,
|
291 |
+
# "hash_code": ...,
|
292 |
+
# "prompt": ...,
|
293 |
+
# }
|
294 |
+
metadata = []
|
295 |
+
with open(metadata_path, "r") as f:
|
296 |
+
for line in f:
|
297 |
+
metadata.append( json.loads(line) )
|
298 |
+
|
299 |
+
metadata = random.sample(metadata, self.args.max_scene)
|
300 |
+
|
301 |
+
prompts = [x["prompt"] for x in metadata]
|
302 |
+
if 'curated' in str(data_root).lower():
|
303 |
+
self.prompt_embeddings = [data_root / "prompt_embeddings" / (x["hash_code"] + '.safetensors') for x in metadata]
|
304 |
+
else:
|
305 |
+
self.prompt_embeddings = [data_root / "prompt_embeddings_revised" / (x["hash_code"] + '.safetensors') for x in metadata]
|
306 |
+
videos = [data_root / "video_latent" / "x".join(str(x) for x in self.args.train_resolution) / (x["hash_code"] + '.safetensors') for x in metadata]
|
307 |
+
images = [data_root / "first_frames" / (x["hash_code"] + '.png') for x in metadata]
|
308 |
+
flows = [data_root / "flow_direct_f_latent" / (x["hash_code"] + '.safetensors') for x in metadata]
|
309 |
+
|
310 |
+
# load prompt embedding
|
311 |
+
validation_prompts = []
|
312 |
+
validation_prompt_embeddings = []
|
313 |
+
validation_video_latents = []
|
314 |
+
validation_images = []
|
315 |
+
validation_flow_latents = []
|
316 |
+
for prompt, prompt_embedding, video_latent, image, flow_latent in zip(prompts, prompt_embeddings, videos, images, flows):
|
317 |
+
validation_prompts.append(prompt)
|
318 |
+
validation_prompt_embeddings.append(load_file(prompt_embedding)["prompt_embedding"].unsqueeze(0))
|
319 |
+
validation_video_latents.append(load_file(video_latent)["encoded_video"].unsqueeze(0))
|
320 |
+
validation_flow_latents.append(load_file(flow_latent)["encoded_flow_f"].unsqueeze(0))
|
321 |
+
# validation_images.append(preprocess_image_with_resize(image, self.args.train_resolution[1], self.args.train_resolution[2]))
|
322 |
+
validation_images.append(image)
|
323 |
+
|
324 |
+
|
325 |
+
validation_videos = [None] * len(validation_prompts)
|
326 |
+
|
327 |
+
|
328 |
+
self.state.validation_prompts = validation_prompts
|
329 |
+
self.state.validation_prompt_embeddings = validation_prompt_embeddings
|
330 |
+
self.state.validation_images = validation_images
|
331 |
+
self.state.validation_videos = validation_videos
|
332 |
+
self.state.validation_video_latents = validation_video_latents
|
333 |
+
self.state.validation_flow_latents = validation_flow_latents
|
334 |
+
|
335 |
+
# Debug..
|
336 |
+
self.validate(0)
|
337 |
+
|
338 |
+
|
339 |
+
@override
|
340 |
+
def validation_step(
|
341 |
+
self, eval_data: Dict[str, Any], pipe: FloVDOMSMCogVideoXImageToVideoPipeline
|
342 |
+
) -> List[Tuple[str, Image.Image | List[Image.Image]]]:
|
343 |
+
"""
|
344 |
+
Return the data that needs to be saved. For videos, the data format is List[PIL],
|
345 |
+
and for images, the data format is PIL
|
346 |
+
"""
|
347 |
+
|
348 |
+
prompt_embedding, image = eval_data["prompt_embedding"], eval_data["image"]
|
349 |
+
|
350 |
+
flow_latent_generate = pipe(
|
351 |
+
num_frames=self.state.train_frames,
|
352 |
+
height=self.state.train_height,
|
353 |
+
width=self.state.train_width,
|
354 |
+
prompt=None,
|
355 |
+
prompt_embeds=prompt_embedding,
|
356 |
+
image=image,
|
357 |
+
generator=self.state.generator,
|
358 |
+
num_inference_steps=50,
|
359 |
+
output_type='latent'
|
360 |
+
).frames[0]
|
361 |
+
|
362 |
+
flow_generate = decode_flow(flow_latent_generate.unsqueeze(0).to(self.accelerator.device), self.components.vae, flow_scale_factor=[60, 36]) # BF,C,H,W
|
363 |
+
|
364 |
+
return [("synthesized_flow", flow_generate)]
|
365 |
+
|
366 |
+
|
367 |
+
@override
|
368 |
+
def validate(self, step: int) -> None:
|
369 |
+
#TODO. Fix the codes!!!!
|
370 |
+
logger.info("Starting validation")
|
371 |
+
|
372 |
+
accelerator = self.accelerator
|
373 |
+
num_validation_samples = len(self.state.validation_prompts)
|
374 |
+
|
375 |
+
if num_validation_samples == 0:
|
376 |
+
logger.warning("No validation samples found. Skipping validation.")
|
377 |
+
return
|
378 |
+
|
379 |
+
self.components.transformer.eval()
|
380 |
+
torch.set_grad_enabled(False)
|
381 |
+
|
382 |
+
memory_statistics = get_memory_statistics()
|
383 |
+
logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}")
|
384 |
+
|
385 |
+
##### Initialize pipeline #####
|
386 |
+
pipe = self.initialize_pipeline()
|
387 |
+
camera_flow_generator = self.initialize_flow_generator().to(device=self.accelerator.device, dtype=self.state.weight_dtype)
|
388 |
+
|
389 |
+
if self.state.using_deepspeed:
|
390 |
+
# Can't using model_cpu_offload in deepspeed,
|
391 |
+
# so we need to move all components in pipe to device
|
392 |
+
# pipe.to(self.accelerator.device, dtype=self.state.weight_dtype)
|
393 |
+
self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=["transformer"])
|
394 |
+
else:
|
395 |
+
# if not using deepspeed, use model_cpu_offload to further reduce memory usage
|
396 |
+
# Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage
|
397 |
+
pipe.enable_model_cpu_offload(device=self.accelerator.device)
|
398 |
+
|
399 |
+
# Convert all model weights to training dtype
|
400 |
+
# Note, this will change LoRA weights in self.components.transformer to training dtype, rather than keep them in fp32
|
401 |
+
pipe = pipe.to(dtype=self.state.weight_dtype)
|
402 |
+
|
403 |
+
#################################
|
404 |
+
all_processes_artifacts = []
|
405 |
+
for i in range(num_validation_samples):
|
406 |
+
if self.state.using_deepspeed and self.accelerator.deepspeed_plugin.zero_stage != 3:
|
407 |
+
# Skip current validation on all processes but one
|
408 |
+
if i % accelerator.num_processes != accelerator.process_index:
|
409 |
+
continue
|
410 |
+
|
411 |
+
prompt = self.state.validation_prompts[i]
|
412 |
+
image = self.state.validation_images[i]
|
413 |
+
video = self.state.validation_videos[i]
|
414 |
+
video_latent = self.state.validation_video_latents[i].permute(0,2,1,3,4) # [B,F,C,H,W] (e.g., [B, 13, 16, 60, 90])
|
415 |
+
prompt_embedding = self.state.validation_prompt_embeddings[i]
|
416 |
+
flow_latent = self.state.validation_flow_latents[i].permute(0,2,1,3,4) # [B,F,C,H,W] (e.g., [B, 13, 16, 60, 90])
|
417 |
+
|
418 |
+
|
419 |
+
if image is not None:
|
420 |
+
image = preprocess_image_with_resize(image, self.state.train_height, self.state.train_width)
|
421 |
+
image_torch = image.detach().clone()
|
422 |
+
# Convert image tensor (C, H, W) to PIL images
|
423 |
+
image = image.to(torch.uint8)
|
424 |
+
image = image.permute(1, 2, 0).cpu().numpy()
|
425 |
+
image = Image.fromarray(image)
|
426 |
+
|
427 |
+
if video is not None:
|
428 |
+
video = preprocess_video_with_resize(
|
429 |
+
video, self.state.train_frames, self.state.train_height, self.state.train_width
|
430 |
+
)
|
431 |
+
# Convert video tensor (F, C, H, W) to list of PIL images
|
432 |
+
video = video.round().clamp(0, 255).to(torch.uint8)
|
433 |
+
video = [Image.fromarray(frame.permute(1, 2, 0).cpu().numpy()) for frame in video]
|
434 |
+
else:
|
435 |
+
with torch.cuda.amp.autocast(enabled=True, dtype=self.state.weight_dtype):
|
436 |
+
try:
|
437 |
+
video_decoded = decode_latents(video_latent.to(self.accelerator.device), self.components.vae)
|
438 |
+
except:
|
439 |
+
pass
|
440 |
+
video_decoded = decode_latents(video_latent.to(self.accelerator.device), self.components.vae)
|
441 |
+
video = ((video_decoded + 1.) / 2. * 255.)[0].permute(1,0,2,3).float().clip(0., 255.).to(torch.uint8)
|
442 |
+
video = [Image.fromarray(frame.permute(1, 2, 0).cpu().numpy()) for frame in video]
|
443 |
+
|
444 |
+
with torch.cuda.amp.autocast(enabled=True, dtype=self.state.weight_dtype):
|
445 |
+
try:
|
446 |
+
flow_decoded = decode_flow(flow_latent.to(self.accelerator.device), self.components.vae, flow_scale_factor=[60, 36])
|
447 |
+
except:
|
448 |
+
pass
|
449 |
+
flow_decoded = decode_flow(flow_latent.to(self.accelerator.device), self.components.vae, flow_scale_factor=[60, 36]) # (BF)CHW (C=2)
|
450 |
+
|
451 |
+
|
452 |
+
logger.debug(
|
453 |
+
f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}",
|
454 |
+
main_process_only=False,
|
455 |
+
)
|
456 |
+
# validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe)
|
457 |
+
validation_artifacts = self.validation_step({"prompt_embedding": prompt_embedding, "image": image}, pipe)
|
458 |
+
|
459 |
+
if (
|
460 |
+
self.state.using_deepspeed
|
461 |
+
and self.accelerator.deepspeed_plugin.zero_stage == 3
|
462 |
+
and not accelerator.is_main_process
|
463 |
+
):
|
464 |
+
continue
|
465 |
+
|
466 |
+
prompt_filename = string_to_filename(prompt)[:25]
|
467 |
+
# Calculate hash of reversed prompt as a unique identifier
|
468 |
+
reversed_prompt = prompt[::-1]
|
469 |
+
hash_suffix = hashlib.md5(reversed_prompt.encode()).hexdigest()[:5]
|
470 |
+
|
471 |
+
artifacts = {
|
472 |
+
"image": {"type": "image", "value": image},
|
473 |
+
"video": {"type": "video", "value": video},
|
474 |
+
}
|
475 |
+
for i, (artifact_type, artifact_value) in enumerate(validation_artifacts):
|
476 |
+
artifacts.update({f"artifact_{i}": {"type": artifact_type, "value": artifact_value}})
|
477 |
+
|
478 |
+
# Log flow
|
479 |
+
artifacts.update({f"artifact_flow_{i}": {"type": 'flow', "value": flow_decoded}})
|
480 |
+
|
481 |
+
# Log flow_warped_frames
|
482 |
+
image_tensor = repeat(rearrange(torch.tensor(np.array(image)).to(flow_decoded.device, torch.float), 'h w c -> 1 c h w'), 'b c h w -> (b f) c h w', f=flow_decoded.size(0)) # scale~(0,255) (BF) C H W
|
483 |
+
warped_video = forward_bilinear_splatting(image_tensor, flow_decoded.to(torch.float)) # if we have an occlusion mask from dataset, we can use it.
|
484 |
+
frame_list = []
|
485 |
+
for frame in warped_video:
|
486 |
+
frame = (frame.permute(1,2,0).float().detach().cpu().numpy()).astype(np.uint8).clip(0,255)
|
487 |
+
frame_list.append(Image.fromarray(frame))
|
488 |
+
|
489 |
+
artifacts.update({f"artifact_warped_video_{i}": {"type": 'warped_video', "value": frame_list}})
|
490 |
+
|
491 |
+
# Log synthesized_flow_wraped_frames
|
492 |
+
# artifact_value: synthesized optical flow
|
493 |
+
warped_video2 = forward_bilinear_splatting(image_tensor, artifact_value.to(torch.float)) # if we have an occlusion mask from dataset, we can use it. For OMSM, do not use.
|
494 |
+
frame_list2 = []
|
495 |
+
for frame in warped_video2:
|
496 |
+
frame = (frame.permute(1,2,0).float().detach().cpu().numpy()).astype(np.uint8).clip(0,255)
|
497 |
+
frame_list2.append(Image.fromarray(frame))
|
498 |
+
|
499 |
+
artifacts.update({f"artifact_synthesized_flow_warped_video_{i}": {"type": 'synthesized_flow_warped_video', "value": frame_list2}})
|
500 |
+
|
501 |
+
|
502 |
+
logger.debug(
|
503 |
+
f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}",
|
504 |
+
main_process_only=False,
|
505 |
+
)
|
506 |
+
|
507 |
+
for key, value in list(artifacts.items()):
|
508 |
+
artifact_type = value["type"]
|
509 |
+
artifact_value = value["value"]
|
510 |
+
if artifact_type not in ["image", "video", "flow", "warped_video", "synthesized_flow", "synthesized_flow_warped_video"] or artifact_value is None:
|
511 |
+
continue
|
512 |
+
|
513 |
+
extension = "png" if artifact_type == "image" else "mp4"
|
514 |
+
if artifact_type == "warped_video" or artifact_type == "synthesized_flow_warped_video":
|
515 |
+
filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}-{hash_suffix}_{artifact_type}.{extension}"
|
516 |
+
elif artifact_type == "synthesized_flow":
|
517 |
+
filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}-{hash_suffix}_synthesized_flow.{extension}"
|
518 |
+
elif artifact_type == "flow":
|
519 |
+
filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}-{hash_suffix}_original_flow.{extension}"
|
520 |
+
else:
|
521 |
+
filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}-{hash_suffix}.{extension}"
|
522 |
+
validation_path = self.args.output_dir / "validation_res"
|
523 |
+
validation_path.mkdir(parents=True, exist_ok=True)
|
524 |
+
filename = str(validation_path / filename)
|
525 |
+
|
526 |
+
if artifact_type == "image":
|
527 |
+
logger.debug(f"Saving image to {filename}")
|
528 |
+
artifact_value.save(filename)
|
529 |
+
artifact_value = wandb.Image(filename)
|
530 |
+
elif artifact_type == "video" or artifact_type == "warped_video" or artifact_type == "synthesized_flow_warped_video":
|
531 |
+
logger.debug(f"Saving video to {filename}")
|
532 |
+
export_to_video(artifact_value, filename, fps=self.args.gen_fps)
|
533 |
+
artifact_value = wandb.Video(filename, caption=f"[{artifact_type}]--{prompt}")
|
534 |
+
elif artifact_type == "synthesized_flow" or artifact_type == "flow":
|
535 |
+
# TODO. RGB Visualization of optical flow. (F,2,H,W)
|
536 |
+
artifact_value_RGB = flow_to_color(artifact_value) # BF,C,H,W (B=1)
|
537 |
+
|
538 |
+
frame_list = []
|
539 |
+
for frame in artifact_value_RGB:
|
540 |
+
frame = (frame.permute(1,2,0).float().detach().cpu().numpy()).astype(np.uint8).clip(0,255)
|
541 |
+
frame_list.append(Image.fromarray(frame))
|
542 |
+
|
543 |
+
logger.debug(f"Saving video to {filename}")
|
544 |
+
export_to_video(frame_list, filename, fps=self.args.gen_fps)
|
545 |
+
artifact_value = wandb.Video(filename, caption=f"[{artifact_type}]--{prompt}")
|
546 |
+
|
547 |
+
all_processes_artifacts.append(artifact_value)
|
548 |
+
|
549 |
+
all_artifacts = gather_object(all_processes_artifacts)
|
550 |
+
|
551 |
+
if accelerator.is_main_process:
|
552 |
+
tracker_key = "validation"
|
553 |
+
for tracker in accelerator.trackers:
|
554 |
+
if tracker.name == "wandb":
|
555 |
+
image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)]
|
556 |
+
video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)]
|
557 |
+
tracker.log(
|
558 |
+
{
|
559 |
+
tracker_key: {f"images": image_artifacts, f"videos": video_artifacts},
|
560 |
+
},
|
561 |
+
step=step,
|
562 |
+
)
|
563 |
+
|
564 |
+
########## Clean up ##########
|
565 |
+
if self.state.using_deepspeed:
|
566 |
+
del pipe
|
567 |
+
# Unload models except those needed for training
|
568 |
+
self.__move_components_to_cpu(unload_list=self.UNLOAD_LIST)
|
569 |
+
else:
|
570 |
+
pipe.remove_all_hooks()
|
571 |
+
del pipe
|
572 |
+
# Load models except those not needed for training
|
573 |
+
self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=self.UNLOAD_LIST)
|
574 |
+
self.components.transformer.to(self.accelerator.device, dtype=self.state.weight_dtype)
|
575 |
+
|
576 |
+
# Change trainable weights back to fp32 to keep with dtype after prepare the model
|
577 |
+
cast_training_params([self.components.transformer], dtype=torch.float32)
|
578 |
+
|
579 |
+
del camera_flow_generator
|
580 |
+
|
581 |
+
free_memory()
|
582 |
+
accelerator.wait_for_everyone()
|
583 |
+
################################
|
584 |
+
|
585 |
+
memory_statistics = get_memory_statistics()
|
586 |
+
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
|
587 |
+
torch.cuda.reset_peak_memory_stats(accelerator.device)
|
588 |
+
|
589 |
+
torch.set_grad_enabled(True)
|
590 |
+
self.components.transformer.train()
|
591 |
+
|
592 |
+
|
593 |
+
# mangling
|
594 |
+
def __move_components_to_device(self, dtype, ignore_list: List[str] = []):
|
595 |
+
ignore_list = set(ignore_list)
|
596 |
+
components = self.components.model_dump()
|
597 |
+
for name, component in components.items():
|
598 |
+
if not isinstance(component, type) and hasattr(component, "to"):
|
599 |
+
if name not in ignore_list:
|
600 |
+
setattr(self.components, name, component.to(self.accelerator.device, dtype=dtype))
|
601 |
+
|
602 |
+
# mangling
|
603 |
+
def __move_components_to_cpu(self, unload_list: List[str] = []):
|
604 |
+
unload_list = set(unload_list)
|
605 |
+
components = self.components.model_dump()
|
606 |
+
for name, component in components.items():
|
607 |
+
if not isinstance(component, type) and hasattr(component, "to"):
|
608 |
+
if name in unload_list:
|
609 |
+
setattr(self.components, name, component.to("cpu"))
|
610 |
+
|
611 |
+
|
612 |
+
register("cogvideox-flovd-omsm", "lora", FloVDOMSMCogVideoXI2VLoraTrainer)
|
613 |
+
|
614 |
+
|
615 |
+
#--------------------------------------------------------------------------------------------------
|
616 |
+
# Extract function
|
617 |
+
def encode_text(prompt: str, components, device) -> torch.Tensor:
|
618 |
+
prompt_token_ids = components.tokenizer(
|
619 |
+
prompt,
|
620 |
+
padding="max_length",
|
621 |
+
max_length=components.transformer.config.max_text_seq_length,
|
622 |
+
truncation=True,
|
623 |
+
add_special_tokens=True,
|
624 |
+
return_tensors="pt",
|
625 |
+
)
|
626 |
+
prompt_token_ids = prompt_token_ids.input_ids
|
627 |
+
prompt_embedding = components.text_encoder(prompt_token_ids.to(device))[0]
|
628 |
+
return prompt_embedding
|
629 |
+
|
630 |
+
def encode_video(video: torch.Tensor, vae) -> torch.Tensor:
|
631 |
+
# shape of input video: [B, C, F, H, W]
|
632 |
+
video = video.to(vae.device, dtype=vae.dtype)
|
633 |
+
latent_dist = vae.encode(video).latent_dist
|
634 |
+
latent = latent_dist.sample() * vae.config.scaling_factor
|
635 |
+
return latent
|
636 |
+
|
637 |
+
def decode_latents(latents: torch.Tensor, vae) -> torch.Tensor:
|
638 |
+
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
639 |
+
latents = 1 / vae.config.scaling_factor * latents
|
640 |
+
|
641 |
+
frames = vae.decode(latents).sample
|
642 |
+
return frames
|
643 |
+
|
644 |
+
def compute_optical_flow(raft, ctxt, trgt, raft_iter=20, chunk=2, only_forward=True):
|
645 |
+
num_frames = ctxt.shape[0]
|
646 |
+
chunk_size = (num_frames // chunk) + 1
|
647 |
+
|
648 |
+
flow_f_list = []
|
649 |
+
if not only_forward:
|
650 |
+
flow_b_list = []
|
651 |
+
for i in range(chunk):
|
652 |
+
start = chunk_size * i
|
653 |
+
end = chunk_size * (i+1)
|
654 |
+
|
655 |
+
with torch.no_grad():
|
656 |
+
flow_f = raft(ctxt[start:end], trgt[start:end], num_flow_updates=raft_iter)[-1]
|
657 |
+
if not only_forward:
|
658 |
+
flow_b = raft(trgt[start:end], ctxt[start:end], num_flow_updates=raft_iter)[-1]
|
659 |
+
|
660 |
+
flow_f_list.append(flow_f)
|
661 |
+
if not only_forward:
|
662 |
+
flow_b_list.append(flow_b)
|
663 |
+
|
664 |
+
flow_f = torch.cat(flow_f_list)
|
665 |
+
if not only_forward:
|
666 |
+
flow_b = torch.cat(flow_b_list)
|
667 |
+
|
668 |
+
if not only_forward:
|
669 |
+
return flow_f, flow_b
|
670 |
+
else:
|
671 |
+
return flow_f, None
|
672 |
+
|
673 |
+
def encode_flow(flow, vae, flow_scale_factor):
|
674 |
+
# flow: BF,C,H,W
|
675 |
+
# flow_scale_factor [sf_x, sf_y]
|
676 |
+
assert flow.ndim == 4
|
677 |
+
num_frames, _, height, width = flow.shape
|
678 |
+
|
679 |
+
# Normalize optical flow
|
680 |
+
# ndim: 4 -> 5
|
681 |
+
flow = rearrange(flow, '(b f) c h w -> b f c h w', b=1)
|
682 |
+
flow_norm = adaptive_normalize(flow, flow_scale_factor[0], flow_scale_factor[1])
|
683 |
+
|
684 |
+
# ndim: 5 -> 4
|
685 |
+
flow_norm = rearrange(flow_norm, 'b f c h w -> (b f) c h w', b=1)
|
686 |
+
|
687 |
+
# Duplicate mean value for third channel
|
688 |
+
num_frames, _, H, W = flow_norm.shape
|
689 |
+
flow_norm_extended = torch.empty((num_frames, 3, height, width)).to(flow_norm)
|
690 |
+
flow_norm_extended[:,:2] = flow_norm
|
691 |
+
flow_norm_extended[:,-1:] = flow_norm.mean(dim=1, keepdim=True)
|
692 |
+
flow_norm_extended = rearrange(flow_norm_extended, '(b f) c h w -> b c f h w', f=num_frames)
|
693 |
+
|
694 |
+
return encode_video(flow_norm_extended, vae)
|
695 |
+
|
696 |
+
def decode_flow(flow_latent, vae, flow_scale_factor):
|
697 |
+
flow_latent = flow_latent.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
698 |
+
flow_latent = 1 / vae.config.scaling_factor * flow_latent
|
699 |
+
|
700 |
+
flow = vae.decode(flow_latent).sample # BCFHW
|
701 |
+
|
702 |
+
# discard third channel (which is a mean value of f_x and f_y)
|
703 |
+
flow = flow[:,:2].detach().clone()
|
704 |
+
|
705 |
+
# Unnormalize optical flow
|
706 |
+
flow = rearrange(flow, 'b c f h w -> b f c h w')
|
707 |
+
flow = adaptive_unnormalize(flow, flow_scale_factor[0], flow_scale_factor[1])
|
708 |
+
|
709 |
+
flow = rearrange(flow, 'b f c h w -> (b f) c h w')
|
710 |
+
return flow # BF,C,H,W
|
711 |
+
|
712 |
+
def adaptive_normalize(flow, sf_x, sf_y):
|
713 |
+
# x: BFCHW, optical flow
|
714 |
+
assert flow.ndim == 5, 'Set the shape of the flow input as (B, F, C, H, W)'
|
715 |
+
assert sf_x is not None and sf_y is not None
|
716 |
+
b, f, c, h, w = flow.shape
|
717 |
+
|
718 |
+
max_clip_x = math.sqrt(w/sf_x) * 1.0
|
719 |
+
max_clip_y = math.sqrt(h/sf_y) * 1.0
|
720 |
+
|
721 |
+
flow_norm = flow.detach().clone()
|
722 |
+
flow_x = flow[:, :, 0].detach().clone()
|
723 |
+
flow_y = flow[:, :, 1].detach().clone()
|
724 |
+
|
725 |
+
flow_x_norm = torch.sign(flow_x) * torch.sqrt(torch.abs(flow_x)/sf_x + 1e-7)
|
726 |
+
flow_y_norm = torch.sign(flow_y) * torch.sqrt(torch.abs(flow_y)/sf_y + 1e-7)
|
727 |
+
|
728 |
+
flow_norm[:, :, 0] = torch.clamp(flow_x_norm, min=-max_clip_x, max=max_clip_x)
|
729 |
+
flow_norm[:, :, 1] = torch.clamp(flow_y_norm, min=-max_clip_y, max=max_clip_y)
|
730 |
+
|
731 |
+
return flow_norm
|
732 |
+
|
733 |
+
|
734 |
+
def adaptive_unnormalize(flow, sf_x, sf_y):
|
735 |
+
# x: BFCHW, optical flow
|
736 |
+
assert flow.ndim == 5, 'Set the shape of the flow input as (B, F, C, H, W)'
|
737 |
+
assert sf_x is not None and sf_y is not None
|
738 |
+
|
739 |
+
flow_orig = flow.detach().clone()
|
740 |
+
flow_x = flow[:, :, 0].detach().clone()
|
741 |
+
flow_y = flow[:, :, 1].detach().clone()
|
742 |
+
|
743 |
+
flow_orig[:, :, 0] = torch.sign(flow_x) * sf_x * (flow_x**2 - 1e-7)
|
744 |
+
flow_orig[:, :, 1] = torch.sign(flow_y) * sf_y * (flow_y**2 - 1e-7)
|
745 |
+
|
746 |
+
return flow_orig
|
747 |
+
|
748 |
+
#--------------------------------------------------------------------------------------------------
|
finetune/models/cogvideox_i2v/flovd_controlnet_trainer.py
ADDED
@@ -0,0 +1,814 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Tuple
|
2 |
+
from pathlib import Path
|
3 |
+
import os
|
4 |
+
import hashlib
|
5 |
+
import json
|
6 |
+
import random
|
7 |
+
import wandb
|
8 |
+
import math
|
9 |
+
import numpy as np
|
10 |
+
from einops import rearrange, repeat
|
11 |
+
from safetensors.torch import load_file, save_file
|
12 |
+
from accelerate.logging import get_logger
|
13 |
+
|
14 |
+
import torch
|
15 |
+
|
16 |
+
from accelerate.utils import gather_object
|
17 |
+
|
18 |
+
from diffusers import (
|
19 |
+
AutoencoderKLCogVideoX,
|
20 |
+
CogVideoXDPMScheduler,
|
21 |
+
CogVideoXImageToVideoPipeline,
|
22 |
+
CogVideoXTransformer3DModel,
|
23 |
+
)
|
24 |
+
from diffusers.utils.export_utils import export_to_video
|
25 |
+
|
26 |
+
from finetune.pipeline.flovd_FVSM_cogvideox_controlnet_pipeline import FloVDCogVideoXControlnetImageToVideoPipeline
|
27 |
+
from finetune.constants import LOG_LEVEL, LOG_NAME
|
28 |
+
|
29 |
+
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
30 |
+
from PIL import Image
|
31 |
+
from numpy import dtype
|
32 |
+
from transformers import AutoTokenizer, T5EncoderModel
|
33 |
+
from typing_extensions import override
|
34 |
+
|
35 |
+
from finetune.schemas import Args, Components, State
|
36 |
+
from finetune.trainer import Trainer
|
37 |
+
from finetune.utils import (
|
38 |
+
cast_training_params,
|
39 |
+
free_memory,
|
40 |
+
get_memory_statistics,
|
41 |
+
string_to_filename,
|
42 |
+
unwrap_model,
|
43 |
+
)
|
44 |
+
from finetune.datasets.utils import (
|
45 |
+
preprocess_image_with_resize,
|
46 |
+
load_binary_mask_compressed,
|
47 |
+
)
|
48 |
+
|
49 |
+
from finetune.modules.cogvideox_controlnet import CogVideoXControlnet
|
50 |
+
from finetune.modules.cogvideox_custom_model import CustomCogVideoXTransformer3DModel
|
51 |
+
from finetune.modules.camera_sampler import SampleManualCam
|
52 |
+
from finetune.modules.camera_flow_generator import CameraFlowGenerator
|
53 |
+
from finetune.modules.utils import get_camera_flow_generator_input, forward_bilinear_splatting
|
54 |
+
|
55 |
+
from ..utils import register
|
56 |
+
|
57 |
+
import sys
|
58 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
59 |
+
|
60 |
+
import pdb
|
61 |
+
|
62 |
+
logger = get_logger(LOG_NAME, LOG_LEVEL)
|
63 |
+
|
64 |
+
class FloVDCogVideoXI2VControlnetTrainer(Trainer):
|
65 |
+
UNLOAD_LIST = ["text_encoder"]
|
66 |
+
|
67 |
+
@override
|
68 |
+
def __init__(self, args: Args) -> None:
|
69 |
+
super().__init__(args)
|
70 |
+
|
71 |
+
# For validation
|
72 |
+
self.CameraSampler = SampleManualCam()
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
@override
|
77 |
+
def load_components(self) -> Dict[str, Any]:
|
78 |
+
# TODO. Change the pipeline and ...
|
79 |
+
components = Components()
|
80 |
+
model_path = str(self.args.model_path)
|
81 |
+
|
82 |
+
components.pipeline_cls = FloVDCogVideoXControlnetImageToVideoPipeline
|
83 |
+
|
84 |
+
components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
|
85 |
+
|
86 |
+
components.text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder")
|
87 |
+
|
88 |
+
# components.transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer")
|
89 |
+
|
90 |
+
components.transformer = CustomCogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer")
|
91 |
+
|
92 |
+
additional_kwargs = {
|
93 |
+
'num_layers': self.args.controlnet_transformer_num_layers,
|
94 |
+
'out_proj_dim_factor': self.args.controlnet_out_proj_dim_factor,
|
95 |
+
'out_proj_dim_zero_init': self.args.controlnet_out_proj_zero_init,
|
96 |
+
'notextinflow': self.args.notextinflow,
|
97 |
+
}
|
98 |
+
components.controlnet = CogVideoXControlnet.from_pretrained(model_path, subfolder="transformer", **additional_kwargs)
|
99 |
+
|
100 |
+
components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
|
101 |
+
|
102 |
+
components.scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
|
103 |
+
|
104 |
+
return components
|
105 |
+
|
106 |
+
|
107 |
+
@override
|
108 |
+
def initialize_pipeline(self) -> FloVDCogVideoXControlnetImageToVideoPipeline:
|
109 |
+
# TODO. Change the pipeline and ...
|
110 |
+
pipe = FloVDCogVideoXControlnetImageToVideoPipeline(
|
111 |
+
tokenizer=self.components.tokenizer,
|
112 |
+
text_encoder=unwrap_model(self.accelerator, self.components.text_encoder),
|
113 |
+
vae=unwrap_model(self.accelerator, self.components.vae),
|
114 |
+
transformer=unwrap_model(self.accelerator, self.components.transformer),
|
115 |
+
controlnet=unwrap_model(self.accelerator, self.components.controlnet),
|
116 |
+
scheduler=self.components.scheduler,
|
117 |
+
)
|
118 |
+
return pipe
|
119 |
+
|
120 |
+
def initialize_flow_generator(self, ckpt_path):
|
121 |
+
depth_estimator_kwargs = {
|
122 |
+
"target": 'modules.depth_warping.depth_warping.DepthWarping_wrapper',
|
123 |
+
"kwargs": {
|
124 |
+
"ckpt_path": ckpt_path,
|
125 |
+
"model_config": {
|
126 |
+
"max_depth": 20,
|
127 |
+
"encoder": 'vitb',
|
128 |
+
"features": 128,
|
129 |
+
"out_channels": [96, 192, 384, 768],
|
130 |
+
}
|
131 |
+
|
132 |
+
}
|
133 |
+
}
|
134 |
+
|
135 |
+
return CameraFlowGenerator(depth_estimator_kwargs)
|
136 |
+
|
137 |
+
@override
|
138 |
+
def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
|
139 |
+
ret = {"encoded_videos": [], "prompt_embedding": [], "images": [], "encoded_flow": []}
|
140 |
+
|
141 |
+
for sample in samples:
|
142 |
+
encoded_video = sample["encoded_video"]
|
143 |
+
prompt_embedding = sample["prompt_embedding"]
|
144 |
+
image = sample["image"]
|
145 |
+
encoded_flow = sample["encoded_flow"]
|
146 |
+
|
147 |
+
ret["encoded_videos"].append(encoded_video)
|
148 |
+
ret["prompt_embedding"].append(prompt_embedding)
|
149 |
+
ret["images"].append(image)
|
150 |
+
ret["encoded_flow"].append(encoded_flow)
|
151 |
+
|
152 |
+
|
153 |
+
ret["encoded_videos"] = torch.stack(ret["encoded_videos"])
|
154 |
+
ret["prompt_embedding"] = torch.stack(ret["prompt_embedding"])
|
155 |
+
ret["images"] = torch.stack(ret["images"])
|
156 |
+
ret["encoded_flow"] = torch.stack(ret["encoded_flow"])
|
157 |
+
|
158 |
+
return ret
|
159 |
+
|
160 |
+
|
161 |
+
@override
|
162 |
+
def compute_loss(self, batch) -> torch.Tensor:
|
163 |
+
prompt_embedding = batch["prompt_embedding"]
|
164 |
+
latent = batch["encoded_videos"]
|
165 |
+
images = batch["images"]
|
166 |
+
latent_flow = batch["encoded_flow"]
|
167 |
+
|
168 |
+
# Shape of prompt_embedding: [B, seq_len, hidden_size]
|
169 |
+
# Shape of latent: [B, C, F, H, W]
|
170 |
+
# Shape of images: [B, C, H, W]
|
171 |
+
# Shape of latent_flow: [B, C, F, H, W]
|
172 |
+
|
173 |
+
patch_size_t = self.state.transformer_config.patch_size_t # WJ: None in i2v setting...
|
174 |
+
if patch_size_t is not None:
|
175 |
+
ncopy = latent.shape[2] % patch_size_t
|
176 |
+
# Copy the first frame ncopy times to match patch_size_t
|
177 |
+
first_frame = latent[:, :, :1, :, :] # Get first frame [B, C, 1, H, W]
|
178 |
+
latent = torch.cat([first_frame.repeat(1, 1, ncopy, 1, 1), latent], dim=2)
|
179 |
+
assert latent.shape[2] % patch_size_t == 0
|
180 |
+
|
181 |
+
batch_size, num_channels, num_frames, height, width = latent.shape
|
182 |
+
|
183 |
+
# Get prompt embeddings
|
184 |
+
_, seq_len, _ = prompt_embedding.shape
|
185 |
+
prompt_embedding = prompt_embedding.view(batch_size, seq_len, -1).to(dtype=latent.dtype)
|
186 |
+
|
187 |
+
# Add frame dimension to images [B,C,H,W] -> [B,C,F,H,W]
|
188 |
+
images = images.unsqueeze(2)
|
189 |
+
# Add noise to images
|
190 |
+
image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=self.accelerator.device)
|
191 |
+
image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=images.dtype)
|
192 |
+
noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None]
|
193 |
+
image_latent_dist = self.components.vae.encode(noisy_images.to(dtype=self.components.vae.dtype)).latent_dist
|
194 |
+
image_latents = image_latent_dist.sample() * self.components.vae.config.scaling_factor
|
195 |
+
|
196 |
+
"""
|
197 |
+
Modify below
|
198 |
+
"""
|
199 |
+
# Sample a random timestep for each sample
|
200 |
+
# timesteps = torch.randint(
|
201 |
+
# 0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device
|
202 |
+
# )
|
203 |
+
if self.args.enable_time_sampling:
|
204 |
+
if self.args.time_sampling_type == "truncated_normal":
|
205 |
+
time_sampling_dict = {
|
206 |
+
'mean': self.args.time_sampling_mean,
|
207 |
+
'std': self.args.time_sampling_std,
|
208 |
+
'a': 1 - self.args.controlnet_guidance_end,
|
209 |
+
'b': 1 - self.args.controlnet_guidance_start,
|
210 |
+
}
|
211 |
+
timesteps = torch.nn.init.trunc_normal_(
|
212 |
+
torch.empty(batch_size, device=latent.device), **time_sampling_dict
|
213 |
+
) * self.components.scheduler.config.num_train_timesteps
|
214 |
+
elif self.args.time_sampling_type == "truncated_uniform":
|
215 |
+
timesteps = torch.randint(
|
216 |
+
int((1- self.args.controlnet_guidance_end) * self.components.scheduler.config.num_train_timesteps),
|
217 |
+
int((1 - self.args.controlnet_guidance_start) * self.components.scheduler.config.num_train_timesteps),
|
218 |
+
(batch_size,), device=latent.device
|
219 |
+
)
|
220 |
+
else:
|
221 |
+
timesteps = torch.randint(
|
222 |
+
0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device
|
223 |
+
)
|
224 |
+
timesteps = timesteps.long()
|
225 |
+
|
226 |
+
# from [B, C, F, H, W] to [B, F, C, H, W]
|
227 |
+
latent = latent.permute(0, 2, 1, 3, 4)
|
228 |
+
latent_flow = latent_flow.permute(0, 2, 1, 3, 4)
|
229 |
+
image_latents = image_latents.permute(0, 2, 1, 3, 4)
|
230 |
+
assert (latent.shape[0], *latent.shape[2:]) == (image_latents.shape[0], *image_latents.shape[2:]) == (latent_flow.shape[0], *latent_flow.shape[2:])
|
231 |
+
|
232 |
+
# Padding image_latents to the same frame number as latent
|
233 |
+
padding_shape = (latent.shape[0], latent.shape[1] - 1, *latent.shape[2:])
|
234 |
+
latent_padding = image_latents.new_zeros(padding_shape)
|
235 |
+
image_latents = torch.cat([image_latents, latent_padding], dim=1)
|
236 |
+
|
237 |
+
# Add noise to latent
|
238 |
+
noise = torch.randn_like(latent)
|
239 |
+
latent_noisy = self.components.scheduler.add_noise(latent, noise, timesteps)
|
240 |
+
|
241 |
+
|
242 |
+
# Concatenate latent and image_latents in the channel dimension
|
243 |
+
# latent_img_flow_noisy = torch.cat([latent_noisy, image_latents, latent_flow], dim=2)
|
244 |
+
latent_img_noisy = torch.cat([latent_noisy, image_latents], dim=2)
|
245 |
+
|
246 |
+
# Prepare rotary embeds
|
247 |
+
vae_scale_factor_spatial = 2 ** (len(self.components.vae.config.block_out_channels) - 1)
|
248 |
+
transformer_config = self.state.transformer_config
|
249 |
+
rotary_emb = (
|
250 |
+
self.prepare_rotary_positional_embeddings(
|
251 |
+
height=height * vae_scale_factor_spatial,
|
252 |
+
width=width * vae_scale_factor_spatial,
|
253 |
+
num_frames=num_frames,
|
254 |
+
transformer_config=transformer_config,
|
255 |
+
vae_scale_factor_spatial=vae_scale_factor_spatial,
|
256 |
+
device=self.accelerator.device,
|
257 |
+
)
|
258 |
+
if transformer_config.use_rotary_positional_embeddings
|
259 |
+
else None
|
260 |
+
)
|
261 |
+
|
262 |
+
# Predict noise, For CogVideoX1.5 Only.
|
263 |
+
ofs_emb = (
|
264 |
+
None if self.state.transformer_config.ofs_embed_dim is None else latent.new_full((1,), fill_value=2.0)
|
265 |
+
)
|
266 |
+
|
267 |
+
# Controlnet feedforward
|
268 |
+
controlnet_states = self.components.controlnet(
|
269 |
+
hidden_states=latent_noisy,
|
270 |
+
encoder_hidden_states=prompt_embedding,
|
271 |
+
image_rotary_emb=rotary_emb,
|
272 |
+
controlnet_hidden_states=latent_flow,
|
273 |
+
timestep=timesteps,
|
274 |
+
return_dict=False,
|
275 |
+
)[0]
|
276 |
+
if isinstance(controlnet_states, (tuple, list)):
|
277 |
+
controlnet_states = [x.to(dtype=self.state.weight_dtype) for x in controlnet_states]
|
278 |
+
else:
|
279 |
+
controlnet_states = controlnet_states.to(dtype=self.state.weight_dtype)
|
280 |
+
|
281 |
+
|
282 |
+
# Transformer feedforward
|
283 |
+
predicted_noise = self.components.transformer(
|
284 |
+
hidden_states=latent_img_noisy,
|
285 |
+
encoder_hidden_states=prompt_embedding,
|
286 |
+
controlnet_states=controlnet_states,
|
287 |
+
controlnet_weights=self.args.controlnet_weights,
|
288 |
+
timestep=timesteps,
|
289 |
+
# ofs=ofs_emb,
|
290 |
+
image_rotary_emb=rotary_emb,
|
291 |
+
return_dict=False,
|
292 |
+
)[0]
|
293 |
+
|
294 |
+
|
295 |
+
# Denoise
|
296 |
+
latent_pred = self.components.scheduler.get_velocity(predicted_noise, latent_noisy, timesteps)
|
297 |
+
|
298 |
+
alphas_cumprod = self.components.scheduler.alphas_cumprod[timesteps]
|
299 |
+
weights = 1 / (1 - alphas_cumprod)
|
300 |
+
while len(weights.shape) < len(latent_pred.shape):
|
301 |
+
weights = weights.unsqueeze(-1)
|
302 |
+
|
303 |
+
loss = torch.mean((weights * (latent_pred - latent) ** 2).reshape(batch_size, -1), dim=1)
|
304 |
+
loss = loss.mean()
|
305 |
+
|
306 |
+
return loss
|
307 |
+
|
308 |
+
def prepare_rotary_positional_embeddings(
|
309 |
+
self,
|
310 |
+
height: int,
|
311 |
+
width: int,
|
312 |
+
num_frames: int,
|
313 |
+
transformer_config: Dict,
|
314 |
+
vae_scale_factor_spatial: int,
|
315 |
+
device: torch.device,
|
316 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
317 |
+
grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
|
318 |
+
grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)
|
319 |
+
|
320 |
+
if transformer_config.patch_size_t is None:
|
321 |
+
base_num_frames = num_frames
|
322 |
+
else:
|
323 |
+
base_num_frames = (num_frames + transformer_config.patch_size_t - 1) // transformer_config.patch_size_t
|
324 |
+
|
325 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
326 |
+
embed_dim=transformer_config.attention_head_dim,
|
327 |
+
crops_coords=None,
|
328 |
+
grid_size=(grid_height, grid_width),
|
329 |
+
temporal_size=base_num_frames,
|
330 |
+
grid_type="slice",
|
331 |
+
max_size=(grid_height, grid_width),
|
332 |
+
device=device,
|
333 |
+
)
|
334 |
+
|
335 |
+
return freqs_cos, freqs_sin
|
336 |
+
|
337 |
+
# Validation
|
338 |
+
|
339 |
+
@override
|
340 |
+
def prepare_for_validation(self):
|
341 |
+
# Load from dataset?
|
342 |
+
# Data_root
|
343 |
+
# - metadata.jsonl
|
344 |
+
# - video_latent / args.resolution /
|
345 |
+
# - prompt_embeddings /
|
346 |
+
# - first_frames /
|
347 |
+
# - flow_direct_f_latent /
|
348 |
+
|
349 |
+
data_root = self.args.data_root
|
350 |
+
metadata_path = data_root / "metadata_revised.jsonl"
|
351 |
+
assert metadata_path.is_file(), "For this dataset type, you need metadata.jsonl or metadata_revised.jsonl in the root path"
|
352 |
+
|
353 |
+
# Load metadata
|
354 |
+
# metadata = {
|
355 |
+
# "video_path": ...,
|
356 |
+
# "hash_code": ...,
|
357 |
+
# "prompt": ...,
|
358 |
+
# }
|
359 |
+
metadata = []
|
360 |
+
with open(metadata_path, "r") as f:
|
361 |
+
for line in f:
|
362 |
+
metadata.append( json.loads(line) )
|
363 |
+
|
364 |
+
metadata = random.sample(metadata, self.args.max_scene)
|
365 |
+
|
366 |
+
prompts = [x["prompt"] for x in metadata]
|
367 |
+
prompt_embeddings = [data_root / "prompt_embeddings_revised" / (x["hash_code"] + '.safetensors') for x in metadata]
|
368 |
+
videos = [data_root / "video_latent" / "x".join(str(x) for x in self.args.train_resolution) / (x["hash_code"] + '.safetensors') for x in metadata]
|
369 |
+
images = [data_root / "first_frames" / (x["hash_code"] + '.png') for x in metadata]
|
370 |
+
flows = [data_root / "flow_direct_f_latent" / (x["hash_code"] + '.safetensors') for x in metadata]
|
371 |
+
|
372 |
+
# load prompt embedding
|
373 |
+
validation_prompts = []
|
374 |
+
validation_prompt_embeddings = []
|
375 |
+
validation_video_latents = []
|
376 |
+
validation_images = []
|
377 |
+
validation_flow_latents = []
|
378 |
+
for prompt, prompt_embedding, video_latent, image, flow_latent in zip(prompts, prompt_embeddings, videos, images, flows):
|
379 |
+
validation_prompts.append(prompt)
|
380 |
+
validation_prompt_embeddings.append(load_file(prompt_embedding)["prompt_embedding"].unsqueeze(0))
|
381 |
+
validation_video_latents.append(load_file(video_latent)["encoded_video"].unsqueeze(0))
|
382 |
+
validation_flow_latents.append(load_file(flow_latent)["encoded_flow_f"].unsqueeze(0))
|
383 |
+
# validation_images.append(preprocess_image_with_resize(image, self.args.train_resolution[1], self.args.train_resolution[2]))
|
384 |
+
validation_images.append(image)
|
385 |
+
|
386 |
+
|
387 |
+
validation_videos = [None] * len(validation_prompts)
|
388 |
+
|
389 |
+
|
390 |
+
self.state.validation_prompts = validation_prompts
|
391 |
+
self.state.validation_prompt_embeddings = validation_prompt_embeddings
|
392 |
+
self.state.validation_images = validation_images
|
393 |
+
self.state.validation_videos = validation_videos
|
394 |
+
self.state.validation_video_latents = validation_video_latents
|
395 |
+
self.state.validation_flow_latents = validation_flow_latents
|
396 |
+
|
397 |
+
# Debug..
|
398 |
+
# self.validate(0)
|
399 |
+
|
400 |
+
|
401 |
+
@override
|
402 |
+
def validation_step(
|
403 |
+
self, eval_data: Dict[str, Any], pipe: FloVDCogVideoXControlnetImageToVideoPipeline
|
404 |
+
) -> List[Tuple[str, Image.Image | List[Image.Image]]]:
|
405 |
+
"""
|
406 |
+
Return the data that needs to be saved. For videos, the data format is List[PIL],
|
407 |
+
and for images, the data format is PIL
|
408 |
+
"""
|
409 |
+
|
410 |
+
prompt_embedding, image, flow_latent = eval_data["prompt_embedding"], eval_data["image"], eval_data["flow_latent"]
|
411 |
+
|
412 |
+
video_generate = pipe(
|
413 |
+
num_frames=self.state.train_frames,
|
414 |
+
height=self.state.train_height,
|
415 |
+
width=self.state.train_width,
|
416 |
+
prompt=None,
|
417 |
+
prompt_embeds=prompt_embedding,
|
418 |
+
image=image,
|
419 |
+
flow_latent=flow_latent,
|
420 |
+
generator=self.state.generator,
|
421 |
+
num_inference_steps=50,
|
422 |
+
controlnet_guidance_start = self.args.controlnet_guidance_start,
|
423 |
+
controlnet_guidance_end = self.args.controlnet_guidance_end,
|
424 |
+
).frames[0]
|
425 |
+
return [("synthesized_video", video_generate)]
|
426 |
+
|
427 |
+
|
428 |
+
@override
|
429 |
+
def validate(self, step: int) -> None:
|
430 |
+
#TODO. Fix the codes!!!!
|
431 |
+
logger.info("Starting validation")
|
432 |
+
|
433 |
+
accelerator = self.accelerator
|
434 |
+
num_validation_samples = len(self.state.validation_prompts)
|
435 |
+
|
436 |
+
if num_validation_samples == 0:
|
437 |
+
logger.warning("No validation samples found. Skipping validation.")
|
438 |
+
return
|
439 |
+
|
440 |
+
self.components.controlnet.eval()
|
441 |
+
torch.set_grad_enabled(False)
|
442 |
+
|
443 |
+
memory_statistics = get_memory_statistics()
|
444 |
+
logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}")
|
445 |
+
|
446 |
+
##### Initialize pipeline #####
|
447 |
+
pipe = self.initialize_pipeline()
|
448 |
+
camera_flow_generator = self.initialize_flow_generator(ckpt_path=self.args.depth_ckpt_path).to(device=self.accelerator.device, dtype=self.state.weight_dtype)
|
449 |
+
|
450 |
+
if self.state.using_deepspeed:
|
451 |
+
# Can't using model_cpu_offload in deepspeed,
|
452 |
+
# so we need to move all components in pipe to device
|
453 |
+
# pipe.to(self.accelerator.device, dtype=self.state.weight_dtype)
|
454 |
+
self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=["controlnet"])
|
455 |
+
# self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=["transformer", "controlnet"])
|
456 |
+
else:
|
457 |
+
# if not using deepspeed, use model_cpu_offload to further reduce memory usage
|
458 |
+
# Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage
|
459 |
+
pipe.enable_model_cpu_offload(device=self.accelerator.device)
|
460 |
+
|
461 |
+
# Convert all model weights to training dtype
|
462 |
+
# Note, this will change LoRA weights in self.components.transformer to training dtype, rather than keep them in fp32
|
463 |
+
pipe = pipe.to(dtype=self.state.weight_dtype)
|
464 |
+
|
465 |
+
|
466 |
+
#################################
|
467 |
+
inference_type = ['training', 'inference']
|
468 |
+
# inference_type = ['inference']
|
469 |
+
for infer_type in inference_type:
|
470 |
+
|
471 |
+
|
472 |
+
all_processes_artifacts = []
|
473 |
+
for i in range(num_validation_samples):
|
474 |
+
if self.state.using_deepspeed and self.accelerator.deepspeed_plugin.zero_stage != 3:
|
475 |
+
# Skip current validation on all processes but one
|
476 |
+
if i % accelerator.num_processes != accelerator.process_index:
|
477 |
+
continue
|
478 |
+
|
479 |
+
prompt = self.state.validation_prompts[i]
|
480 |
+
image = self.state.validation_images[i]
|
481 |
+
video = self.state.validation_videos[i]
|
482 |
+
video_latent = self.state.validation_video_latents[i].permute(0,2,1,3,4) # [B,F,C,H,W] (e.g., [B, 13, 16, 60, 90])
|
483 |
+
prompt_embedding = self.state.validation_prompt_embeddings[i]
|
484 |
+
flow_latent = self.state.validation_flow_latents[i].permute(0,2,1,3,4) # [B,F,C,H,W] (e.g., [B, 13, 16, 60, 90])
|
485 |
+
|
486 |
+
|
487 |
+
if image is not None:
|
488 |
+
image = preprocess_image_with_resize(image, self.state.train_height, self.state.train_width)
|
489 |
+
image_torch = image.detach().clone()
|
490 |
+
# Convert image tensor (C, H, W) to PIL images
|
491 |
+
image = image.to(torch.uint8)
|
492 |
+
image = image.permute(1, 2, 0).cpu().numpy()
|
493 |
+
image = Image.fromarray(image)
|
494 |
+
|
495 |
+
if video is not None:
|
496 |
+
video = preprocess_video_with_resize(
|
497 |
+
video, self.state.train_frames, self.state.train_height, self.state.train_width
|
498 |
+
)
|
499 |
+
# Convert video tensor (F, C, H, W) to list of PIL images
|
500 |
+
video = video.round().clamp(0, 255).to(torch.uint8)
|
501 |
+
video = [Image.fromarray(frame.permute(1, 2, 0).cpu().numpy()) for frame in video]
|
502 |
+
else:
|
503 |
+
if infer_type == 'training':
|
504 |
+
with torch.cuda.amp.autocast(enabled=True, dtype=self.state.weight_dtype):
|
505 |
+
try:
|
506 |
+
video_decoded = decode_latents(video_latent.to(self.accelerator.device), self.components.vae)
|
507 |
+
except:
|
508 |
+
pass
|
509 |
+
video_decoded = decode_latents(video_latent.to(self.accelerator.device), self.components.vae)
|
510 |
+
video = ((video_decoded + 1.) / 2. * 255.)[0].permute(1,0,2,3).float().clip(0., 255.).to(torch.uint8)
|
511 |
+
video = [Image.fromarray(frame.permute(1, 2, 0).cpu().numpy()) for frame in video]
|
512 |
+
|
513 |
+
with torch.cuda.amp.autocast(enabled=True, dtype=self.state.weight_dtype):
|
514 |
+
try:
|
515 |
+
flow_decoded = decode_flow(flow_latent.to(self.accelerator.device), self.components.vae, flow_scale_factor=[60, 36])
|
516 |
+
except:
|
517 |
+
pass
|
518 |
+
flow_decoded = decode_flow(flow_latent.to(self.accelerator.device), self.components.vae, flow_scale_factor=[60, 36]) # (BF)CHW (C=2)
|
519 |
+
|
520 |
+
|
521 |
+
# Prepare camera flow
|
522 |
+
if infer_type == 'inference':
|
523 |
+
with torch.cuda.amp.autocast(enabled=True, dtype=self.state.weight_dtype):
|
524 |
+
camparam, cam_name = self.CameraSampler.sample()
|
525 |
+
camera_flow_generator_input = get_camera_flow_generator_input(image_torch, camparam, device=self.accelerator.device, speed=0.5)
|
526 |
+
image_torch = ((image_torch.unsqueeze(0) / 255.) * 2. - 1.).to(self.accelerator.device)
|
527 |
+
camera_flow, log_dict = camera_flow_generator(image_torch, camera_flow_generator_input)
|
528 |
+
camera_flow = camera_flow.to(self.accelerator.device)
|
529 |
+
# WTF, unknown bug. Need warm up inference.
|
530 |
+
try:
|
531 |
+
flow_latent = rearrange(encode_flow(camera_flow, self.components.vae, flow_scale_factor=[60, 36]), 'b c f h w -> b f c h w').to(self.accelerator.device, self.state.weight_dtype)
|
532 |
+
except:
|
533 |
+
pass
|
534 |
+
flow_latent = rearrange(encode_flow(camera_flow, self.components.vae, flow_scale_factor=[60, 36]), 'b c f h w -> b f c h w').to(self.accelerator.device, self.state.weight_dtype)
|
535 |
+
|
536 |
+
|
537 |
+
logger.debug(
|
538 |
+
f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}",
|
539 |
+
main_process_only=False,
|
540 |
+
)
|
541 |
+
# validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe)
|
542 |
+
validation_artifacts = self.validation_step({"prompt_embedding": prompt_embedding, "image": image, "flow_latent": flow_latent}, pipe)
|
543 |
+
|
544 |
+
if (
|
545 |
+
self.state.using_deepspeed
|
546 |
+
and self.accelerator.deepspeed_plugin.zero_stage == 3
|
547 |
+
and not accelerator.is_main_process
|
548 |
+
):
|
549 |
+
continue
|
550 |
+
|
551 |
+
prompt_filename = string_to_filename(prompt)[:25]
|
552 |
+
# Calculate hash of reversed prompt as a unique identifier
|
553 |
+
reversed_prompt = prompt[::-1]
|
554 |
+
hash_suffix = hashlib.md5(reversed_prompt.encode()).hexdigest()[:5]
|
555 |
+
|
556 |
+
artifacts = {
|
557 |
+
"image": {"type": "image", "value": image},
|
558 |
+
"video": {"type": "video", "value": video},
|
559 |
+
}
|
560 |
+
for i, (artifact_type, artifact_value) in enumerate(validation_artifacts):
|
561 |
+
artifacts.update({f"artifact_{i}": {"type": artifact_type, "value": artifact_value}})
|
562 |
+
if infer_type == 'training':
|
563 |
+
# Log flow_warped_frames
|
564 |
+
image_tensor = repeat(rearrange(torch.tensor(np.array(image)).to(flow_decoded.device, torch.float), 'h w c -> 1 c h w'), 'b c h w -> (b f) c h w', f=flow_decoded.size(0)) # scale~(0,255) (BF) C H W
|
565 |
+
warped_video = forward_bilinear_splatting(image_tensor, flow_decoded.to(torch.float)) # if we have an occlusion mask from dataset, we can use it.
|
566 |
+
frame_list = []
|
567 |
+
for frame in warped_video:
|
568 |
+
frame = (frame.permute(1,2,0).float().detach().cpu().numpy()).astype(np.uint8).clip(0,255)
|
569 |
+
frame_list.append(Image.fromarray(frame))
|
570 |
+
|
571 |
+
artifacts.update({f"artifact_warped_video_{i}": {"type": 'warped_video', "value": frame_list}})
|
572 |
+
|
573 |
+
if infer_type == 'inference':
|
574 |
+
warped_video = log_dict['depth_warped_frames']
|
575 |
+
frame_list = []
|
576 |
+
for frame in warped_video:
|
577 |
+
frame = (frame + 1.)/2. * 255.
|
578 |
+
frame = (frame.permute(1,2,0).float().detach().cpu().numpy()).astype(np.uint8).clip(0,255)
|
579 |
+
frame_list.append(Image.fromarray(frame))
|
580 |
+
|
581 |
+
artifacts.update({f"artifact_warped_video_{i}": {"type": 'warped_video', "value": frame_list}})
|
582 |
+
logger.debug(
|
583 |
+
f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}",
|
584 |
+
main_process_only=False,
|
585 |
+
)
|
586 |
+
|
587 |
+
for key, value in list(artifacts.items()):
|
588 |
+
artifact_type = value["type"]
|
589 |
+
artifact_value = value["value"]
|
590 |
+
if artifact_type not in ["image", "video", "warped_video", "synthesized_video"] or artifact_value is None:
|
591 |
+
continue
|
592 |
+
|
593 |
+
extension = "png" if artifact_type == "image" else "mp4"
|
594 |
+
if artifact_type == "warped_video":
|
595 |
+
filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}-{hash_suffix}-{infer_type}_warped_video.{extension}"
|
596 |
+
elif artifact_type == "synthesized_video":
|
597 |
+
filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}-{hash_suffix}-{infer_type}_synthesized_video.{extension}"
|
598 |
+
else:
|
599 |
+
filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}-{hash_suffix}-{infer_type}.{extension}"
|
600 |
+
validation_path = self.args.output_dir / "validation_res"
|
601 |
+
validation_path.mkdir(parents=True, exist_ok=True)
|
602 |
+
filename = str(validation_path / filename)
|
603 |
+
|
604 |
+
if artifact_type == "image":
|
605 |
+
logger.debug(f"Saving image to {filename}")
|
606 |
+
artifact_value.save(filename)
|
607 |
+
artifact_value = wandb.Image(filename)
|
608 |
+
elif artifact_type == "video" or artifact_type == "warped_video" or artifact_type == "synthesized_video":
|
609 |
+
logger.debug(f"Saving video to {filename}")
|
610 |
+
export_to_video(artifact_value, filename, fps=self.args.gen_fps)
|
611 |
+
artifact_value = wandb.Video(filename, caption=prompt)
|
612 |
+
|
613 |
+
all_processes_artifacts.append(artifact_value)
|
614 |
+
|
615 |
+
all_artifacts = gather_object(all_processes_artifacts)
|
616 |
+
|
617 |
+
if accelerator.is_main_process:
|
618 |
+
tracker_key = "validation"
|
619 |
+
for tracker in accelerator.trackers:
|
620 |
+
if tracker.name == "wandb":
|
621 |
+
image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)]
|
622 |
+
video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)]
|
623 |
+
tracker.log(
|
624 |
+
{
|
625 |
+
tracker_key: {f"images_{infer_type}": image_artifacts, f"videos_{infer_type}": video_artifacts},
|
626 |
+
},
|
627 |
+
step=step,
|
628 |
+
)
|
629 |
+
|
630 |
+
########## Clean up ##########
|
631 |
+
if self.state.using_deepspeed:
|
632 |
+
del pipe
|
633 |
+
# Unload models except those needed for training
|
634 |
+
self.__move_components_to_cpu(unload_list=self.UNLOAD_LIST)
|
635 |
+
else:
|
636 |
+
pipe.remove_all_hooks()
|
637 |
+
del pipe
|
638 |
+
# Load models except those not needed for training
|
639 |
+
self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=self.UNLOAD_LIST)
|
640 |
+
self.components.controlnet.to(self.accelerator.device, dtype=self.state.weight_dtype)
|
641 |
+
|
642 |
+
# Change trainable weights back to fp32 to keep with dtype after prepare the model
|
643 |
+
cast_training_params([self.components.controlnet], dtype=torch.float32)
|
644 |
+
|
645 |
+
del camera_flow_generator
|
646 |
+
|
647 |
+
free_memory()
|
648 |
+
accelerator.wait_for_everyone()
|
649 |
+
################################
|
650 |
+
|
651 |
+
memory_statistics = get_memory_statistics()
|
652 |
+
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
|
653 |
+
torch.cuda.reset_peak_memory_stats(accelerator.device)
|
654 |
+
|
655 |
+
torch.set_grad_enabled(True)
|
656 |
+
self.components.controlnet.train()
|
657 |
+
|
658 |
+
|
659 |
+
# mangling
|
660 |
+
def __move_components_to_device(self, dtype, ignore_list: List[str] = []):
|
661 |
+
ignore_list = set(ignore_list)
|
662 |
+
components = self.components.model_dump()
|
663 |
+
for name, component in components.items():
|
664 |
+
if not isinstance(component, type) and hasattr(component, "to"):
|
665 |
+
if name not in ignore_list:
|
666 |
+
setattr(self.components, name, component.to(self.accelerator.device, dtype=dtype))
|
667 |
+
|
668 |
+
# mangling
|
669 |
+
def __move_components_to_cpu(self, unload_list: List[str] = []):
|
670 |
+
unload_list = set(unload_list)
|
671 |
+
components = self.components.model_dump()
|
672 |
+
for name, component in components.items():
|
673 |
+
if not isinstance(component, type) and hasattr(component, "to"):
|
674 |
+
if name in unload_list:
|
675 |
+
setattr(self.components, name, component.to("cpu"))
|
676 |
+
|
677 |
+
|
678 |
+
register("cogvideox-flovd", "controlnet", FloVDCogVideoXI2VControlnetTrainer)
|
679 |
+
|
680 |
+
|
681 |
+
#--------------------------------------------------------------------------------------------------
|
682 |
+
# Extract function
|
683 |
+
def encode_text(prompt: str, components, device) -> torch.Tensor:
|
684 |
+
prompt_token_ids = components.tokenizer(
|
685 |
+
prompt,
|
686 |
+
padding="max_length",
|
687 |
+
max_length=components.transformer.config.max_text_seq_length,
|
688 |
+
truncation=True,
|
689 |
+
add_special_tokens=True,
|
690 |
+
return_tensors="pt",
|
691 |
+
)
|
692 |
+
prompt_token_ids = prompt_token_ids.input_ids
|
693 |
+
prompt_embedding = components.text_encoder(prompt_token_ids.to(device))[0]
|
694 |
+
return prompt_embedding
|
695 |
+
|
696 |
+
def encode_video(video: torch.Tensor, vae) -> torch.Tensor:
|
697 |
+
# shape of input video: [B, C, F, H, W]
|
698 |
+
video = video.to(vae.device, dtype=vae.dtype)
|
699 |
+
latent_dist = vae.encode(video).latent_dist
|
700 |
+
latent = latent_dist.sample() * vae.config.scaling_factor
|
701 |
+
return latent
|
702 |
+
|
703 |
+
def decode_latents(latents: torch.Tensor, vae) -> torch.Tensor:
|
704 |
+
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
705 |
+
latents = 1 / vae.config.scaling_factor * latents
|
706 |
+
|
707 |
+
frames = vae.decode(latents).sample
|
708 |
+
return frames
|
709 |
+
|
710 |
+
def compute_optical_flow(raft, ctxt, trgt, raft_iter=20, chunk=2, only_forward=True):
|
711 |
+
num_frames = ctxt.shape[0]
|
712 |
+
chunk_size = (num_frames // chunk) + 1
|
713 |
+
|
714 |
+
flow_f_list = []
|
715 |
+
if not only_forward:
|
716 |
+
flow_b_list = []
|
717 |
+
for i in range(chunk):
|
718 |
+
start = chunk_size * i
|
719 |
+
end = chunk_size * (i+1)
|
720 |
+
|
721 |
+
with torch.no_grad():
|
722 |
+
flow_f = raft(ctxt[start:end], trgt[start:end], num_flow_updates=raft_iter)[-1]
|
723 |
+
if not only_forward:
|
724 |
+
flow_b = raft(trgt[start:end], ctxt[start:end], num_flow_updates=raft_iter)[-1]
|
725 |
+
|
726 |
+
flow_f_list.append(flow_f)
|
727 |
+
if not only_forward:
|
728 |
+
flow_b_list.append(flow_b)
|
729 |
+
|
730 |
+
flow_f = torch.cat(flow_f_list)
|
731 |
+
if not only_forward:
|
732 |
+
flow_b = torch.cat(flow_b_list)
|
733 |
+
|
734 |
+
if not only_forward:
|
735 |
+
return flow_f, flow_b
|
736 |
+
else:
|
737 |
+
return flow_f, None
|
738 |
+
|
739 |
+
def encode_flow(flow, vae, flow_scale_factor):
|
740 |
+
# flow: BF,C,H,W
|
741 |
+
# flow_scale_factor [sf_x, sf_y]
|
742 |
+
assert flow.ndim == 4
|
743 |
+
num_frames, _, height, width = flow.shape
|
744 |
+
|
745 |
+
# Normalize optical flow
|
746 |
+
# ndim: 4 -> 5
|
747 |
+
flow = rearrange(flow, '(b f) c h w -> b f c h w', b=1)
|
748 |
+
flow_norm = adaptive_normalize(flow, flow_scale_factor[0], flow_scale_factor[1])
|
749 |
+
|
750 |
+
# ndim: 5 -> 4
|
751 |
+
flow_norm = rearrange(flow_norm, 'b f c h w -> (b f) c h w', b=1)
|
752 |
+
|
753 |
+
# Duplicate mean value for third channel
|
754 |
+
num_frames, _, H, W = flow_norm.shape
|
755 |
+
flow_norm_extended = torch.empty((num_frames, 3, height, width)).to(flow_norm)
|
756 |
+
flow_norm_extended[:,:2] = flow_norm
|
757 |
+
flow_norm_extended[:,-1:] = flow_norm.mean(dim=1, keepdim=True)
|
758 |
+
flow_norm_extended = rearrange(flow_norm_extended, '(b f) c h w -> b c f h w', f=num_frames)
|
759 |
+
|
760 |
+
return encode_video(flow_norm_extended, vae)
|
761 |
+
|
762 |
+
def decode_flow(flow_latent, vae, flow_scale_factor):
|
763 |
+
flow_latent = flow_latent.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
764 |
+
flow_latent = 1 / vae.config.scaling_factor * flow_latent
|
765 |
+
|
766 |
+
flow = vae.decode(flow_latent).sample # BCFHW
|
767 |
+
|
768 |
+
# discard third channel (which is a mean value of f_x and f_y)
|
769 |
+
flow = flow[:,:2].detach().clone()
|
770 |
+
|
771 |
+
# Unnormalize optical flow
|
772 |
+
flow = rearrange(flow, 'b c f h w -> b f c h w')
|
773 |
+
flow = adaptive_unnormalize(flow, flow_scale_factor[0], flow_scale_factor[1])
|
774 |
+
|
775 |
+
flow = rearrange(flow, 'b f c h w -> (b f) c h w')
|
776 |
+
return flow # BF,C,H,W
|
777 |
+
|
778 |
+
def adaptive_normalize(flow, sf_x, sf_y):
|
779 |
+
# x: BFCHW, optical flow
|
780 |
+
assert flow.ndim == 5, 'Set the shape of the flow input as (B, F, C, H, W)'
|
781 |
+
assert sf_x is not None and sf_y is not None
|
782 |
+
b, f, c, h, w = flow.shape
|
783 |
+
|
784 |
+
max_clip_x = math.sqrt(w/sf_x) * 1.0
|
785 |
+
max_clip_y = math.sqrt(h/sf_y) * 1.0
|
786 |
+
|
787 |
+
flow_norm = flow.detach().clone()
|
788 |
+
flow_x = flow[:, :, 0].detach().clone()
|
789 |
+
flow_y = flow[:, :, 1].detach().clone()
|
790 |
+
|
791 |
+
flow_x_norm = torch.sign(flow_x) * torch.sqrt(torch.abs(flow_x)/sf_x + 1e-7)
|
792 |
+
flow_y_norm = torch.sign(flow_y) * torch.sqrt(torch.abs(flow_y)/sf_y + 1e-7)
|
793 |
+
|
794 |
+
flow_norm[:, :, 0] = torch.clamp(flow_x_norm, min=-max_clip_x, max=max_clip_x)
|
795 |
+
flow_norm[:, :, 1] = torch.clamp(flow_y_norm, min=-max_clip_y, max=max_clip_y)
|
796 |
+
|
797 |
+
return flow_norm
|
798 |
+
|
799 |
+
|
800 |
+
def adaptive_unnormalize(flow, sf_x, sf_y):
|
801 |
+
# x: BFCHW, optical flow
|
802 |
+
assert flow.ndim == 5, 'Set the shape of the flow input as (B, F, C, H, W)'
|
803 |
+
assert sf_x is not None and sf_y is not None
|
804 |
+
|
805 |
+
flow_orig = flow.detach().clone()
|
806 |
+
flow_x = flow[:, :, 0].detach().clone()
|
807 |
+
flow_y = flow[:, :, 1].detach().clone()
|
808 |
+
|
809 |
+
flow_orig[:, :, 0] = torch.sign(flow_x) * sf_x * (flow_x**2 - 1e-7)
|
810 |
+
flow_orig[:, :, 1] = torch.sign(flow_y) * sf_y * (flow_y**2 - 1e-7)
|
811 |
+
|
812 |
+
return flow_orig
|
813 |
+
|
814 |
+
#--------------------------------------------------------------------------------------------------
|
finetune/models/cogvideox_i2v/lora_trainer.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from diffusers import (
|
5 |
+
AutoencoderKLCogVideoX,
|
6 |
+
CogVideoXDPMScheduler,
|
7 |
+
CogVideoXImageToVideoPipeline,
|
8 |
+
CogVideoXTransformer3DModel,
|
9 |
+
)
|
10 |
+
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
11 |
+
from PIL import Image
|
12 |
+
from numpy import dtype
|
13 |
+
from transformers import AutoTokenizer, T5EncoderModel
|
14 |
+
from typing_extensions import override
|
15 |
+
|
16 |
+
from finetune.schemas import Components
|
17 |
+
from finetune.trainer import Trainer
|
18 |
+
from finetune.utils import unwrap_model
|
19 |
+
|
20 |
+
from ..utils import register
|
21 |
+
|
22 |
+
|
23 |
+
class CogVideoXI2VLoraTrainer(Trainer):
|
24 |
+
UNLOAD_LIST = ["text_encoder"]
|
25 |
+
|
26 |
+
@override
|
27 |
+
def load_components(self) -> Dict[str, Any]:
|
28 |
+
components = Components()
|
29 |
+
model_path = str(self.args.model_path)
|
30 |
+
|
31 |
+
components.pipeline_cls = CogVideoXImageToVideoPipeline
|
32 |
+
|
33 |
+
components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
|
34 |
+
|
35 |
+
components.text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder")
|
36 |
+
|
37 |
+
components.transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer")
|
38 |
+
|
39 |
+
components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
|
40 |
+
|
41 |
+
components.scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
|
42 |
+
|
43 |
+
return components
|
44 |
+
|
45 |
+
@override
|
46 |
+
def initialize_pipeline(self) -> CogVideoXImageToVideoPipeline:
|
47 |
+
pipe = CogVideoXImageToVideoPipeline(
|
48 |
+
tokenizer=self.components.tokenizer,
|
49 |
+
text_encoder=self.components.text_encoder,
|
50 |
+
vae=self.components.vae,
|
51 |
+
transformer=unwrap_model(self.accelerator, self.components.transformer),
|
52 |
+
scheduler=self.components.scheduler,
|
53 |
+
)
|
54 |
+
return pipe
|
55 |
+
|
56 |
+
@override
|
57 |
+
def encode_video(self, video: torch.Tensor) -> torch.Tensor:
|
58 |
+
# shape of input video: [B, C, F, H, W]
|
59 |
+
vae = self.components.vae
|
60 |
+
video = video.to(vae.device, dtype=vae.dtype)
|
61 |
+
latent_dist = vae.encode(video).latent_dist
|
62 |
+
latent = latent_dist.sample() * vae.config.scaling_factor
|
63 |
+
return latent
|
64 |
+
|
65 |
+
@override
|
66 |
+
def encode_text(self, prompt: str) -> torch.Tensor:
|
67 |
+
prompt_token_ids = self.components.tokenizer(
|
68 |
+
prompt,
|
69 |
+
padding="max_length",
|
70 |
+
max_length=self.state.transformer_config.max_text_seq_length,
|
71 |
+
truncation=True,
|
72 |
+
add_special_tokens=True,
|
73 |
+
return_tensors="pt",
|
74 |
+
)
|
75 |
+
prompt_token_ids = prompt_token_ids.input_ids
|
76 |
+
prompt_embedding = self.components.text_encoder(prompt_token_ids.to(self.accelerator.device))[0]
|
77 |
+
return prompt_embedding
|
78 |
+
|
79 |
+
@override
|
80 |
+
def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
|
81 |
+
ret = {"encoded_videos": [], "prompt_embedding": [], "images": []}
|
82 |
+
|
83 |
+
for sample in samples:
|
84 |
+
encoded_video = sample["encoded_video"]
|
85 |
+
prompt_embedding = sample["prompt_embedding"]
|
86 |
+
image = sample["image"]
|
87 |
+
|
88 |
+
ret["encoded_videos"].append(encoded_video)
|
89 |
+
ret["prompt_embedding"].append(prompt_embedding)
|
90 |
+
ret["images"].append(image)
|
91 |
+
|
92 |
+
ret["encoded_videos"] = torch.stack(ret["encoded_videos"])
|
93 |
+
ret["prompt_embedding"] = torch.stack(ret["prompt_embedding"])
|
94 |
+
ret["images"] = torch.stack(ret["images"])
|
95 |
+
|
96 |
+
return ret
|
97 |
+
|
98 |
+
@override
|
99 |
+
def compute_loss(self, batch) -> torch.Tensor:
|
100 |
+
prompt_embedding = batch["prompt_embedding"]
|
101 |
+
latent = batch["encoded_videos"]
|
102 |
+
images = batch["images"]
|
103 |
+
|
104 |
+
# Shape of prompt_embedding: [B, seq_len, hidden_size]
|
105 |
+
# Shape of latent: [B, C, F, H, W]
|
106 |
+
# Shape of images: [B, C, H, W]
|
107 |
+
|
108 |
+
patch_size_t = self.state.transformer_config.patch_size_t
|
109 |
+
if patch_size_t is not None:
|
110 |
+
ncopy = latent.shape[2] % patch_size_t
|
111 |
+
# Copy the first frame ncopy times to match patch_size_t
|
112 |
+
first_frame = latent[:, :, :1, :, :] # Get first frame [B, C, 1, H, W]
|
113 |
+
latent = torch.cat([first_frame.repeat(1, 1, ncopy, 1, 1), latent], dim=2)
|
114 |
+
assert latent.shape[2] % patch_size_t == 0
|
115 |
+
|
116 |
+
batch_size, num_channels, num_frames, height, width = latent.shape
|
117 |
+
|
118 |
+
# Get prompt embeddings
|
119 |
+
_, seq_len, _ = prompt_embedding.shape
|
120 |
+
prompt_embedding = prompt_embedding.view(batch_size, seq_len, -1).to(dtype=latent.dtype)
|
121 |
+
|
122 |
+
# Add frame dimension to images [B,C,H,W] -> [B,C,F,H,W]
|
123 |
+
images = images.unsqueeze(2)
|
124 |
+
# Add noise to images
|
125 |
+
image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=self.accelerator.device)
|
126 |
+
image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=images.dtype)
|
127 |
+
noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None]
|
128 |
+
image_latent_dist = self.components.vae.encode(noisy_images.to(dtype=self.components.vae.dtype)).latent_dist
|
129 |
+
image_latents = image_latent_dist.sample() * self.components.vae.config.scaling_factor
|
130 |
+
|
131 |
+
# Sample a random timestep for each sample
|
132 |
+
timesteps = torch.randint(
|
133 |
+
0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device
|
134 |
+
)
|
135 |
+
timesteps = timesteps.long()
|
136 |
+
|
137 |
+
# from [B, C, F, H, W] to [B, F, C, H, W]
|
138 |
+
latent = latent.permute(0, 2, 1, 3, 4)
|
139 |
+
image_latents = image_latents.permute(0, 2, 1, 3, 4)
|
140 |
+
assert (latent.shape[0], *latent.shape[2:]) == (image_latents.shape[0], *image_latents.shape[2:])
|
141 |
+
|
142 |
+
# Padding image_latents to the same frame number as latent
|
143 |
+
padding_shape = (latent.shape[0], latent.shape[1] - 1, *latent.shape[2:])
|
144 |
+
latent_padding = image_latents.new_zeros(padding_shape)
|
145 |
+
image_latents = torch.cat([image_latents, latent_padding], dim=1)
|
146 |
+
|
147 |
+
# Add noise to latent
|
148 |
+
noise = torch.randn_like(latent)
|
149 |
+
latent_noisy = self.components.scheduler.add_noise(latent, noise, timesteps)
|
150 |
+
|
151 |
+
# Concatenate latent and image_latents in the channel dimension
|
152 |
+
latent_img_noisy = torch.cat([latent_noisy, image_latents], dim=2)
|
153 |
+
|
154 |
+
# Prepare rotary embeds
|
155 |
+
vae_scale_factor_spatial = 2 ** (len(self.components.vae.config.block_out_channels) - 1)
|
156 |
+
transformer_config = self.state.transformer_config
|
157 |
+
rotary_emb = (
|
158 |
+
self.prepare_rotary_positional_embeddings(
|
159 |
+
height=height * vae_scale_factor_spatial,
|
160 |
+
width=width * vae_scale_factor_spatial,
|
161 |
+
num_frames=num_frames,
|
162 |
+
transformer_config=transformer_config,
|
163 |
+
vae_scale_factor_spatial=vae_scale_factor_spatial,
|
164 |
+
device=self.accelerator.device,
|
165 |
+
)
|
166 |
+
if transformer_config.use_rotary_positional_embeddings
|
167 |
+
else None
|
168 |
+
)
|
169 |
+
|
170 |
+
# Predict noise, For CogVideoX1.5 Only.
|
171 |
+
ofs_emb = (
|
172 |
+
None if self.state.transformer_config.ofs_embed_dim is None else latent.new_full((1,), fill_value=2.0)
|
173 |
+
)
|
174 |
+
predicted_noise = self.components.transformer(
|
175 |
+
hidden_states=latent_img_noisy,
|
176 |
+
encoder_hidden_states=prompt_embedding,
|
177 |
+
timestep=timesteps,
|
178 |
+
ofs=ofs_emb,
|
179 |
+
image_rotary_emb=rotary_emb,
|
180 |
+
return_dict=False,
|
181 |
+
)[0]
|
182 |
+
|
183 |
+
# Denoise
|
184 |
+
latent_pred = self.components.scheduler.get_velocity(predicted_noise, latent_noisy, timesteps)
|
185 |
+
|
186 |
+
alphas_cumprod = self.components.scheduler.alphas_cumprod[timesteps]
|
187 |
+
weights = 1 / (1 - alphas_cumprod)
|
188 |
+
while len(weights.shape) < len(latent_pred.shape):
|
189 |
+
weights = weights.unsqueeze(-1)
|
190 |
+
|
191 |
+
loss = torch.mean((weights * (latent_pred - latent) ** 2).reshape(batch_size, -1), dim=1)
|
192 |
+
loss = loss.mean()
|
193 |
+
|
194 |
+
return loss
|
195 |
+
|
196 |
+
@override
|
197 |
+
def validation_step(
|
198 |
+
self, eval_data: Dict[str, Any], pipe: CogVideoXImageToVideoPipeline
|
199 |
+
) -> List[Tuple[str, Image.Image | List[Image.Image]]]:
|
200 |
+
"""
|
201 |
+
Return the data that needs to be saved. For videos, the data format is List[PIL],
|
202 |
+
and for images, the data format is PIL
|
203 |
+
"""
|
204 |
+
prompt, image, video = eval_data["prompt"], eval_data["image"], eval_data["video"]
|
205 |
+
|
206 |
+
video_generate = pipe(
|
207 |
+
num_frames=self.state.train_frames,
|
208 |
+
height=self.state.train_height,
|
209 |
+
width=self.state.train_width,
|
210 |
+
prompt=prompt,
|
211 |
+
image=image,
|
212 |
+
generator=self.state.generator,
|
213 |
+
).frames[0]
|
214 |
+
return [("video", video_generate)]
|
215 |
+
|
216 |
+
def prepare_rotary_positional_embeddings(
|
217 |
+
self,
|
218 |
+
height: int,
|
219 |
+
width: int,
|
220 |
+
num_frames: int,
|
221 |
+
transformer_config: Dict,
|
222 |
+
vae_scale_factor_spatial: int,
|
223 |
+
device: torch.device,
|
224 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
225 |
+
grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
|
226 |
+
grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)
|
227 |
+
|
228 |
+
if transformer_config.patch_size_t is None:
|
229 |
+
base_num_frames = num_frames
|
230 |
+
else:
|
231 |
+
base_num_frames = (num_frames + transformer_config.patch_size_t - 1) // transformer_config.patch_size_t
|
232 |
+
|
233 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
234 |
+
embed_dim=transformer_config.attention_head_dim,
|
235 |
+
crops_coords=None,
|
236 |
+
grid_size=(grid_height, grid_width),
|
237 |
+
temporal_size=base_num_frames,
|
238 |
+
grid_type="slice",
|
239 |
+
max_size=(grid_height, grid_width),
|
240 |
+
device=device,
|
241 |
+
)
|
242 |
+
|
243 |
+
return freqs_cos, freqs_sin
|
244 |
+
|
245 |
+
|
246 |
+
register("cogvideox-i2v", "lora", CogVideoXI2VLoraTrainer)
|
finetune/models/cogvideox_i2v/sft_trainer.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..cogvideox_i2v.lora_trainer import CogVideoXI2VLoraTrainer
|
2 |
+
from ..utils import register
|
3 |
+
|
4 |
+
|
5 |
+
class CogVideoXI2VSftTrainer(CogVideoXI2VLoraTrainer):
|
6 |
+
pass
|
7 |
+
|
8 |
+
|
9 |
+
register("cogvideox-i2v", "sft", CogVideoXI2VSftTrainer)
|
finetune/models/utils.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Literal
|
2 |
+
|
3 |
+
from finetune.trainer import Trainer
|
4 |
+
|
5 |
+
|
6 |
+
SUPPORTED_MODELS: Dict[str, Dict[str, Trainer]] = {}
|
7 |
+
|
8 |
+
|
9 |
+
def register(model_name: str, training_type: Literal["lora", "sft", "controlnet"], trainer_cls: Trainer):
|
10 |
+
"""Register a model and its associated functions for a specific training type.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
model_name (str): Name of the model to register (e.g. "cogvideox-5b")
|
14 |
+
training_type (Literal["lora", "sft", "controlnet"]): Type of training - either "lora" or "sft" or "controlnet"
|
15 |
+
trainer_cls (Trainer): Trainer class to register.
|
16 |
+
"""
|
17 |
+
|
18 |
+
# Check if model_name and training_type exists in SUPPORTED_MODELS
|
19 |
+
if model_name not in SUPPORTED_MODELS:
|
20 |
+
SUPPORTED_MODELS[model_name] = {}
|
21 |
+
else:
|
22 |
+
if training_type in SUPPORTED_MODELS[model_name]:
|
23 |
+
raise ValueError(f"Training type {training_type} already exists for model {model_name}")
|
24 |
+
|
25 |
+
SUPPORTED_MODELS[model_name][training_type] = trainer_cls
|
26 |
+
|
27 |
+
|
28 |
+
def show_supported_models():
|
29 |
+
"""Print all currently supported models and their training types."""
|
30 |
+
|
31 |
+
print("\nSupported Models:")
|
32 |
+
print("================")
|
33 |
+
|
34 |
+
for model_name, training_types in SUPPORTED_MODELS.items():
|
35 |
+
print(f"\n{model_name}")
|
36 |
+
print("-" * len(model_name))
|
37 |
+
for training_type in training_types:
|
38 |
+
print(f" • {training_type}")
|
39 |
+
|
40 |
+
|
41 |
+
def get_model_cls(model_type: str, training_type: Literal["lora", "sft"]) -> Trainer:
|
42 |
+
"""Get the trainer class for a specific model and training type."""
|
43 |
+
if model_type not in SUPPORTED_MODELS:
|
44 |
+
print(f"\nModel '{model_type}' is not supported.")
|
45 |
+
print("\nSupported models are:")
|
46 |
+
for supported_model in SUPPORTED_MODELS:
|
47 |
+
print(f" • {supported_model}")
|
48 |
+
raise ValueError(f"Model '{model_type}' is not supported")
|
49 |
+
|
50 |
+
if training_type not in SUPPORTED_MODELS[model_type]:
|
51 |
+
print(f"\nTraining type '{training_type}' is not supported for model '{model_type}'.")
|
52 |
+
print(f"\nSupported training types for '{model_type}' are:")
|
53 |
+
for supported_type in SUPPORTED_MODELS[model_type]:
|
54 |
+
print(f" • {supported_type}")
|
55 |
+
raise ValueError(f"Training type '{training_type}' is not supported for model '{model_type}'")
|
56 |
+
|
57 |
+
return SUPPORTED_MODELS[model_type][training_type]
|
finetune/modules/__init__.py
ADDED
File without changes
|
finetune/modules/camera_flow_generator.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from einops import rearrange, repeat
|
6 |
+
|
7 |
+
from .utils import instantiate_from_config, get_camera_flow_generator_input, warp_image
|
8 |
+
|
9 |
+
import pdb
|
10 |
+
|
11 |
+
class CameraFlowGenerator(nn.Module):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
depth_estimator_kwargs,
|
15 |
+
use_observed_mask=False,
|
16 |
+
cycle_th=3.,
|
17 |
+
):
|
18 |
+
super().__init__()
|
19 |
+
|
20 |
+
self.depth_warping_module = instantiate_from_config(depth_estimator_kwargs)
|
21 |
+
self.use_observed_mask = use_observed_mask
|
22 |
+
self.cycle_th = cycle_th
|
23 |
+
|
24 |
+
def forward(self, condition_image, camera_flow_generator_input):
|
25 |
+
# NOTE. camera_flow_generator_input is a dict of network inputs!
|
26 |
+
# camera_flow_generator_input: Dict
|
27 |
+
# - image
|
28 |
+
# - intrinsics
|
29 |
+
# - extrinsics
|
30 |
+
with torch.no_grad():
|
31 |
+
flow_f, flow_b, depth_warped_frames, depth_ctxt, depth_trgt = self.depth_warping_module(camera_flow_generator_input)
|
32 |
+
image_ctxt = repeat(condition_image, "b c h w -> (b v) c h w", v=(depth_warped_frames.shape[0]//condition_image.shape[0]))
|
33 |
+
|
34 |
+
log_dict = {
|
35 |
+
'depth_warped_frames': depth_warped_frames,
|
36 |
+
'depth_ctxt': depth_ctxt,
|
37 |
+
'depth_trgt': depth_trgt,
|
38 |
+
}
|
39 |
+
|
40 |
+
# if self.use_observed_mask:
|
41 |
+
# observed_mask = run_filtering(flow_f, flow_b, cycle_th=self.cycle_th)
|
42 |
+
# log_dict[
|
43 |
+
# 'observed_mask': observed_mask
|
44 |
+
# ]
|
45 |
+
|
46 |
+
return flow_f, log_dict
|
finetune/modules/camera_sampler.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from glob import glob
|
3 |
+
import random
|
4 |
+
import os
|
5 |
+
import pdb
|
6 |
+
random.seed(7777)
|
7 |
+
|
8 |
+
class SampleManualCam:
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
pose_type = 'manual',
|
12 |
+
root_path = '../assets/manual_poses',
|
13 |
+
):
|
14 |
+
self.root_path = root_path
|
15 |
+
if pose_type == 'manual':
|
16 |
+
self.MANUAL_CAM = ['I', 'D', 'L', 'O', 'R', 'U']
|
17 |
+
elif pose_type == 're10k':
|
18 |
+
self.RE10K_CAM = os.listdir(root_path)
|
19 |
+
# self.pose_path = glob(root_path, "*.txt")
|
20 |
+
|
21 |
+
self.pose_type = pose_type
|
22 |
+
|
23 |
+
def sample(self, order=None, name=None):
|
24 |
+
# Sample camera parameters (W2C)
|
25 |
+
|
26 |
+
if self.pose_type == 'manual':
|
27 |
+
if name is not None:
|
28 |
+
assert name in self.MANUAL_CAM
|
29 |
+
cam_name = name
|
30 |
+
elif order is not None:
|
31 |
+
order = order % len(self.MANUAL_CAM)
|
32 |
+
cam_name = self.MANUAL_CAM[order]
|
33 |
+
else:
|
34 |
+
cam_name = random.choice(self.MANUAL_CAM)
|
35 |
+
path = os.path.join(self.root_path, f"camera_{cam_name}.txt")
|
36 |
+
elif self.pose_type == 're10k':
|
37 |
+
if name is not None:
|
38 |
+
assert name in self.RE10K_CAM
|
39 |
+
cam_name = name
|
40 |
+
elif order is not None:
|
41 |
+
order = order % len(self.RE10K_CAM)
|
42 |
+
cam_name = self.RE10K_CAM[order]
|
43 |
+
else:
|
44 |
+
cam_name = random.choice(self.RE10K_CAM)
|
45 |
+
path = os.path.join(self.root_path, cam_name)
|
46 |
+
with open(path, 'r') as f:
|
47 |
+
poses = f.readlines()
|
48 |
+
|
49 |
+
poses = [pose.strip().split(' ') for pose in poses]
|
50 |
+
poses = [[float(x) for x in pose] for pose in poses]
|
51 |
+
|
52 |
+
return poses, cam_name
|
finetune/modules/cogvideox_controlnet.py
ADDED
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from einops import rearrange
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from diffusers.configuration_utils import FrozenDict
|
9 |
+
|
10 |
+
from diffusers import CogVideoXTransformer3DModel
|
11 |
+
from diffusers.models.transformers.cogvideox_transformer_3d import Transformer2DModelOutput, CogVideoXBlock
|
12 |
+
from diffusers.utils import is_torch_version
|
13 |
+
from diffusers.loaders import PeftAdapterMixin
|
14 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
15 |
+
from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
|
16 |
+
from diffusers.models.modeling_utils import ModelMixin
|
17 |
+
from diffusers.models.attention import Attention, FeedForward
|
18 |
+
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor2_0
|
19 |
+
from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero, AdaLayerNormZeroSingle
|
20 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
21 |
+
|
22 |
+
from .cogvideox_custom_modules import CustomCogVideoXPatchEmbed, CustomCogVideoXBlock
|
23 |
+
|
24 |
+
import pdb
|
25 |
+
|
26 |
+
class CogVideoXControlnet(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
27 |
+
_supports_gradient_checkpointing = True
|
28 |
+
|
29 |
+
@register_to_config
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
num_attention_heads: int = 30, # 48 for 5B, 30 for 2B.
|
33 |
+
attention_head_dim: int = 64,
|
34 |
+
# in_channels: int = 3,
|
35 |
+
in_channels: int = 16,
|
36 |
+
out_channels: Optional[int] = 16, # Not used
|
37 |
+
flip_sin_to_cos: bool = True,
|
38 |
+
freq_shift: int = 0,
|
39 |
+
time_embed_dim: int = 512,
|
40 |
+
ofs_embed_dim: Optional[int] = None,
|
41 |
+
text_embed_dim: int = 4096,
|
42 |
+
num_layers: int = 30,
|
43 |
+
dropout: float = 0.0,
|
44 |
+
attention_bias: bool = True,
|
45 |
+
sample_width: int = 90,
|
46 |
+
sample_height: int = 60,
|
47 |
+
sample_frames: int = 49,
|
48 |
+
patch_size: int = 2,
|
49 |
+
patch_size_t: Optional[int] = None,
|
50 |
+
temporal_compression_ratio: int = 4,
|
51 |
+
max_text_seq_length: int = 226,
|
52 |
+
activation_fn: str = "gelu-approximate",
|
53 |
+
timestep_activation_fn: str = "silu",
|
54 |
+
norm_elementwise_affine: bool = True,
|
55 |
+
norm_eps: float = 1e-5,
|
56 |
+
spatial_interpolation_scale: float = 1.875,
|
57 |
+
temporal_interpolation_scale: float = 1.0,
|
58 |
+
use_rotary_positional_embeddings: bool = False,
|
59 |
+
use_learned_positional_embeddings: bool = False,
|
60 |
+
patch_bias: bool = True,
|
61 |
+
out_proj_dim_factor: int = 8,
|
62 |
+
out_proj_dim_zero_init: bool = True,
|
63 |
+
notextinflow: bool = False,
|
64 |
+
):
|
65 |
+
super().__init__()
|
66 |
+
inner_dim = num_attention_heads * attention_head_dim
|
67 |
+
|
68 |
+
self.notextinflow = notextinflow
|
69 |
+
|
70 |
+
if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
|
71 |
+
raise ValueError(
|
72 |
+
"There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
|
73 |
+
"embeddings. If you're using a custom model and/or believe this should be supported, please open an "
|
74 |
+
"issue at https://github.com/huggingface/diffusers/issues."
|
75 |
+
)
|
76 |
+
|
77 |
+
"""
|
78 |
+
Delete below.
|
79 |
+
In our case, FloVD, controlnet_hidden_states is already flow_latents encoded by 3D-Causal-VAE
|
80 |
+
"""
|
81 |
+
# start_channels = in_channels * (downscale_coef ** 2)
|
82 |
+
# input_channels = [start_channels, start_channels // 2, start_channels // 4]
|
83 |
+
# self.unshuffle = nn.PixelUnshuffle(downscale_coef)
|
84 |
+
|
85 |
+
# self.controlnet_encode_first = nn.Sequential(
|
86 |
+
# nn.Conv2d(input_channels[0], input_channels[1], kernel_size=1, stride=1, padding=0),
|
87 |
+
# nn.GroupNorm(2, input_channels[1]),
|
88 |
+
# nn.ReLU(),
|
89 |
+
# )
|
90 |
+
|
91 |
+
# self.controlnet_encode_second = nn.Sequential(
|
92 |
+
# nn.Conv2d(input_channels[1], input_channels[2], kernel_size=1, stride=1, padding=0),
|
93 |
+
# nn.GroupNorm(2, input_channels[2]),
|
94 |
+
# nn.ReLU(),
|
95 |
+
# )
|
96 |
+
|
97 |
+
# """
|
98 |
+
# Modify below.
|
99 |
+
# In our case, patch_embed takes encoder_hidden_states, hidden_states, controlnet_hidden_states (flow)
|
100 |
+
# """
|
101 |
+
# 1. Patch embedding
|
102 |
+
self.patch_embed = CogVideoXPatchEmbed(
|
103 |
+
patch_size=patch_size,
|
104 |
+
in_channels=in_channels,
|
105 |
+
embed_dim=inner_dim,
|
106 |
+
bias=True,
|
107 |
+
sample_width=sample_width,
|
108 |
+
sample_height=sample_height,
|
109 |
+
sample_frames=sample_frames,
|
110 |
+
temporal_compression_ratio=temporal_compression_ratio,
|
111 |
+
spatial_interpolation_scale=spatial_interpolation_scale,
|
112 |
+
temporal_interpolation_scale=temporal_interpolation_scale,
|
113 |
+
use_positional_embeddings=not use_rotary_positional_embeddings,
|
114 |
+
use_learned_positional_embeddings=use_learned_positional_embeddings,
|
115 |
+
)
|
116 |
+
# self.patch_embed = CustomCogVideoXPatchEmbed(
|
117 |
+
# patch_size=patch_size,
|
118 |
+
# patch_size_t=patch_size_t,
|
119 |
+
# in_channels=in_channels,
|
120 |
+
# embed_dim=inner_dim,
|
121 |
+
# text_embed_dim=text_embed_dim,
|
122 |
+
# bias=patch_bias,
|
123 |
+
# sample_width=sample_width,
|
124 |
+
# sample_height=sample_height,
|
125 |
+
# sample_frames=sample_frames,
|
126 |
+
# temporal_compression_ratio=temporal_compression_ratio,
|
127 |
+
# max_text_seq_length=max_text_seq_length,
|
128 |
+
# spatial_interpolation_scale=spatial_interpolation_scale,
|
129 |
+
# temporal_interpolation_scale=temporal_interpolation_scale,
|
130 |
+
# use_positional_embeddings=not use_rotary_positional_embeddings,
|
131 |
+
# use_learned_positional_embeddings=use_learned_positional_embeddings,
|
132 |
+
# )
|
133 |
+
|
134 |
+
self.embedding_dropout = nn.Dropout(dropout)
|
135 |
+
|
136 |
+
# 2. Time embeddings
|
137 |
+
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
|
138 |
+
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
|
139 |
+
|
140 |
+
# 3. Define spatio-temporal transformers blocks
|
141 |
+
# self.transformer_blocks = nn.ModuleList(
|
142 |
+
# [
|
143 |
+
# CogVideoXBlock(
|
144 |
+
# dim=inner_dim,
|
145 |
+
# num_attention_heads=num_attention_heads,
|
146 |
+
# attention_head_dim=attention_head_dim,
|
147 |
+
# time_embed_dim=time_embed_dim,
|
148 |
+
# dropout=dropout,
|
149 |
+
# activation_fn=activation_fn,
|
150 |
+
# attention_bias=attention_bias,
|
151 |
+
# norm_elementwise_affine=norm_elementwise_affine,
|
152 |
+
# norm_eps=norm_eps,
|
153 |
+
# )
|
154 |
+
# for _ in range(num_layers)
|
155 |
+
# ]
|
156 |
+
# )
|
157 |
+
self.transformer_blocks = nn.ModuleList(
|
158 |
+
[
|
159 |
+
CustomCogVideoXBlock(
|
160 |
+
dim=inner_dim,
|
161 |
+
num_attention_heads=num_attention_heads,
|
162 |
+
attention_head_dim=attention_head_dim,
|
163 |
+
time_embed_dim=time_embed_dim,
|
164 |
+
dropout=dropout,
|
165 |
+
activation_fn=activation_fn,
|
166 |
+
attention_bias=attention_bias,
|
167 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
168 |
+
norm_eps=norm_eps,
|
169 |
+
)
|
170 |
+
for _ in range(num_layers)
|
171 |
+
]
|
172 |
+
)
|
173 |
+
|
174 |
+
self.out_projectors = None
|
175 |
+
if out_proj_dim_factor is not None:
|
176 |
+
out_proj_dim = num_attention_heads * out_proj_dim_factor
|
177 |
+
self.out_projectors = nn.ModuleList(
|
178 |
+
[nn.Linear(inner_dim, out_proj_dim) for _ in range(num_layers)]
|
179 |
+
)
|
180 |
+
if out_proj_dim_zero_init:
|
181 |
+
for out_projector in self.out_projectors:
|
182 |
+
self.zeros_init_linear(out_projector)
|
183 |
+
|
184 |
+
self.gradient_checkpointing = False
|
185 |
+
|
186 |
+
def zeros_init_linear(self, linear: nn.Module):
|
187 |
+
if isinstance(linear, (nn.Linear, nn.Conv1d)):
|
188 |
+
if hasattr(linear, "weight"):
|
189 |
+
nn.init.zeros_(linear.weight)
|
190 |
+
if hasattr(linear, "bias"):
|
191 |
+
nn.init.zeros_(linear.bias)
|
192 |
+
|
193 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
194 |
+
self.gradient_checkpointing = value
|
195 |
+
|
196 |
+
def compress_time(self, x, num_frames):
|
197 |
+
x = rearrange(x, '(b f) c h w -> b f c h w', f=num_frames)
|
198 |
+
batch_size, frames, channels, height, width = x.shape
|
199 |
+
x = rearrange(x, 'b f c h w -> (b h w) c f')
|
200 |
+
|
201 |
+
if x.shape[-1] % 2 == 1:
|
202 |
+
x_first, x_rest = x[..., 0], x[..., 1:]
|
203 |
+
if x_rest.shape[-1] > 0:
|
204 |
+
x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
|
205 |
+
|
206 |
+
x = torch.cat([x_first[..., None], x_rest], dim=-1)
|
207 |
+
else:
|
208 |
+
x = F.avg_pool1d(x, kernel_size=2, stride=2)
|
209 |
+
x = rearrange(x, '(b h w) c f -> (b f) c h w', b=batch_size, h=height, w=width)
|
210 |
+
return x
|
211 |
+
|
212 |
+
# """
|
213 |
+
# Add below.
|
214 |
+
# Load pre-trained weight from Diffusers
|
215 |
+
# For patch_embed, copy a projection layer for controlnet_states
|
216 |
+
# """
|
217 |
+
@classmethod
|
218 |
+
def from_pretrained(cls, model_path, subfolder, **additional_kwargs):
|
219 |
+
base = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder=subfolder)
|
220 |
+
controlnet_config = FrozenDict({**base.config, **additional_kwargs})
|
221 |
+
model = cls(**controlnet_config)
|
222 |
+
|
223 |
+
missing, unexpected = model.load_state_dict(base.state_dict(), strict=False)
|
224 |
+
print(f"Load CogVideoXTransformer3DModel.")
|
225 |
+
# if len(missing) != 0 or len(unexpected) != 0:
|
226 |
+
# print(f"Missing keys: {missing}")
|
227 |
+
# print(f"Unexpected keys: {unexpected}")
|
228 |
+
|
229 |
+
del base
|
230 |
+
torch.cuda.empty_cache()
|
231 |
+
|
232 |
+
|
233 |
+
return model
|
234 |
+
|
235 |
+
def forward(
|
236 |
+
self,
|
237 |
+
hidden_states: torch.Tensor,
|
238 |
+
encoder_hidden_states: torch.Tensor,
|
239 |
+
controlnet_hidden_states: torch.Tensor,
|
240 |
+
timestep: Union[int, float, torch.LongTensor],
|
241 |
+
controlnet_valid_mask: torch.Tensor = None,
|
242 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
243 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
244 |
+
return_dict: bool = True,
|
245 |
+
):
|
246 |
+
"""
|
247 |
+
Delete below.
|
248 |
+
In our case, FloVD, controlnet_hidden_states is already flow_latents encoded by 3D-Causal-VAE
|
249 |
+
"""
|
250 |
+
# batch_size, num_frames, channels, height, width = controlnet_states.shape
|
251 |
+
# # 0. Controlnet encoder
|
252 |
+
# controlnet_states = rearrange(controlnet_states, 'b f c h w -> (b f) c h w')
|
253 |
+
# controlnet_states = self.unshuffle(controlnet_states)
|
254 |
+
# controlnet_states = self.controlnet_encode_first(controlnet_states)
|
255 |
+
# controlnet_states = self.compress_time(controlnet_states, num_frames=num_frames)
|
256 |
+
# num_frames = controlnet_states.shape[0] // batch_size
|
257 |
+
|
258 |
+
# controlnet_states = self.controlnet_encode_second(controlnet_states)
|
259 |
+
# controlnet_states = self.compress_time(controlnet_states, num_frames=num_frames)
|
260 |
+
# controlnet_states = rearrange(controlnet_states, '(b f) c h w -> b f c h w', b=batch_size)
|
261 |
+
|
262 |
+
batch_size, num_frames, channels, height, width = hidden_states.shape
|
263 |
+
|
264 |
+
|
265 |
+
# """
|
266 |
+
# Modify below.
|
267 |
+
# Distinguish hidden_states and controlnet_states (i.e., flow_hidden_states)
|
268 |
+
# """
|
269 |
+
hidden_states = torch.cat([hidden_states, controlnet_hidden_states], dim=2) # instead of image_latents, we use flow_latents for condition.
|
270 |
+
|
271 |
+
# controlnet_states = self.controlnext_encoder(controlnet_states, timestep=timestep)
|
272 |
+
# 1. Time embedding
|
273 |
+
timesteps = timestep
|
274 |
+
t_emb = self.time_proj(timesteps)
|
275 |
+
|
276 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
277 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
278 |
+
# there might be better ways to encapsulate this.
|
279 |
+
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
280 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
281 |
+
|
282 |
+
# """
|
283 |
+
# Modify below.
|
284 |
+
# patch_embed takes encoder, hidden_states, controlnet_hidden_states
|
285 |
+
# """
|
286 |
+
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
287 |
+
# hidden_states = self.patch_embed(encoder_hidden_states, hidden_states, controlnet_hidden_states) # output: [text_embeds, image_embeds, flow_embeds] [B, 35326, 3072]
|
288 |
+
hidden_states = self.embedding_dropout(hidden_states)
|
289 |
+
|
290 |
+
"""
|
291 |
+
Not modified below.
|
292 |
+
hidden_states include both hidden_states and controlnet_hidden_states
|
293 |
+
"""
|
294 |
+
text_seq_length = encoder_hidden_states.shape[1]
|
295 |
+
encoder_hidden_states = hidden_states[:, :text_seq_length] # [text_embeds] [B, 226, 3072]
|
296 |
+
hidden_states = hidden_states[:, text_seq_length:] # [image_embeds, flow_embeds] [B, 35100, 3072]
|
297 |
+
|
298 |
+
# attention mask
|
299 |
+
if controlnet_valid_mask is not None:
|
300 |
+
mask_shape = controlnet_valid_mask.shape
|
301 |
+
attention_mask = torch.nn.functional.interpolate(controlnet_valid_mask, size=(mask_shape[2], mask_shape[3]//2, mask_shape[4]//2), mode='trilinear', align_corners=False) # CFHW
|
302 |
+
attention_mask[attention_mask>=0.5] = 1
|
303 |
+
attention_mask[attention_mask<0.5] = 0
|
304 |
+
attention_mask = attention_mask.to(torch.bool)
|
305 |
+
attention_mask = rearrange(attention_mask.squeeze(1), 'b f h w -> b (f h w)') # (B, N=(fxhxw))
|
306 |
+
|
307 |
+
# Consider encoder_hidden_states.. or do not use?? not sure..
|
308 |
+
if not self.notextinflow:
|
309 |
+
attention_mask = F.pad(attention_mask, (text_seq_length, 0), value=0.0)
|
310 |
+
|
311 |
+
attention_kwargs = {
|
312 |
+
'attention_mask': attention_mask if controlnet_valid_mask is not None else None,
|
313 |
+
'notextinflow': self.notextinflow,
|
314 |
+
}
|
315 |
+
|
316 |
+
controlnet_hidden_states = ()
|
317 |
+
# 3. Transformer blocks
|
318 |
+
for i, block in enumerate(self.transformer_blocks):
|
319 |
+
if self.training and self.gradient_checkpointing:
|
320 |
+
|
321 |
+
def create_custom_forward(module):
|
322 |
+
def custom_forward(*inputs):
|
323 |
+
return module(*inputs)
|
324 |
+
|
325 |
+
return custom_forward
|
326 |
+
|
327 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
328 |
+
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
329 |
+
create_custom_forward(block),
|
330 |
+
hidden_states,
|
331 |
+
encoder_hidden_states,
|
332 |
+
emb,
|
333 |
+
image_rotary_emb,
|
334 |
+
attention_kwargs,
|
335 |
+
**ckpt_kwargs,
|
336 |
+
)
|
337 |
+
else:
|
338 |
+
hidden_states, encoder_hidden_states = block(
|
339 |
+
hidden_states=hidden_states,
|
340 |
+
encoder_hidden_states=encoder_hidden_states,
|
341 |
+
temb=emb,
|
342 |
+
image_rotary_emb=image_rotary_emb,
|
343 |
+
attention_kwargs=attention_kwargs,
|
344 |
+
)
|
345 |
+
|
346 |
+
if self.out_projectors is not None:
|
347 |
+
controlnet_hidden_states += (self.out_projectors[i](hidden_states),)
|
348 |
+
else:
|
349 |
+
controlnet_hidden_states += (hidden_states,)
|
350 |
+
|
351 |
+
if not return_dict:
|
352 |
+
return (controlnet_hidden_states,)
|
353 |
+
return Transformer2DModelOutput(sample=controlnet_hidden_states)
|
finetune/modules/cogvideox_custom_model.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from diffusers.utils import is_torch_version
|
6 |
+
from diffusers.models.transformers.cogvideox_transformer_3d import CogVideoXTransformer3DModel, Transformer2DModelOutput
|
7 |
+
|
8 |
+
import pdb
|
9 |
+
|
10 |
+
class CustomCogVideoXTransformer3DModel(CogVideoXTransformer3DModel):
|
11 |
+
def forward(
|
12 |
+
self,
|
13 |
+
hidden_states: torch.Tensor,
|
14 |
+
encoder_hidden_states: torch.Tensor,
|
15 |
+
timestep: Union[int, float, torch.LongTensor],
|
16 |
+
start_frame = None,
|
17 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
18 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
19 |
+
controlnet_states: torch.Tensor = None,
|
20 |
+
controlnet_weights: Optional[Union[float, int, list, np.ndarray, torch.FloatTensor]] = 1.0,
|
21 |
+
return_dict: bool = True,
|
22 |
+
):
|
23 |
+
batch_size, num_frames, channels, height, width = hidden_states.shape
|
24 |
+
|
25 |
+
if start_frame is not None:
|
26 |
+
hidden_states = torch.cat([start_frame, hidden_states], dim=2)
|
27 |
+
# 1. Time embedding
|
28 |
+
timesteps = timestep
|
29 |
+
t_emb = self.time_proj(timesteps)
|
30 |
+
|
31 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
32 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
33 |
+
# there might be better ways to encapsulate this.
|
34 |
+
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
35 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
36 |
+
|
37 |
+
# 2. Patch embedding
|
38 |
+
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
39 |
+
hidden_states = self.embedding_dropout(hidden_states)
|
40 |
+
|
41 |
+
text_seq_length = encoder_hidden_states.shape[1]
|
42 |
+
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
43 |
+
hidden_states = hidden_states[:, text_seq_length:]
|
44 |
+
|
45 |
+
# 3. Transformer blocks
|
46 |
+
for i, block in enumerate(self.transformer_blocks):
|
47 |
+
if self.training and self.gradient_checkpointing:
|
48 |
+
|
49 |
+
def create_custom_forward(module):
|
50 |
+
def custom_forward(*inputs):
|
51 |
+
return module(*inputs)
|
52 |
+
|
53 |
+
return custom_forward
|
54 |
+
|
55 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
56 |
+
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
57 |
+
create_custom_forward(block),
|
58 |
+
hidden_states,
|
59 |
+
encoder_hidden_states,
|
60 |
+
emb,
|
61 |
+
image_rotary_emb,
|
62 |
+
**ckpt_kwargs,
|
63 |
+
)
|
64 |
+
else:
|
65 |
+
hidden_states, encoder_hidden_states = block(
|
66 |
+
hidden_states=hidden_states,
|
67 |
+
encoder_hidden_states=encoder_hidden_states,
|
68 |
+
temb=emb,
|
69 |
+
image_rotary_emb=image_rotary_emb,
|
70 |
+
)
|
71 |
+
|
72 |
+
if (controlnet_states is not None) and (i < len(controlnet_states)):
|
73 |
+
controlnet_states_block = controlnet_states[i]
|
74 |
+
controlnet_block_weight = 1.0
|
75 |
+
if isinstance(controlnet_weights, (list, np.ndarray)) or torch.is_tensor(controlnet_weights):
|
76 |
+
controlnet_block_weight = controlnet_weights[i]
|
77 |
+
elif isinstance(controlnet_weights, (float, int)):
|
78 |
+
controlnet_block_weight = controlnet_weights
|
79 |
+
hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight
|
80 |
+
|
81 |
+
if not self.config.use_rotary_positional_embeddings:
|
82 |
+
# CogVideoX-2B
|
83 |
+
hidden_states = self.norm_final(hidden_states)
|
84 |
+
else:
|
85 |
+
# CogVideoX-5B
|
86 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
87 |
+
hidden_states = self.norm_final(hidden_states)
|
88 |
+
hidden_states = hidden_states[:, text_seq_length:]
|
89 |
+
|
90 |
+
# 4. Final block
|
91 |
+
hidden_states = self.norm_out(hidden_states, temb=emb)
|
92 |
+
hidden_states = self.proj_out(hidden_states)
|
93 |
+
|
94 |
+
# 5. Unpatchify
|
95 |
+
p = self.config.patch_size
|
96 |
+
p_t = self.config.patch_size_t
|
97 |
+
|
98 |
+
if p_t is None:
|
99 |
+
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
|
100 |
+
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
101 |
+
else:
|
102 |
+
output = hidden_states.reshape(
|
103 |
+
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
|
104 |
+
)
|
105 |
+
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
|
106 |
+
|
107 |
+
if not return_dict:
|
108 |
+
return (output,)
|
109 |
+
return Transformer2DModelOutput(sample=output)
|
finetune/modules/cogvideox_custom_modules.py
ADDED
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import List, Optional, Tuple, Union, Dict, Any
|
3 |
+
import copy
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch import nn
|
9 |
+
from diffusers import CogVideoXTransformer3DModel
|
10 |
+
from diffusers.models.transformers.cogvideox_transformer_3d import CogVideoXBlock
|
11 |
+
from diffusers.models.normalization import CogVideoXLayerNormZero
|
12 |
+
from diffusers.models.attention import FeedForward
|
13 |
+
from diffusers.models.attention_processor import CogVideoXAttnProcessor2_0, Attention
|
14 |
+
from diffusers.models.embeddings import CogVideoXPatchEmbed
|
15 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
16 |
+
from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
17 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
18 |
+
|
19 |
+
from contextlib import contextmanager
|
20 |
+
from peft.tuners.lora.layer import LoraLayer # PEFT의 LoRA 레이어 기본 클래스
|
21 |
+
|
22 |
+
import pdb
|
23 |
+
|
24 |
+
# Code heavily borrowed from https://github.com/huggingface/diffusers
|
25 |
+
|
26 |
+
|
27 |
+
class enable_lora:
|
28 |
+
def __init__(self, modules, enable=True):
|
29 |
+
self.modules = modules
|
30 |
+
self.enable = enable
|
31 |
+
self.prev_states = {}
|
32 |
+
|
33 |
+
def __enter__(self):
|
34 |
+
for module in self.modules:
|
35 |
+
self.prev_states[module] = getattr(module, "lora_enabled", True)
|
36 |
+
setattr(module, "lora_enabled", self.enable)
|
37 |
+
return self
|
38 |
+
|
39 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
40 |
+
for module in self.modules:
|
41 |
+
setattr(module, "lora_enabled", self.prev_states[module])
|
42 |
+
return False
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
class CustomCogVideoXPatchEmbed(CogVideoXPatchEmbed):
|
47 |
+
def __init__(self, **kwargs):
|
48 |
+
super().__init__(**kwargs)
|
49 |
+
|
50 |
+
patch_size = kwargs['patch_size']
|
51 |
+
patch_size_t = kwargs['patch_size_t']
|
52 |
+
bias = kwargs['bias']
|
53 |
+
in_channels = kwargs['in_channels']
|
54 |
+
embed_dim = kwargs['embed_dim']
|
55 |
+
|
56 |
+
# projection layer for flow latents
|
57 |
+
if patch_size_t is None:
|
58 |
+
# CogVideoX 1.0 checkpoints
|
59 |
+
self.flow_proj = nn.Conv2d(in_channels//2, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias)
|
60 |
+
else:
|
61 |
+
# CogVideoX 1.5 checkpoints
|
62 |
+
self.flow_proj = nn.Linear(in_channels//2 * patch_size * patch_size * patch_size_t, embed_dim)
|
63 |
+
|
64 |
+
# Add positional embedding for flow_embeds
|
65 |
+
if self.use_positional_embeddings or self.use_learned_positional_embeddings:
|
66 |
+
flow_pos_embedding = self._get_positional_embeddings(self.sample_height, self.sample_width, self.sample_frames)[:,self.max_text_seq_length:] # shape: [1, 17550, 3072]
|
67 |
+
self.flow_pos_embedding = nn.Parameter(flow_pos_embedding)
|
68 |
+
|
69 |
+
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor, flow_embeds: torch.Tensor):
|
70 |
+
r"""
|
71 |
+
Args:
|
72 |
+
text_embeds (`torch.Tensor`):
|
73 |
+
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
|
74 |
+
image_embeds (`torch.Tensor`):
|
75 |
+
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
|
76 |
+
flow_embeds (`torch.Tensor`):
|
77 |
+
Input flow embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
|
78 |
+
"""
|
79 |
+
text_embeds = self.text_proj(text_embeds)
|
80 |
+
|
81 |
+
batch_size, num_frames, channels, height, width = image_embeds.shape
|
82 |
+
|
83 |
+
if self.patch_size_t is None:
|
84 |
+
# embed video latents
|
85 |
+
image_embeds = image_embeds.reshape(-1, channels, height, width)
|
86 |
+
image_embeds = self.proj(image_embeds)
|
87 |
+
image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
|
88 |
+
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
|
89 |
+
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
|
90 |
+
|
91 |
+
# embed flow latents
|
92 |
+
flow_embeds = flow_embeds.reshape(-1, channels//2, height, width)
|
93 |
+
flow_embeds = self.flow_proj(flow_embeds)
|
94 |
+
flow_embeds = flow_embeds.view(batch_size, num_frames, *flow_embeds.shape[1:])
|
95 |
+
flow_embeds = flow_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
|
96 |
+
flow_embeds = flow_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
|
97 |
+
else:
|
98 |
+
p = self.patch_size
|
99 |
+
p_t = self.patch_size_t
|
100 |
+
|
101 |
+
# embed video latents
|
102 |
+
image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
|
103 |
+
image_embeds = image_embeds.reshape(
|
104 |
+
batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
|
105 |
+
)
|
106 |
+
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
|
107 |
+
image_embeds = self.proj(image_embeds)
|
108 |
+
|
109 |
+
# embed flow latents
|
110 |
+
flow_embeds = flow_embeds.permute(0, 1, 3, 4, 2)
|
111 |
+
flow_embeds = flow_embeds.reshape(
|
112 |
+
batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels//2
|
113 |
+
)
|
114 |
+
flow_embeds = flow_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
|
115 |
+
flow_embeds = self.flow_proj(flow_embeds)
|
116 |
+
|
117 |
+
# Curriculum learning of flow token
|
118 |
+
# flow_embeds = self.flow_scale * flow_embeds
|
119 |
+
|
120 |
+
|
121 |
+
embeds = torch.cat(
|
122 |
+
[text_embeds, image_embeds, flow_embeds], dim=1
|
123 |
+
).contiguous() # [batch, num_frames x height x width + seq_length + num_frames x height x width, channels]
|
124 |
+
|
125 |
+
if self.use_positional_embeddings or self.use_learned_positional_embeddings:
|
126 |
+
if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height):
|
127 |
+
raise ValueError(
|
128 |
+
"It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'."
|
129 |
+
"If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues."
|
130 |
+
)
|
131 |
+
|
132 |
+
pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
|
133 |
+
|
134 |
+
if (
|
135 |
+
self.sample_height != height
|
136 |
+
or self.sample_width != width
|
137 |
+
or self.sample_frames != pre_time_compression_frames
|
138 |
+
):
|
139 |
+
pos_embedding = self._get_positional_embeddings(
|
140 |
+
height, width, pre_time_compression_frames, device=embeds.device
|
141 |
+
)
|
142 |
+
else:
|
143 |
+
pos_embedding = self.pos_embedding
|
144 |
+
|
145 |
+
# Previous version..
|
146 |
+
# pos_embedding = pos_embedding.to(dtype=embeds.dtype)
|
147 |
+
# embeds = embeds + pos_embedding
|
148 |
+
|
149 |
+
# Add flow embedding..
|
150 |
+
# flow_pos_embedding = self.flow_pos_scale * self.flow_pos_embedding
|
151 |
+
flow_pos_embedding = self.flow_pos_embedding
|
152 |
+
pos_embedding_total = torch.cat([pos_embedding, flow_pos_embedding], dim=1).to(dtype=embeds.dtype)
|
153 |
+
embeds = embeds + pos_embedding_total
|
154 |
+
|
155 |
+
return embeds
|
156 |
+
|
157 |
+
|
158 |
+
|
159 |
+
@maybe_allow_in_graph
|
160 |
+
class CustomCogVideoXBlock(nn.Module):
|
161 |
+
r"""
|
162 |
+
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
|
163 |
+
|
164 |
+
Parameters:
|
165 |
+
dim (`int`):
|
166 |
+
The number of channels in the input and output.
|
167 |
+
num_attention_heads (`int`):
|
168 |
+
The number of heads to use for multi-head attention.
|
169 |
+
attention_head_dim (`int`):
|
170 |
+
The number of channels in each head.
|
171 |
+
time_embed_dim (`int`):
|
172 |
+
The number of channels in timestep embedding.
|
173 |
+
dropout (`float`, defaults to `0.0`):
|
174 |
+
The dropout probability to use.
|
175 |
+
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
176 |
+
Activation function to be used in feed-forward.
|
177 |
+
attention_bias (`bool`, defaults to `False`):
|
178 |
+
Whether or not to use bias in attention projection layers.
|
179 |
+
qk_norm (`bool`, defaults to `True`):
|
180 |
+
Whether or not to use normalization after query and key projections in Attention.
|
181 |
+
norm_elementwise_affine (`bool`, defaults to `True`):
|
182 |
+
Whether to use learnable elementwise affine parameters for normalization.
|
183 |
+
norm_eps (`float`, defaults to `1e-5`):
|
184 |
+
Epsilon value for normalization layers.
|
185 |
+
final_dropout (`bool` defaults to `False`):
|
186 |
+
Whether to apply a final dropout after the last feed-forward layer.
|
187 |
+
ff_inner_dim (`int`, *optional*, defaults to `None`):
|
188 |
+
Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
|
189 |
+
ff_bias (`bool`, defaults to `True`):
|
190 |
+
Whether or not to use bias in Feed-forward layer.
|
191 |
+
attention_out_bias (`bool`, defaults to `True`):
|
192 |
+
Whether or not to use bias in Attention output projection layer.
|
193 |
+
"""
|
194 |
+
|
195 |
+
def __init__(
|
196 |
+
self,
|
197 |
+
dim: int,
|
198 |
+
num_attention_heads: int,
|
199 |
+
attention_head_dim: int,
|
200 |
+
time_embed_dim: int,
|
201 |
+
dropout: float = 0.0,
|
202 |
+
activation_fn: str = "gelu-approximate",
|
203 |
+
attention_bias: bool = False,
|
204 |
+
qk_norm: bool = True,
|
205 |
+
norm_elementwise_affine: bool = True,
|
206 |
+
norm_eps: float = 1e-5,
|
207 |
+
final_dropout: bool = True,
|
208 |
+
ff_inner_dim: Optional[int] = None,
|
209 |
+
ff_bias: bool = True,
|
210 |
+
attention_out_bias: bool = True,
|
211 |
+
):
|
212 |
+
super().__init__()
|
213 |
+
|
214 |
+
# 1. Self Attention
|
215 |
+
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
216 |
+
|
217 |
+
self.attn1 = Attention(
|
218 |
+
query_dim=dim,
|
219 |
+
dim_head=attention_head_dim,
|
220 |
+
heads=num_attention_heads,
|
221 |
+
qk_norm="layer_norm" if qk_norm else None,
|
222 |
+
eps=1e-6,
|
223 |
+
bias=attention_bias,
|
224 |
+
out_bias=attention_out_bias,
|
225 |
+
processor=CustomCogVideoXAttnProcessor2_0(),
|
226 |
+
)
|
227 |
+
|
228 |
+
# 2. Feed Forward
|
229 |
+
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
230 |
+
|
231 |
+
self.ff = FeedForward(
|
232 |
+
dim,
|
233 |
+
dropout=dropout,
|
234 |
+
activation_fn=activation_fn,
|
235 |
+
final_dropout=final_dropout,
|
236 |
+
inner_dim=ff_inner_dim,
|
237 |
+
bias=ff_bias,
|
238 |
+
)
|
239 |
+
|
240 |
+
def forward(
|
241 |
+
self,
|
242 |
+
hidden_states: torch.Tensor,
|
243 |
+
encoder_hidden_states: torch.Tensor,
|
244 |
+
temb: torch.Tensor,
|
245 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
246 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
247 |
+
) -> torch.Tensor:
|
248 |
+
text_seq_length = encoder_hidden_states.size(1)
|
249 |
+
attention_kwargs = attention_kwargs or {}
|
250 |
+
|
251 |
+
# norm & modulate
|
252 |
+
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
253 |
+
hidden_states, encoder_hidden_states, temb
|
254 |
+
)
|
255 |
+
|
256 |
+
# attention
|
257 |
+
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
258 |
+
hidden_states=norm_hidden_states,
|
259 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
260 |
+
image_rotary_emb=image_rotary_emb,
|
261 |
+
**attention_kwargs,
|
262 |
+
)
|
263 |
+
|
264 |
+
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
265 |
+
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
|
266 |
+
|
267 |
+
# norm & modulate
|
268 |
+
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
|
269 |
+
hidden_states, encoder_hidden_states, temb
|
270 |
+
)
|
271 |
+
|
272 |
+
# feed-forward
|
273 |
+
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
274 |
+
ff_output = self.ff(norm_hidden_states)
|
275 |
+
|
276 |
+
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
|
277 |
+
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
|
278 |
+
|
279 |
+
return hidden_states, encoder_hidden_states
|
280 |
+
|
281 |
+
|
282 |
+
class CustomCogVideoXAttnProcessor2_0:
|
283 |
+
r"""
|
284 |
+
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
285 |
+
query and key vectors, but does not include spatial normalization.
|
286 |
+
"""
|
287 |
+
|
288 |
+
def __init__(self):
|
289 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
290 |
+
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
291 |
+
|
292 |
+
def __call__(
|
293 |
+
self,
|
294 |
+
attn: Attention,
|
295 |
+
hidden_states: torch.Tensor,
|
296 |
+
encoder_hidden_states: torch.Tensor,
|
297 |
+
attention_mask: Optional[torch.Tensor] = None,
|
298 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
299 |
+
notextinflow: Optional[bool] = False,
|
300 |
+
) -> torch.Tensor:
|
301 |
+
text_seq_length = encoder_hidden_states.size(1)
|
302 |
+
|
303 |
+
if not notextinflow:
|
304 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
305 |
+
|
306 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
307 |
+
|
308 |
+
if attention_mask is not None:
|
309 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
310 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
311 |
+
|
312 |
+
query = attn.to_q(hidden_states)
|
313 |
+
key = attn.to_k(hidden_states)
|
314 |
+
value = attn.to_v(hidden_states)
|
315 |
+
|
316 |
+
inner_dim = key.shape[-1]
|
317 |
+
head_dim = inner_dim // attn.heads
|
318 |
+
|
319 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
320 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
321 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
322 |
+
|
323 |
+
if attn.norm_q is not None:
|
324 |
+
query = attn.norm_q(query)
|
325 |
+
if attn.norm_k is not None:
|
326 |
+
key = attn.norm_k(key)
|
327 |
+
|
328 |
+
# Apply RoPE if needed
|
329 |
+
if image_rotary_emb is not None:
|
330 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
331 |
+
|
332 |
+
if not notextinflow:
|
333 |
+
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
334 |
+
if not attn.is_cross_attention:
|
335 |
+
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
336 |
+
else:
|
337 |
+
query[:, :, :] = apply_rotary_emb(query[:, :, :], image_rotary_emb)
|
338 |
+
if not attn.is_cross_attention:
|
339 |
+
key[:, :, :] = apply_rotary_emb(key[:, :, :], image_rotary_emb)
|
340 |
+
|
341 |
+
hidden_states = F.scaled_dot_product_attention(
|
342 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
343 |
+
)
|
344 |
+
|
345 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
346 |
+
|
347 |
+
# linear proj
|
348 |
+
hidden_states = attn.to_out[0](hidden_states)
|
349 |
+
# dropout
|
350 |
+
hidden_states = attn.to_out[1](hidden_states)
|
351 |
+
|
352 |
+
if not notextinflow:
|
353 |
+
encoder_hidden_states, hidden_states = hidden_states.split(
|
354 |
+
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
355 |
+
)
|
356 |
+
|
357 |
+
return hidden_states, encoder_hidden_states
|
finetune/modules/depth_warping/__init__.py
ADDED
File without changes
|
finetune/modules/depth_warping/camera/Camera.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import json
|
4 |
+
import torch
|
5 |
+
|
6 |
+
import torch.nn as nn
|
7 |
+
import torchvision.transforms as transforms
|
8 |
+
import torchvision.transforms.functional as F
|
9 |
+
import numpy as np
|
10 |
+
from einops import rearrange, repeat
|
11 |
+
|
12 |
+
class Camera(object):
|
13 |
+
def __init__(self, entry):
|
14 |
+
fx, fy, cx, cy = entry[1:5]
|
15 |
+
self.fx = fx
|
16 |
+
self.fy = fy
|
17 |
+
self.cx = cx
|
18 |
+
self.cy = cy
|
19 |
+
w2c_mat = np.array(entry[7:]).reshape(3, 4)
|
20 |
+
w2c_mat_4x4 = np.eye(4)
|
21 |
+
w2c_mat_4x4[:3, :] = w2c_mat
|
22 |
+
self.w2c_mat = w2c_mat_4x4
|
23 |
+
self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
|
24 |
+
|
25 |
+
def load_cameras(path):
|
26 |
+
with open(path, 'r') as f:
|
27 |
+
poses = f.readlines()
|
28 |
+
poses = [pose.strip().split(' ') for pose in poses[1:]]
|
29 |
+
cam_params = [[float(x) for x in pose] for pose in poses]
|
30 |
+
cam_params = [Camera(cam_param) for cam_param in cam_params]
|
31 |
+
return cam_params
|
32 |
+
|
33 |
+
def get_relative_pose(cam_params):
|
34 |
+
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
|
35 |
+
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
|
36 |
+
source_cam_c2w = abs_c2ws[0]
|
37 |
+
cam_to_origin = 0
|
38 |
+
target_cam_c2w = np.array([
|
39 |
+
[1, 0, 0, 0],
|
40 |
+
[0, 1, 0, -cam_to_origin],
|
41 |
+
[0, 0, 1, 0],
|
42 |
+
[0, 0, 0, 1]
|
43 |
+
])
|
44 |
+
abs2rel = target_cam_c2w @ abs_w2cs[0]
|
45 |
+
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
|
46 |
+
ret_poses = np.array(ret_poses, dtype=np.float32)
|
47 |
+
return ret_poses
|
48 |
+
|
49 |
+
def get_K(intrinsics, size, do_normalize=False):
|
50 |
+
def normalize_intrinsic(x, size):
|
51 |
+
h, w = size
|
52 |
+
x[:,:,0:1] = x[:,:,0:1] / w
|
53 |
+
x[:,:,1:2] = x[:,:,1:2] / h
|
54 |
+
return x
|
55 |
+
|
56 |
+
b, _, t, _ = intrinsics.shape
|
57 |
+
K = torch.zeros((b, t, 9), dtype=intrinsics.dtype, device=intrinsics.device)
|
58 |
+
fx, fy, cx, cy = intrinsics.squeeze(1).chunk(4, dim=-1)
|
59 |
+
|
60 |
+
K[:,:,0:1] = fx
|
61 |
+
K[:,:,2:3] = cx
|
62 |
+
K[:,:,4:5] = fy
|
63 |
+
K[:,:,5:6] = cy
|
64 |
+
K[:,:,8:9] = 1.0
|
65 |
+
|
66 |
+
K = rearrange(K, "b t (h w) -> b t h w", h=3, w=3)
|
67 |
+
if do_normalize:
|
68 |
+
K = normalize_intrinsic(K, size)
|
69 |
+
|
70 |
+
return K
|
finetune/modules/depth_warping/camera/WarperPytorch.py
ADDED
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Shree KRISHNAya Namaha
|
2 |
+
# Differentiable warper implemented in PyTorch. Warping is done on batches.
|
3 |
+
# Tested on PyTorch 1.8.1
|
4 |
+
# Author: Nagabhushan S N
|
5 |
+
# Last Modified: 27/09/2021
|
6 |
+
# Code from https://github.com/NagabhushanSN95/Pose-Warping
|
7 |
+
|
8 |
+
import datetime
|
9 |
+
import time
|
10 |
+
import traceback
|
11 |
+
from pathlib import Path
|
12 |
+
from typing import Tuple, Optional
|
13 |
+
|
14 |
+
import numpy
|
15 |
+
# import skimage.io
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from einops import rearrange, repeat
|
19 |
+
# import Imath
|
20 |
+
# import OpenEXR
|
21 |
+
|
22 |
+
import pdb
|
23 |
+
|
24 |
+
class Warper:
|
25 |
+
def __init__(self, resolution: tuple = None):
|
26 |
+
self.resolution = resolution
|
27 |
+
|
28 |
+
def forward_warp(self, frame1: torch.Tensor, mask1: Optional[torch.Tensor], depth1: torch.Tensor,
|
29 |
+
transformation1: torch.Tensor, transformation2: torch.Tensor, intrinsic1: torch.Tensor,
|
30 |
+
intrinsic2: Optional[torch.Tensor], is_image=True) -> \
|
31 |
+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
32 |
+
"""
|
33 |
+
Given a frame1 and global transformations transformation1 and transformation2, warps frame1 to next view using
|
34 |
+
bilinear splatting.
|
35 |
+
All arrays should be torch tensors with batch dimension and channel first
|
36 |
+
:param frame1: (b, 3, h, w). If frame1 is not in the range [-1, 1], either set is_image=False when calling
|
37 |
+
bilinear_splatting on frame within this function, or modify clipping in bilinear_splatting()
|
38 |
+
method accordingly.
|
39 |
+
:param mask1: (b, 1, h, w) - 1 for known, 0 for unknown. Optional
|
40 |
+
:param depth1: (b, 1, h, w)
|
41 |
+
:param transformation1: (b, 4, 4) extrinsic transformation matrix of first view: [R, t; 0, 1]
|
42 |
+
:param transformation2: (b, 4, 4) extrinsic transformation matrix of second view: [R, t; 0, 1]
|
43 |
+
:param intrinsic1: (b, 3, 3) camera intrinsic matrix
|
44 |
+
:param intrinsic2: (b, 3, 3) camera intrinsic matrix. Optional
|
45 |
+
"""
|
46 |
+
self.device = frame1.device
|
47 |
+
|
48 |
+
if self.resolution is not None:
|
49 |
+
assert frame1.shape[2:4] == self.resolution
|
50 |
+
b, c, h, w = frame1.shape
|
51 |
+
if mask1 is None:
|
52 |
+
mask1 = torch.ones(size=(b, 1, h, w)).to(frame1)
|
53 |
+
if intrinsic2 is None:
|
54 |
+
intrinsic2 = intrinsic1.clone()
|
55 |
+
|
56 |
+
assert frame1.shape == (b, 3, h, w) or frame1.shape == (b, 2, h, w) # flow b2hw
|
57 |
+
assert mask1.shape == (b, 1, h, w)
|
58 |
+
assert depth1.shape == (b, 1, h, w)
|
59 |
+
assert transformation1.shape == (b, 4, 4)
|
60 |
+
assert transformation2.shape == (b, 4, 4)
|
61 |
+
assert intrinsic1.shape == (b, 3, 3)
|
62 |
+
assert intrinsic2.shape == (b, 3, 3)
|
63 |
+
|
64 |
+
frame1 = frame1.to(self.device)
|
65 |
+
mask1 = mask1.to(self.device)
|
66 |
+
depth1 = depth1.to(self.device)
|
67 |
+
transformation1 = transformation1.to(self.device)
|
68 |
+
transformation2 = transformation2.to(self.device)
|
69 |
+
intrinsic1 = intrinsic1.to(self.device)
|
70 |
+
intrinsic2 = intrinsic2.to(self.device)
|
71 |
+
|
72 |
+
trans_points1 = self.compute_transformed_points(depth1, transformation1, transformation2, intrinsic1,
|
73 |
+
intrinsic2)
|
74 |
+
# trans_coordinates = trans_points1[:, :, :2, 0] / trans_points1[:, :, 2:3, 0]
|
75 |
+
trans_coordinates = trans_points1[:, :, :, :2, 0] / (trans_points1[:, :, :, 2:3, 0]+1e-7)
|
76 |
+
trans_depth1 = rearrange(trans_points1[:, :, :, 2:3, 0], "b h w c -> b c h w")
|
77 |
+
|
78 |
+
grid = self.create_grid(b, h, w).to(trans_coordinates)
|
79 |
+
flow12 = rearrange(trans_coordinates, "b h w c -> b c h w") - grid
|
80 |
+
|
81 |
+
warped_frame2, mask2 = self.bilinear_splatting(frame1, mask1, trans_depth1, flow12, None, is_image=is_image)
|
82 |
+
warped_depth2 = self.bilinear_splatting(trans_depth1, mask1, trans_depth1, flow12, None, is_image=False)[0] # [0][:, :, 0]
|
83 |
+
|
84 |
+
return warped_frame2, mask2, warped_depth2, flow12
|
85 |
+
|
86 |
+
def forward_warp_displacement(self, depth1: torch.Tensor, flow1: torch.Tensor,
|
87 |
+
transformation1: torch.Tensor, transformation2: torch.Tensor, intrinsic1: torch.Tensor, intrinsic2: Optional[torch.Tensor],):
|
88 |
+
"""
|
89 |
+
Given a frame1 and global transformations transformation1 and transformation2, warps frame1 to next view using
|
90 |
+
bilinear splatting.
|
91 |
+
All arrays should be torch tensors with batch dimension and channel first
|
92 |
+
:param depth1: (b, 1, h, w)
|
93 |
+
:param flow1: (b, 2, h, w)
|
94 |
+
:param transformation1: (b, 4, 4) extrinsic transformation matrix of first view: [R, t; 0, 1]
|
95 |
+
:param transformation2: (b, 4, 4) extrinsic transformation matrix of second view: [R, t; 0, 1]
|
96 |
+
:param intrinsic1: (b, 3, 3) camera intrinsic matrix
|
97 |
+
:param intrinsic2: (b, 3, 3) camera intrinsic matrix. Optional
|
98 |
+
"""
|
99 |
+
self.device = flow1.device
|
100 |
+
|
101 |
+
if self.resolution is not None:
|
102 |
+
assert flow1.shape[2:4] == self.resolution
|
103 |
+
b, c, h, w = flow1.shape
|
104 |
+
if intrinsic2 is None:
|
105 |
+
intrinsic2 = intrinsic1.clone()
|
106 |
+
|
107 |
+
assert flow1.shape == (b, 2, h, w)
|
108 |
+
assert depth1.shape == (b, 1, h, w)
|
109 |
+
assert transformation1.shape == (b, 4, 4)
|
110 |
+
assert transformation2.shape == (b, 4, 4)
|
111 |
+
assert intrinsic1.shape == (b, 3, 3)
|
112 |
+
assert intrinsic2.shape == (b, 3, 3)
|
113 |
+
|
114 |
+
depth1 = depth1.to(self.device)
|
115 |
+
flow1 = flow1.to(self.device)
|
116 |
+
transformation1 = transformation1.to(self.device)
|
117 |
+
transformation2 = transformation2.to(self.device)
|
118 |
+
intrinsic1 = intrinsic1.to(self.device)
|
119 |
+
intrinsic2 = intrinsic2.to(self.device)
|
120 |
+
|
121 |
+
trans_points1 = self.compute_transformed_points(depth1, transformation1, transformation2, intrinsic1, intrinsic2)
|
122 |
+
trans_coordinates1 = trans_points1[:, :, :, :2, 0] / (trans_points1[:, :, :, 2:3, 0]+1e-7)
|
123 |
+
|
124 |
+
trans_points2 = self.compute_transformed_points(depth1, transformation1, transformation2, intrinsic1, intrinsic2, flow1)
|
125 |
+
trans_coordinates2 = trans_points2[:, :, :, :2, 0] / (trans_points2[:, :, :, 2:3, 0]+1e-7)
|
126 |
+
|
127 |
+
flow12_displacement = rearrange(trans_coordinates2 - trans_coordinates1, "b h w c -> b c h w")
|
128 |
+
|
129 |
+
return flow12_displacement
|
130 |
+
|
131 |
+
def compute_transformed_points(self, depth1: torch.Tensor, transformation1: torch.Tensor, transformation2: torch.Tensor,
|
132 |
+
intrinsic1: torch.Tensor, intrinsic2: Optional[torch.Tensor], flow1: Optional[torch.Tensor]=None):
|
133 |
+
"""
|
134 |
+
Computes transformed position for each pixel location
|
135 |
+
"""
|
136 |
+
if self.resolution is not None:
|
137 |
+
assert depth1.shape[2:4] == self.resolution
|
138 |
+
b, _, h, w = depth1.shape
|
139 |
+
if intrinsic2 is None:
|
140 |
+
intrinsic2 = intrinsic1.clone()
|
141 |
+
transformation = torch.bmm(transformation2, torch.linalg.inv(transformation1)).to(transformation1.dtype) # (b, 4, 4)
|
142 |
+
|
143 |
+
x1d = torch.arange(0, w)[None]
|
144 |
+
y1d = torch.arange(0, h)[:, None]
|
145 |
+
x2d = x1d.repeat([h, 1]).to(depth1) # (h, w)
|
146 |
+
y2d = y1d.repeat([1, w]).to(depth1) # (h, w)
|
147 |
+
|
148 |
+
ones_2d = torch.ones(size=(h, w)).to(depth1) # (h, w)
|
149 |
+
ones_4d = ones_2d[None, :, :, None, None].repeat([b, 1, 1, 1, 1]) # (b, h, w, 1, 1)
|
150 |
+
|
151 |
+
|
152 |
+
if flow1 is not None:
|
153 |
+
x4d = repeat(x2d[None, :, :, None], '1 h w c -> b h w c', b=b)
|
154 |
+
y4d = repeat(y2d[None, :, :, None], '1 h w c -> b h w c', b=b)
|
155 |
+
flow1_x4d = rearrange(flow1[:,:1].detach().clone(), "b c h w -> b h w c")
|
156 |
+
flow1_y4d = rearrange(flow1[:,1:].detach().clone(), "b c h w -> b h w c")
|
157 |
+
|
158 |
+
x4d = x4d + flow1_x4d
|
159 |
+
y4d = y4d + flow1_y4d
|
160 |
+
|
161 |
+
pos_vectors_homo = torch.stack([x4d, y4d, ones_4d.squeeze(-1)], dim=3) # (b, h, w, 3, 1)
|
162 |
+
else:
|
163 |
+
pos_vectors_homo = torch.stack([x2d, y2d, ones_2d], dim=2)[None, :, :, :, None] # (1, h, w, 3, 1)
|
164 |
+
|
165 |
+
intrinsic1_inv = torch.linalg.inv(intrinsic1) # (b, 3, 3)
|
166 |
+
intrinsic1_inv_4d = intrinsic1_inv[:, None, None] # (b, 1, 1, 3, 3)
|
167 |
+
intrinsic2_4d = intrinsic2[:, None, None] # (b, 1, 1, 3, 3)
|
168 |
+
depth_4d = depth1[:, 0][:, :, :, None, None] # (b, h, w, 1, 1)
|
169 |
+
trans_4d = transformation[:, None, None] # (b, 1, 1, 4, 4)
|
170 |
+
|
171 |
+
unnormalized_pos = torch.matmul(intrinsic1_inv_4d, pos_vectors_homo).to(transformation1.dtype) # (b, h, w, 3, 1)
|
172 |
+
world_points = depth_4d * unnormalized_pos # (b, h, w, 3, 1)
|
173 |
+
world_points_homo = torch.cat([world_points, ones_4d], dim=3) # (b, h, w, 4, 1)
|
174 |
+
trans_world_homo = torch.matmul(trans_4d, world_points_homo).to(transformation1.dtype) # (b, h, w, 4, 1)
|
175 |
+
trans_world = trans_world_homo[:, :, :, :3] # (b, h, w, 3, 1)
|
176 |
+
trans_norm_points = torch.matmul(intrinsic2_4d, trans_world).to(transformation1.dtype) # (b, h, w, 3, 1)
|
177 |
+
return trans_norm_points
|
178 |
+
|
179 |
+
def bilinear_splatting(self, frame1: torch.Tensor, mask1: Optional[torch.Tensor], depth1: torch.Tensor,
|
180 |
+
flow12: torch.Tensor, flow12_mask: Optional[torch.Tensor], is_image: bool = False) -> \
|
181 |
+
Tuple[torch.Tensor, torch.Tensor]:
|
182 |
+
"""
|
183 |
+
Bilinear splatting
|
184 |
+
:param frame1: (b,c,h,w)
|
185 |
+
:param mask1: (b,1,h,w): 1 for known, 0 for unknown. Optional
|
186 |
+
:param depth1: (b,1,h,w)
|
187 |
+
:param flow12: (b,2,h,w)
|
188 |
+
:param flow12_mask: (b,1,h,w): 1 for valid flow, 0 for invalid flow. Optional
|
189 |
+
:param is_image: if true, output will be clipped to (-1,1) range
|
190 |
+
:return: warped_frame2: (b,c,h,w)
|
191 |
+
mask2: (b,1,h,w): 1 for known and 0 for unknown
|
192 |
+
"""
|
193 |
+
if self.resolution is not None:
|
194 |
+
assert frame1.shape[2:4] == self.resolution
|
195 |
+
b, c, h, w = frame1.shape
|
196 |
+
if mask1 is None:
|
197 |
+
mask1 = torch.ones(size=(b, 1, h, w)).to(frame1)
|
198 |
+
if flow12_mask is None:
|
199 |
+
flow12_mask = torch.ones(size=(b, 1, h, w)).to(flow12)
|
200 |
+
grid = self.create_grid(b, h, w).to(frame1)
|
201 |
+
trans_pos = flow12 + grid
|
202 |
+
|
203 |
+
trans_pos_offset = trans_pos + 1
|
204 |
+
trans_pos_floor = torch.floor(trans_pos_offset).long()
|
205 |
+
trans_pos_ceil = torch.ceil(trans_pos_offset).long()
|
206 |
+
trans_pos_offset = torch.stack([
|
207 |
+
torch.clamp(trans_pos_offset[:, 0], min=0, max=w + 1),
|
208 |
+
torch.clamp(trans_pos_offset[:, 1], min=0, max=h + 1)], dim=1)
|
209 |
+
trans_pos_floor = torch.stack([
|
210 |
+
torch.clamp(trans_pos_floor[:, 0], min=0, max=w + 1),
|
211 |
+
torch.clamp(trans_pos_floor[:, 1], min=0, max=h + 1)], dim=1)
|
212 |
+
trans_pos_ceil = torch.stack([
|
213 |
+
torch.clamp(trans_pos_ceil[:, 0], min=0, max=w + 1),
|
214 |
+
torch.clamp(trans_pos_ceil[:, 1], min=0, max=h + 1)], dim=1)
|
215 |
+
|
216 |
+
prox_weight_nw = (1 - (trans_pos_offset[:, 1:2] - trans_pos_floor[:, 1:2])) * \
|
217 |
+
(1 - (trans_pos_offset[:, 0:1] - trans_pos_floor[:, 0:1]))
|
218 |
+
prox_weight_sw = (1 - (trans_pos_ceil[:, 1:2] - trans_pos_offset[:, 1:2])) * \
|
219 |
+
(1 - (trans_pos_offset[:, 0:1] - trans_pos_floor[:, 0:1]))
|
220 |
+
prox_weight_ne = (1 - (trans_pos_offset[:, 1:2] - trans_pos_floor[:, 1:2])) * \
|
221 |
+
(1 - (trans_pos_ceil[:, 0:1] - trans_pos_offset[:, 0:1]))
|
222 |
+
prox_weight_se = (1 - (trans_pos_ceil[:, 1:2] - trans_pos_offset[:, 1:2])) * \
|
223 |
+
(1 - (trans_pos_ceil[:, 0:1] - trans_pos_offset[:, 0:1]))
|
224 |
+
|
225 |
+
sat_depth1 = torch.clamp(depth1, min=0, max=1000)
|
226 |
+
log_depth1 = torch.log(1 + sat_depth1)
|
227 |
+
depth_weights = torch.exp(log_depth1 / log_depth1.max() * 50)
|
228 |
+
|
229 |
+
weight_nw = torch.moveaxis(prox_weight_nw * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3], [0, 3, 1, 2])
|
230 |
+
weight_sw = torch.moveaxis(prox_weight_sw * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3], [0, 3, 1, 2])
|
231 |
+
weight_ne = torch.moveaxis(prox_weight_ne * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3], [0, 3, 1, 2])
|
232 |
+
weight_se = torch.moveaxis(prox_weight_se * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3], [0, 3, 1, 2])
|
233 |
+
|
234 |
+
warped_frame = torch.zeros(size=(b, h + 2, w + 2, c), dtype=torch.float32).to(frame1)
|
235 |
+
warped_weights = torch.zeros(size=(b, h + 2, w + 2, 1), dtype=torch.float32).to(frame1)
|
236 |
+
|
237 |
+
frame1_cl = torch.moveaxis(frame1, [0, 1, 2, 3], [0, 3, 1, 2])
|
238 |
+
batch_indices = torch.arange(b)[:, None, None].to(frame1.device)
|
239 |
+
warped_frame.index_put_((batch_indices, trans_pos_floor[:, 1], trans_pos_floor[:, 0]),
|
240 |
+
frame1_cl * weight_nw, accumulate=True)
|
241 |
+
warped_frame.index_put_((batch_indices, trans_pos_ceil[:, 1], trans_pos_floor[:, 0]),
|
242 |
+
frame1_cl * weight_sw, accumulate=True)
|
243 |
+
warped_frame.index_put_((batch_indices, trans_pos_floor[:, 1], trans_pos_ceil[:, 0]),
|
244 |
+
frame1_cl * weight_ne, accumulate=True)
|
245 |
+
warped_frame.index_put_((batch_indices, trans_pos_ceil[:, 1], trans_pos_ceil[:, 0]),
|
246 |
+
frame1_cl * weight_se, accumulate=True)
|
247 |
+
|
248 |
+
warped_weights.index_put_((batch_indices, trans_pos_floor[:, 1], trans_pos_floor[:, 0]),
|
249 |
+
weight_nw, accumulate=True)
|
250 |
+
warped_weights.index_put_((batch_indices, trans_pos_ceil[:, 1], trans_pos_floor[:, 0]),
|
251 |
+
weight_sw, accumulate=True)
|
252 |
+
warped_weights.index_put_((batch_indices, trans_pos_floor[:, 1], trans_pos_ceil[:, 0]),
|
253 |
+
weight_ne, accumulate=True)
|
254 |
+
warped_weights.index_put_((batch_indices, trans_pos_ceil[:, 1], trans_pos_ceil[:, 0]),
|
255 |
+
weight_se, accumulate=True)
|
256 |
+
|
257 |
+
warped_frame_cf = torch.moveaxis(warped_frame, [0, 1, 2, 3], [0, 2, 3, 1])
|
258 |
+
warped_weights_cf = torch.moveaxis(warped_weights, [0, 1, 2, 3], [0, 2, 3, 1])
|
259 |
+
cropped_warped_frame = warped_frame_cf[:, :, 1:-1, 1:-1]
|
260 |
+
cropped_weights = warped_weights_cf[:, :, 1:-1, 1:-1]
|
261 |
+
|
262 |
+
mask = cropped_weights > 0
|
263 |
+
zero_value = -1 if is_image else 0
|
264 |
+
zero_tensor = torch.tensor(zero_value, dtype=frame1.dtype, device=frame1.device)
|
265 |
+
warped_frame2 = torch.where(mask, cropped_warped_frame / cropped_weights, zero_tensor)
|
266 |
+
mask2 = mask.to(frame1)
|
267 |
+
|
268 |
+
if is_image:
|
269 |
+
assert warped_frame2.min() >= -1.1 # Allow for rounding errors
|
270 |
+
assert warped_frame2.max() <= 1.1
|
271 |
+
warped_frame2 = torch.clamp(warped_frame2, min=-1, max=1)
|
272 |
+
return warped_frame2, mask2
|
273 |
+
|
274 |
+
def bilinear_interpolation(self, frame2: torch.Tensor, mask2: Optional[torch.Tensor], flow12: torch.Tensor,
|
275 |
+
flow12_mask: Optional[torch.Tensor], is_image: bool = False) -> \
|
276 |
+
Tuple[torch.Tensor, torch.Tensor]:
|
277 |
+
"""
|
278 |
+
Bilinear interpolation
|
279 |
+
:param frame2: (b, c, h, w)
|
280 |
+
:param mask2: (b, 1, h, w): 1 for known, 0 for unknown. Optional
|
281 |
+
:param flow12: (b, 2, h, w)
|
282 |
+
:param flow12_mask: (b, 1, h, w): 1 for valid flow, 0 for invalid flow. Optional
|
283 |
+
:param is_image: if true, output will be clipped to (-1,1) range
|
284 |
+
:return: warped_frame1: (b, c, h, w)
|
285 |
+
mask1: (b, 1, h, w): 1 for known and 0 for unknown
|
286 |
+
"""
|
287 |
+
if self.resolution is not None:
|
288 |
+
assert frame2.shape[2:4] == self.resolution
|
289 |
+
b, c, h, w = frame2.shape
|
290 |
+
if mask2 is None:
|
291 |
+
mask2 = torch.ones(size=(b, 1, h, w)).to(frame2)
|
292 |
+
if flow12_mask is None:
|
293 |
+
flow12_mask = torch.ones(size=(b, 1, h, w)).to(flow12)
|
294 |
+
grid = self.create_grid(b, h, w).to(frame2)
|
295 |
+
trans_pos = flow12 + grid
|
296 |
+
|
297 |
+
trans_pos_offset = trans_pos + 1
|
298 |
+
trans_pos_floor = torch.floor(trans_pos_offset).long()
|
299 |
+
trans_pos_ceil = torch.ceil(trans_pos_offset).long()
|
300 |
+
trans_pos_offset = torch.stack([
|
301 |
+
torch.clamp(trans_pos_offset[:, 0], min=0, max=w + 1),
|
302 |
+
torch.clamp(trans_pos_offset[:, 1], min=0, max=h + 1)], dim=1)
|
303 |
+
trans_pos_floor = torch.stack([
|
304 |
+
torch.clamp(trans_pos_floor[:, 0], min=0, max=w + 1),
|
305 |
+
torch.clamp(trans_pos_floor[:, 1], min=0, max=h + 1)], dim=1)
|
306 |
+
trans_pos_ceil = torch.stack([
|
307 |
+
torch.clamp(trans_pos_ceil[:, 0], min=0, max=w + 1),
|
308 |
+
torch.clamp(trans_pos_ceil[:, 1], min=0, max=h + 1)], dim=1)
|
309 |
+
|
310 |
+
prox_weight_nw = (1 - (trans_pos_offset[:, 1:2] - trans_pos_floor[:, 1:2])) * \
|
311 |
+
(1 - (trans_pos_offset[:, 0:1] - trans_pos_floor[:, 0:1]))
|
312 |
+
prox_weight_sw = (1 - (trans_pos_ceil[:, 1:2] - trans_pos_offset[:, 1:2])) * \
|
313 |
+
(1 - (trans_pos_offset[:, 0:1] - trans_pos_floor[:, 0:1]))
|
314 |
+
prox_weight_ne = (1 - (trans_pos_offset[:, 1:2] - trans_pos_floor[:, 1:2])) * \
|
315 |
+
(1 - (trans_pos_ceil[:, 0:1] - trans_pos_offset[:, 0:1]))
|
316 |
+
prox_weight_se = (1 - (trans_pos_ceil[:, 1:2] - trans_pos_offset[:, 1:2])) * \
|
317 |
+
(1 - (trans_pos_ceil[:, 0:1] - trans_pos_offset[:, 0:1]))
|
318 |
+
|
319 |
+
weight_nw = torch.moveaxis(prox_weight_nw * flow12_mask, [0, 1, 2, 3], [0, 3, 1, 2])
|
320 |
+
weight_sw = torch.moveaxis(prox_weight_sw * flow12_mask, [0, 1, 2, 3], [0, 3, 1, 2])
|
321 |
+
weight_ne = torch.moveaxis(prox_weight_ne * flow12_mask, [0, 1, 2, 3], [0, 3, 1, 2])
|
322 |
+
weight_se = torch.moveaxis(prox_weight_se * flow12_mask, [0, 1, 2, 3], [0, 3, 1, 2])
|
323 |
+
|
324 |
+
frame2_offset = F.pad(frame2, [1, 1, 1, 1])
|
325 |
+
mask2_offset = F.pad(mask2, [1, 1, 1, 1])
|
326 |
+
bi = torch.arange(b)[:, None, None]
|
327 |
+
|
328 |
+
f2_nw = frame2_offset[bi, :, trans_pos_floor[:, 1], trans_pos_floor[:, 0]]
|
329 |
+
f2_sw = frame2_offset[bi, :, trans_pos_ceil[:, 1], trans_pos_floor[:, 0]]
|
330 |
+
f2_ne = frame2_offset[bi, :, trans_pos_floor[:, 1], trans_pos_ceil[:, 0]]
|
331 |
+
f2_se = frame2_offset[bi, :, trans_pos_ceil[:, 1], trans_pos_ceil[:, 0]]
|
332 |
+
|
333 |
+
m2_nw = mask2_offset[bi, :, trans_pos_floor[:, 1], trans_pos_floor[:, 0]]
|
334 |
+
m2_sw = mask2_offset[bi, :, trans_pos_ceil[:, 1], trans_pos_floor[:, 0]]
|
335 |
+
m2_ne = mask2_offset[bi, :, trans_pos_floor[:, 1], trans_pos_ceil[:, 0]]
|
336 |
+
m2_se = mask2_offset[bi, :, trans_pos_ceil[:, 1], trans_pos_ceil[:, 0]]
|
337 |
+
|
338 |
+
nr = weight_nw * f2_nw * m2_nw + weight_sw * f2_sw * m2_sw + \
|
339 |
+
weight_ne * f2_ne * m2_ne + weight_se * f2_se * m2_se
|
340 |
+
dr = weight_nw * m2_nw + weight_sw * m2_sw + weight_ne * m2_ne + weight_se * m2_se
|
341 |
+
|
342 |
+
zero_value = -1 if is_image else 0
|
343 |
+
zero_tensor = torch.tensor(zero_value, dtype=nr.dtype, device=nr.device)
|
344 |
+
warped_frame1 = torch.where(dr > 0, nr / dr, zero_tensor)
|
345 |
+
mask1 = (dr > 0).to(frame2)
|
346 |
+
|
347 |
+
# Convert to channel first
|
348 |
+
warped_frame1 = torch.moveaxis(warped_frame1, [0, 1, 2, 3], [0, 2, 3, 1])
|
349 |
+
mask1 = torch.moveaxis(mask1, [0, 1, 2, 3], [0, 2, 3, 1])
|
350 |
+
|
351 |
+
if is_image:
|
352 |
+
assert warped_frame1.min() >= -1.1 # Allow for rounding errors
|
353 |
+
assert warped_frame1.max() <= 1.1
|
354 |
+
warped_frame1 = torch.clamp(warped_frame1, min=-1, max=1)
|
355 |
+
return warped_frame1, mask1
|
356 |
+
|
357 |
+
@staticmethod
|
358 |
+
def create_grid(b, h, w):
|
359 |
+
x_1d = torch.arange(0, w)[None]
|
360 |
+
y_1d = torch.arange(0, h)[:, None]
|
361 |
+
x_2d = x_1d.repeat([h, 1])
|
362 |
+
y_2d = y_1d.repeat([1, w])
|
363 |
+
grid = torch.stack([x_2d, y_2d], dim=0)
|
364 |
+
batch_grid = grid[None].repeat([b, 1, 1, 1])
|
365 |
+
return batch_grid
|
366 |
+
|
367 |
+
# @staticmethod
|
368 |
+
# def read_image(path: Path) -> torch.Tensor:
|
369 |
+
# image = skimage.io.imread(path.as_posix())
|
370 |
+
# return image
|
371 |
+
|
372 |
+
# @staticmethod
|
373 |
+
# def read_depth(path: Path) -> torch.Tensor:
|
374 |
+
# if path.suffix == '.png':
|
375 |
+
# depth = skimage.io.imread(path.as_posix())
|
376 |
+
# elif path.suffix == '.npy':
|
377 |
+
# depth = numpy.load(path.as_posix())
|
378 |
+
# elif path.suffix == '.npz':
|
379 |
+
# with numpy.load(path.as_posix()) as depth_data:
|
380 |
+
# depth = depth_data['depth']
|
381 |
+
# elif path.suffix == '.exr':
|
382 |
+
# exr_file = OpenEXR.InputFile(path.as_posix())
|
383 |
+
# raw_bytes = exr_file.channel('B', Imath.PixelType(Imath.PixelType.FLOAT))
|
384 |
+
# depth_vector = numpy.frombuffer(raw_bytes, dtype=numpy.float32)
|
385 |
+
# height = exr_file.header()['displayWindow'].max.y + 1 - exr_file.header()['displayWindow'].min.y
|
386 |
+
# width = exr_file.header()['displayWindow'].max.x + 1 - exr_file.header()['displayWindow'].min.x
|
387 |
+
# depth = numpy.reshape(depth_vector, (height, width))
|
388 |
+
# else:
|
389 |
+
# raise RuntimeError(f'Unknown depth format: {path.suffix}')
|
390 |
+
# return depth
|
391 |
+
|
392 |
+
# @staticmethod
|
393 |
+
# def camera_intrinsic_transform(capture_width=1920, capture_height=1080, patch_start_point: tuple = (0, 0)):
|
394 |
+
# start_y, start_x = patch_start_point
|
395 |
+
# camera_intrinsics = numpy.eye(4)
|
396 |
+
# camera_intrinsics[0, 0] = 2100
|
397 |
+
# camera_intrinsics[0, 2] = capture_width / 2.0 - start_x
|
398 |
+
# camera_intrinsics[1, 1] = 2100
|
399 |
+
# camera_intrinsics[1, 2] = capture_height / 2.0 - start_y
|
400 |
+
# return camera_intrinsics
|
401 |
+
|
402 |
+
# @staticmethod
|
403 |
+
# def get_device(device: str):
|
404 |
+
# """
|
405 |
+
# Returns torch device object
|
406 |
+
# :param device: cpu/gpu0/gpu1
|
407 |
+
# :return:
|
408 |
+
# """
|
409 |
+
# if device == 'cpu':
|
410 |
+
# device = torch.device('cpu')
|
411 |
+
# elif device.startswith('gpu') and torch.cuda.is_available():
|
412 |
+
# gpu_num = int(device[3:])
|
413 |
+
# device = torch.device(f'cuda:{gpu_num}')
|
414 |
+
# else:
|
415 |
+
# device = torch.device('cpu')
|
416 |
+
# return device
|
finetune/modules/depth_warping/depth_anything_v2/depth_anything_wrapper.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from cameractrl.modules.depth_anything_v2.dpt import DepthAnythingV2
|
3 |
+
|
4 |
+
class MVSplat_wrapper(nn.Module):
|
5 |
+
def __init__(
|
6 |
+
self,
|
7 |
+
model_configs,
|
8 |
+
ckpt_path,
|
9 |
+
):
|
10 |
+
super().__init__()
|
11 |
+
|
12 |
+
depth_anything = DepthAnythingV2(model_configs)
|
finetune/modules/depth_warping/depth_anything_v2/dinov2.py
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
9 |
+
|
10 |
+
from functools import partial
|
11 |
+
import math
|
12 |
+
import logging
|
13 |
+
from typing import Sequence, Tuple, Union, Callable
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.utils.checkpoint
|
18 |
+
from torch.nn.init import trunc_normal_
|
19 |
+
|
20 |
+
from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.getLogger("dinov2")
|
24 |
+
|
25 |
+
|
26 |
+
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
27 |
+
if not depth_first and include_root:
|
28 |
+
fn(module=module, name=name)
|
29 |
+
for child_name, child_module in module.named_children():
|
30 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
31 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
32 |
+
if depth_first and include_root:
|
33 |
+
fn(module=module, name=name)
|
34 |
+
return module
|
35 |
+
|
36 |
+
|
37 |
+
class BlockChunk(nn.ModuleList):
|
38 |
+
def forward(self, x):
|
39 |
+
for b in self:
|
40 |
+
x = b(x)
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
class DinoVisionTransformer(nn.Module):
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
img_size=224,
|
48 |
+
patch_size=16,
|
49 |
+
in_chans=3,
|
50 |
+
embed_dim=768,
|
51 |
+
depth=12,
|
52 |
+
num_heads=12,
|
53 |
+
mlp_ratio=4.0,
|
54 |
+
qkv_bias=True,
|
55 |
+
ffn_bias=True,
|
56 |
+
proj_bias=True,
|
57 |
+
drop_path_rate=0.0,
|
58 |
+
drop_path_uniform=False,
|
59 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
60 |
+
embed_layer=PatchEmbed,
|
61 |
+
act_layer=nn.GELU,
|
62 |
+
block_fn=Block,
|
63 |
+
ffn_layer="mlp",
|
64 |
+
block_chunks=1,
|
65 |
+
num_register_tokens=0,
|
66 |
+
interpolate_antialias=False,
|
67 |
+
interpolate_offset=0.1,
|
68 |
+
):
|
69 |
+
"""
|
70 |
+
Args:
|
71 |
+
img_size (int, tuple): input image size
|
72 |
+
patch_size (int, tuple): patch size
|
73 |
+
in_chans (int): number of input channels
|
74 |
+
embed_dim (int): embedding dimension
|
75 |
+
depth (int): depth of transformer
|
76 |
+
num_heads (int): number of attention heads
|
77 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
78 |
+
qkv_bias (bool): enable bias for qkv if True
|
79 |
+
proj_bias (bool): enable bias for proj in attn if True
|
80 |
+
ffn_bias (bool): enable bias for ffn if True
|
81 |
+
drop_path_rate (float): stochastic depth rate
|
82 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
83 |
+
weight_init (str): weight init scheme
|
84 |
+
init_values (float): layer-scale init values
|
85 |
+
embed_layer (nn.Module): patch embedding layer
|
86 |
+
act_layer (nn.Module): MLP activation layer
|
87 |
+
block_fn (nn.Module): transformer block class
|
88 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
89 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
90 |
+
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
91 |
+
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
92 |
+
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
93 |
+
"""
|
94 |
+
super().__init__()
|
95 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
96 |
+
|
97 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
98 |
+
self.num_tokens = 1
|
99 |
+
self.n_blocks = depth
|
100 |
+
self.num_heads = num_heads
|
101 |
+
self.patch_size = patch_size
|
102 |
+
self.num_register_tokens = num_register_tokens
|
103 |
+
self.interpolate_antialias = interpolate_antialias
|
104 |
+
self.interpolate_offset = interpolate_offset
|
105 |
+
|
106 |
+
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
107 |
+
num_patches = self.patch_embed.num_patches
|
108 |
+
|
109 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
110 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
111 |
+
assert num_register_tokens >= 0
|
112 |
+
self.register_tokens = (
|
113 |
+
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
|
114 |
+
)
|
115 |
+
|
116 |
+
if drop_path_uniform is True:
|
117 |
+
dpr = [drop_path_rate] * depth
|
118 |
+
else:
|
119 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
120 |
+
|
121 |
+
if ffn_layer == "mlp":
|
122 |
+
logger.info("using MLP layer as FFN")
|
123 |
+
ffn_layer = Mlp
|
124 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
125 |
+
logger.info("using SwiGLU layer as FFN")
|
126 |
+
ffn_layer = SwiGLUFFNFused
|
127 |
+
elif ffn_layer == "identity":
|
128 |
+
logger.info("using Identity layer as FFN")
|
129 |
+
|
130 |
+
def f(*args, **kwargs):
|
131 |
+
return nn.Identity()
|
132 |
+
|
133 |
+
ffn_layer = f
|
134 |
+
else:
|
135 |
+
raise NotImplementedError
|
136 |
+
|
137 |
+
blocks_list = [
|
138 |
+
block_fn(
|
139 |
+
dim=embed_dim,
|
140 |
+
num_heads=num_heads,
|
141 |
+
mlp_ratio=mlp_ratio,
|
142 |
+
qkv_bias=qkv_bias,
|
143 |
+
proj_bias=proj_bias,
|
144 |
+
ffn_bias=ffn_bias,
|
145 |
+
drop_path=dpr[i],
|
146 |
+
norm_layer=norm_layer,
|
147 |
+
act_layer=act_layer,
|
148 |
+
ffn_layer=ffn_layer,
|
149 |
+
init_values=init_values,
|
150 |
+
)
|
151 |
+
for i in range(depth)
|
152 |
+
]
|
153 |
+
if block_chunks > 0:
|
154 |
+
self.chunked_blocks = True
|
155 |
+
chunked_blocks = []
|
156 |
+
chunksize = depth // block_chunks
|
157 |
+
for i in range(0, depth, chunksize):
|
158 |
+
# this is to keep the block index consistent if we chunk the block list
|
159 |
+
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
160 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
161 |
+
else:
|
162 |
+
self.chunked_blocks = False
|
163 |
+
self.blocks = nn.ModuleList(blocks_list)
|
164 |
+
|
165 |
+
self.norm = norm_layer(embed_dim)
|
166 |
+
self.head = nn.Identity()
|
167 |
+
|
168 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
169 |
+
|
170 |
+
self.init_weights()
|
171 |
+
|
172 |
+
def init_weights(self):
|
173 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
174 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
175 |
+
if self.register_tokens is not None:
|
176 |
+
nn.init.normal_(self.register_tokens, std=1e-6)
|
177 |
+
named_apply(init_weights_vit_timm, self)
|
178 |
+
|
179 |
+
def interpolate_pos_encoding(self, x, w, h):
|
180 |
+
previous_dtype = x.dtype
|
181 |
+
npatch = x.shape[1] - 1
|
182 |
+
N = self.pos_embed.shape[1] - 1
|
183 |
+
if npatch == N and w == h:
|
184 |
+
return self.pos_embed
|
185 |
+
pos_embed = self.pos_embed.float()
|
186 |
+
class_pos_embed = pos_embed[:, 0]
|
187 |
+
patch_pos_embed = pos_embed[:, 1:]
|
188 |
+
dim = x.shape[-1]
|
189 |
+
w0 = w // self.patch_size
|
190 |
+
h0 = h // self.patch_size
|
191 |
+
# we add a small number to avoid floating point error in the interpolation
|
192 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
193 |
+
# DINOv2 with register modify the interpolate_offset from 0.1 to 0.0
|
194 |
+
w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
|
195 |
+
# w0, h0 = w0 + 0.1, h0 + 0.1
|
196 |
+
|
197 |
+
sqrt_N = math.sqrt(N)
|
198 |
+
sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
|
199 |
+
patch_pos_embed = nn.functional.interpolate(
|
200 |
+
patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
|
201 |
+
scale_factor=(sx, sy),
|
202 |
+
# (int(w0), int(h0)), # to solve the upsampling shape issue
|
203 |
+
mode="bicubic",
|
204 |
+
antialias=self.interpolate_antialias
|
205 |
+
)
|
206 |
+
|
207 |
+
assert int(w0) == patch_pos_embed.shape[-2]
|
208 |
+
assert int(h0) == patch_pos_embed.shape[-1]
|
209 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
210 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
211 |
+
|
212 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
213 |
+
B, nc, w, h = x.shape
|
214 |
+
x = self.patch_embed(x)
|
215 |
+
if masks is not None:
|
216 |
+
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
217 |
+
|
218 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
219 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
220 |
+
|
221 |
+
if self.register_tokens is not None:
|
222 |
+
x = torch.cat(
|
223 |
+
(
|
224 |
+
x[:, :1],
|
225 |
+
self.register_tokens.expand(x.shape[0], -1, -1),
|
226 |
+
x[:, 1:],
|
227 |
+
),
|
228 |
+
dim=1,
|
229 |
+
)
|
230 |
+
|
231 |
+
return x
|
232 |
+
|
233 |
+
def forward_features_list(self, x_list, masks_list):
|
234 |
+
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
235 |
+
for blk in self.blocks:
|
236 |
+
x = blk(x)
|
237 |
+
|
238 |
+
all_x = x
|
239 |
+
output = []
|
240 |
+
for x, masks in zip(all_x, masks_list):
|
241 |
+
x_norm = self.norm(x)
|
242 |
+
output.append(
|
243 |
+
{
|
244 |
+
"x_norm_clstoken": x_norm[:, 0],
|
245 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
246 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
247 |
+
"x_prenorm": x,
|
248 |
+
"masks": masks,
|
249 |
+
}
|
250 |
+
)
|
251 |
+
return output
|
252 |
+
|
253 |
+
def forward_features(self, x, masks=None):
|
254 |
+
if isinstance(x, list):
|
255 |
+
return self.forward_features_list(x, masks)
|
256 |
+
|
257 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
258 |
+
|
259 |
+
for blk in self.blocks:
|
260 |
+
x = blk(x)
|
261 |
+
|
262 |
+
x_norm = self.norm(x)
|
263 |
+
return {
|
264 |
+
"x_norm_clstoken": x_norm[:, 0],
|
265 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
266 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
267 |
+
"x_prenorm": x,
|
268 |
+
"masks": masks,
|
269 |
+
}
|
270 |
+
|
271 |
+
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
272 |
+
x = self.prepare_tokens_with_masks(x)
|
273 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
274 |
+
output, total_block_len = [], len(self.blocks)
|
275 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
276 |
+
for i, blk in enumerate(self.blocks):
|
277 |
+
x = blk(x)
|
278 |
+
if i in blocks_to_take:
|
279 |
+
output.append(x)
|
280 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
281 |
+
return output
|
282 |
+
|
283 |
+
def _get_intermediate_layers_chunked(self, x, n=1):
|
284 |
+
x = self.prepare_tokens_with_masks(x)
|
285 |
+
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
286 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
287 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
288 |
+
for block_chunk in self.blocks:
|
289 |
+
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
290 |
+
x = blk(x)
|
291 |
+
if i in blocks_to_take:
|
292 |
+
output.append(x)
|
293 |
+
i += 1
|
294 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
295 |
+
return output
|
296 |
+
|
297 |
+
def get_intermediate_layers(
|
298 |
+
self,
|
299 |
+
x: torch.Tensor,
|
300 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
301 |
+
reshape: bool = False,
|
302 |
+
return_class_token: bool = False,
|
303 |
+
norm=True
|
304 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
305 |
+
if self.chunked_blocks:
|
306 |
+
outputs = self._get_intermediate_layers_chunked(x, n)
|
307 |
+
else:
|
308 |
+
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
309 |
+
if norm:
|
310 |
+
outputs = [self.norm(out) for out in outputs]
|
311 |
+
class_tokens = [out[:, 0] for out in outputs]
|
312 |
+
outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
|
313 |
+
if reshape:
|
314 |
+
B, _, w, h = x.shape
|
315 |
+
outputs = [
|
316 |
+
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
317 |
+
for out in outputs
|
318 |
+
]
|
319 |
+
if return_class_token:
|
320 |
+
return tuple(zip(outputs, class_tokens))
|
321 |
+
return tuple(outputs)
|
322 |
+
|
323 |
+
def forward(self, *args, is_training=False, **kwargs):
|
324 |
+
ret = self.forward_features(*args, **kwargs)
|
325 |
+
if is_training:
|
326 |
+
return ret
|
327 |
+
else:
|
328 |
+
return self.head(ret["x_norm_clstoken"])
|
329 |
+
|
330 |
+
|
331 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
332 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
333 |
+
if isinstance(module, nn.Linear):
|
334 |
+
trunc_normal_(module.weight, std=0.02)
|
335 |
+
if module.bias is not None:
|
336 |
+
nn.init.zeros_(module.bias)
|
337 |
+
|
338 |
+
|
339 |
+
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
340 |
+
model = DinoVisionTransformer(
|
341 |
+
patch_size=patch_size,
|
342 |
+
embed_dim=384,
|
343 |
+
depth=12,
|
344 |
+
num_heads=6,
|
345 |
+
mlp_ratio=4,
|
346 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
347 |
+
num_register_tokens=num_register_tokens,
|
348 |
+
**kwargs,
|
349 |
+
)
|
350 |
+
return model
|
351 |
+
|
352 |
+
|
353 |
+
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
354 |
+
model = DinoVisionTransformer(
|
355 |
+
patch_size=patch_size,
|
356 |
+
embed_dim=768,
|
357 |
+
depth=12,
|
358 |
+
num_heads=12,
|
359 |
+
mlp_ratio=4,
|
360 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
361 |
+
num_register_tokens=num_register_tokens,
|
362 |
+
**kwargs,
|
363 |
+
)
|
364 |
+
return model
|
365 |
+
|
366 |
+
|
367 |
+
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
368 |
+
model = DinoVisionTransformer(
|
369 |
+
patch_size=patch_size,
|
370 |
+
embed_dim=1024,
|
371 |
+
depth=24,
|
372 |
+
num_heads=16,
|
373 |
+
mlp_ratio=4,
|
374 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
375 |
+
num_register_tokens=num_register_tokens,
|
376 |
+
**kwargs,
|
377 |
+
)
|
378 |
+
return model
|
379 |
+
|
380 |
+
|
381 |
+
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
|
382 |
+
"""
|
383 |
+
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
384 |
+
"""
|
385 |
+
model = DinoVisionTransformer(
|
386 |
+
patch_size=patch_size,
|
387 |
+
embed_dim=1536,
|
388 |
+
depth=40,
|
389 |
+
num_heads=24,
|
390 |
+
mlp_ratio=4,
|
391 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
392 |
+
num_register_tokens=num_register_tokens,
|
393 |
+
**kwargs,
|
394 |
+
)
|
395 |
+
return model
|
396 |
+
|
397 |
+
|
398 |
+
def DINOv2(model_name):
|
399 |
+
model_zoo = {
|
400 |
+
"vits": vit_small,
|
401 |
+
"vitb": vit_base,
|
402 |
+
"vitl": vit_large,
|
403 |
+
"vitg": vit_giant2
|
404 |
+
}
|
405 |
+
|
406 |
+
return model_zoo[model_name](
|
407 |
+
img_size=518,
|
408 |
+
patch_size=14,
|
409 |
+
init_values=1.0,
|
410 |
+
ffn_layer="mlp" if model_name != "vitg" else "swiglufused",
|
411 |
+
block_chunks=0,
|
412 |
+
num_register_tokens=0,
|
413 |
+
interpolate_antialias=False,
|
414 |
+
interpolate_offset=0.1
|
415 |
+
)
|
finetune/modules/depth_warping/depth_anything_v2/dinov2_layers/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from .mlp import Mlp
|
8 |
+
from .patch_embed import PatchEmbed
|
9 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
10 |
+
from .block import NestedTensorBlock
|
11 |
+
from .attention import MemEffAttention
|
finetune/modules/depth_warping/depth_anything_v2/dinov2_layers/attention.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
10 |
+
|
11 |
+
import logging
|
12 |
+
|
13 |
+
from torch import Tensor
|
14 |
+
from torch import nn
|
15 |
+
|
16 |
+
|
17 |
+
logger = logging.getLogger("dinov2")
|
18 |
+
|
19 |
+
|
20 |
+
try:
|
21 |
+
from xformers.ops import memory_efficient_attention, unbind, fmha
|
22 |
+
|
23 |
+
XFORMERS_AVAILABLE = True
|
24 |
+
except ImportError:
|
25 |
+
logger.warning("xFormers not available")
|
26 |
+
XFORMERS_AVAILABLE = False
|
27 |
+
|
28 |
+
|
29 |
+
class Attention(nn.Module):
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
dim: int,
|
33 |
+
num_heads: int = 8,
|
34 |
+
qkv_bias: bool = False,
|
35 |
+
proj_bias: bool = True,
|
36 |
+
attn_drop: float = 0.0,
|
37 |
+
proj_drop: float = 0.0,
|
38 |
+
) -> None:
|
39 |
+
super().__init__()
|
40 |
+
self.num_heads = num_heads
|
41 |
+
head_dim = dim // num_heads
|
42 |
+
self.scale = head_dim**-0.5
|
43 |
+
|
44 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
45 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
46 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
47 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
48 |
+
|
49 |
+
def forward(self, x: Tensor) -> Tensor:
|
50 |
+
B, N, C = x.shape
|
51 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
52 |
+
|
53 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
54 |
+
attn = q @ k.transpose(-2, -1)
|
55 |
+
|
56 |
+
attn = attn.softmax(dim=-1)
|
57 |
+
attn = self.attn_drop(attn)
|
58 |
+
|
59 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
60 |
+
x = self.proj(x)
|
61 |
+
x = self.proj_drop(x)
|
62 |
+
return x
|
63 |
+
|
64 |
+
|
65 |
+
class MemEffAttention(Attention):
|
66 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
67 |
+
if not XFORMERS_AVAILABLE:
|
68 |
+
assert attn_bias is None, "xFormers is required for nested tensors usage"
|
69 |
+
return super().forward(x)
|
70 |
+
|
71 |
+
B, N, C = x.shape
|
72 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
73 |
+
|
74 |
+
q, k, v = unbind(qkv, 2)
|
75 |
+
|
76 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
77 |
+
x = x.reshape([B, N, C])
|
78 |
+
|
79 |
+
x = self.proj(x)
|
80 |
+
x = self.proj_drop(x)
|
81 |
+
return x
|
82 |
+
|
83 |
+
|
finetune/modules/depth_warping/depth_anything_v2/dinov2_layers/block.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
10 |
+
|
11 |
+
import logging
|
12 |
+
from typing import Callable, List, Any, Tuple, Dict
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from torch import nn, Tensor
|
16 |
+
|
17 |
+
from .attention import Attention, MemEffAttention
|
18 |
+
from .drop_path import DropPath
|
19 |
+
from .layer_scale import LayerScale
|
20 |
+
from .mlp import Mlp
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.getLogger("dinov2")
|
24 |
+
|
25 |
+
|
26 |
+
try:
|
27 |
+
from xformers.ops import fmha
|
28 |
+
from xformers.ops import scaled_index_add, index_select_cat
|
29 |
+
|
30 |
+
XFORMERS_AVAILABLE = True
|
31 |
+
except ImportError:
|
32 |
+
logger.warning("xFormers not available")
|
33 |
+
XFORMERS_AVAILABLE = False
|
34 |
+
|
35 |
+
|
36 |
+
class Block(nn.Module):
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
dim: int,
|
40 |
+
num_heads: int,
|
41 |
+
mlp_ratio: float = 4.0,
|
42 |
+
qkv_bias: bool = False,
|
43 |
+
proj_bias: bool = True,
|
44 |
+
ffn_bias: bool = True,
|
45 |
+
drop: float = 0.0,
|
46 |
+
attn_drop: float = 0.0,
|
47 |
+
init_values=None,
|
48 |
+
drop_path: float = 0.0,
|
49 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
50 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
51 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
52 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
53 |
+
) -> None:
|
54 |
+
super().__init__()
|
55 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
56 |
+
self.norm1 = norm_layer(dim)
|
57 |
+
self.attn = attn_class(
|
58 |
+
dim,
|
59 |
+
num_heads=num_heads,
|
60 |
+
qkv_bias=qkv_bias,
|
61 |
+
proj_bias=proj_bias,
|
62 |
+
attn_drop=attn_drop,
|
63 |
+
proj_drop=drop,
|
64 |
+
)
|
65 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
66 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
67 |
+
|
68 |
+
self.norm2 = norm_layer(dim)
|
69 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
70 |
+
self.mlp = ffn_layer(
|
71 |
+
in_features=dim,
|
72 |
+
hidden_features=mlp_hidden_dim,
|
73 |
+
act_layer=act_layer,
|
74 |
+
drop=drop,
|
75 |
+
bias=ffn_bias,
|
76 |
+
)
|
77 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
78 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
79 |
+
|
80 |
+
self.sample_drop_ratio = drop_path
|
81 |
+
|
82 |
+
def forward(self, x: Tensor) -> Tensor:
|
83 |
+
def attn_residual_func(x: Tensor) -> Tensor:
|
84 |
+
return self.ls1(self.attn(self.norm1(x)))
|
85 |
+
|
86 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
87 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
88 |
+
|
89 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
90 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
91 |
+
x = drop_add_residual_stochastic_depth(
|
92 |
+
x,
|
93 |
+
residual_func=attn_residual_func,
|
94 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
95 |
+
)
|
96 |
+
x = drop_add_residual_stochastic_depth(
|
97 |
+
x,
|
98 |
+
residual_func=ffn_residual_func,
|
99 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
100 |
+
)
|
101 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
102 |
+
x = x + self.drop_path1(attn_residual_func(x))
|
103 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
104 |
+
else:
|
105 |
+
x = x + attn_residual_func(x)
|
106 |
+
x = x + ffn_residual_func(x)
|
107 |
+
return x
|
108 |
+
|
109 |
+
|
110 |
+
def drop_add_residual_stochastic_depth(
|
111 |
+
x: Tensor,
|
112 |
+
residual_func: Callable[[Tensor], Tensor],
|
113 |
+
sample_drop_ratio: float = 0.0,
|
114 |
+
) -> Tensor:
|
115 |
+
# 1) extract subset using permutation
|
116 |
+
b, n, d = x.shape
|
117 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
118 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
119 |
+
x_subset = x[brange]
|
120 |
+
|
121 |
+
# 2) apply residual_func to get residual
|
122 |
+
residual = residual_func(x_subset)
|
123 |
+
|
124 |
+
x_flat = x.flatten(1)
|
125 |
+
residual = residual.flatten(1)
|
126 |
+
|
127 |
+
residual_scale_factor = b / sample_subset_size
|
128 |
+
|
129 |
+
# 3) add the residual
|
130 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
131 |
+
return x_plus_residual.view_as(x)
|
132 |
+
|
133 |
+
|
134 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
135 |
+
b, n, d = x.shape
|
136 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
137 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
138 |
+
residual_scale_factor = b / sample_subset_size
|
139 |
+
return brange, residual_scale_factor
|
140 |
+
|
141 |
+
|
142 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
143 |
+
if scaling_vector is None:
|
144 |
+
x_flat = x.flatten(1)
|
145 |
+
residual = residual.flatten(1)
|
146 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
147 |
+
else:
|
148 |
+
x_plus_residual = scaled_index_add(
|
149 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
150 |
+
)
|
151 |
+
return x_plus_residual
|
152 |
+
|
153 |
+
|
154 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
155 |
+
|
156 |
+
|
157 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
158 |
+
"""
|
159 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
160 |
+
"""
|
161 |
+
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
162 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
163 |
+
if all_shapes not in attn_bias_cache.keys():
|
164 |
+
seqlens = []
|
165 |
+
for b, x in zip(batch_sizes, x_list):
|
166 |
+
for _ in range(b):
|
167 |
+
seqlens.append(x.shape[1])
|
168 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
169 |
+
attn_bias._batch_sizes = batch_sizes
|
170 |
+
attn_bias_cache[all_shapes] = attn_bias
|
171 |
+
|
172 |
+
if branges is not None:
|
173 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
174 |
+
else:
|
175 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
176 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
177 |
+
|
178 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
179 |
+
|
180 |
+
|
181 |
+
def drop_add_residual_stochastic_depth_list(
|
182 |
+
x_list: List[Tensor],
|
183 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
184 |
+
sample_drop_ratio: float = 0.0,
|
185 |
+
scaling_vector=None,
|
186 |
+
) -> Tensor:
|
187 |
+
# 1) generate random set of indices for dropping samples in the batch
|
188 |
+
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
189 |
+
branges = [s[0] for s in branges_scales]
|
190 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
191 |
+
|
192 |
+
# 2) get attention bias and index+concat the tensors
|
193 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
194 |
+
|
195 |
+
# 3) apply residual_func to get residual, and split the result
|
196 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
197 |
+
|
198 |
+
outputs = []
|
199 |
+
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
200 |
+
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
201 |
+
return outputs
|
202 |
+
|
203 |
+
|
204 |
+
class NestedTensorBlock(Block):
|
205 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
206 |
+
"""
|
207 |
+
x_list contains a list of tensors to nest together and run
|
208 |
+
"""
|
209 |
+
assert isinstance(self.attn, MemEffAttention)
|
210 |
+
|
211 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
212 |
+
|
213 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
214 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
215 |
+
|
216 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
217 |
+
return self.mlp(self.norm2(x))
|
218 |
+
|
219 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
220 |
+
x_list,
|
221 |
+
residual_func=attn_residual_func,
|
222 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
223 |
+
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
224 |
+
)
|
225 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
226 |
+
x_list,
|
227 |
+
residual_func=ffn_residual_func,
|
228 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
229 |
+
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
230 |
+
)
|
231 |
+
return x_list
|
232 |
+
else:
|
233 |
+
|
234 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
235 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
236 |
+
|
237 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
238 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
239 |
+
|
240 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
241 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
242 |
+
x = x + ffn_residual_func(x)
|
243 |
+
return attn_bias.split(x)
|
244 |
+
|
245 |
+
def forward(self, x_or_x_list):
|
246 |
+
if isinstance(x_or_x_list, Tensor):
|
247 |
+
return super().forward(x_or_x_list)
|
248 |
+
elif isinstance(x_or_x_list, list):
|
249 |
+
assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
|
250 |
+
return self.forward_nested(x_or_x_list)
|
251 |
+
else:
|
252 |
+
raise AssertionError
|
finetune/modules/depth_warping/depth_anything_v2/dinov2_layers/drop_path.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
10 |
+
|
11 |
+
|
12 |
+
from torch import nn
|
13 |
+
|
14 |
+
|
15 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
16 |
+
if drop_prob == 0.0 or not training:
|
17 |
+
return x
|
18 |
+
keep_prob = 1 - drop_prob
|
19 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
20 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
21 |
+
if keep_prob > 0.0:
|
22 |
+
random_tensor.div_(keep_prob)
|
23 |
+
output = x * random_tensor
|
24 |
+
return output
|
25 |
+
|
26 |
+
|
27 |
+
class DropPath(nn.Module):
|
28 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
29 |
+
|
30 |
+
def __init__(self, drop_prob=None):
|
31 |
+
super(DropPath, self).__init__()
|
32 |
+
self.drop_prob = drop_prob
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
return drop_path(x, self.drop_prob, self.training)
|
finetune/modules/depth_warping/depth_anything_v2/dinov2_layers/layer_scale.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
8 |
+
|
9 |
+
from typing import Union
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch import Tensor
|
13 |
+
from torch import nn
|
14 |
+
|
15 |
+
|
16 |
+
class LayerScale(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
dim: int,
|
20 |
+
init_values: Union[float, Tensor] = 1e-5,
|
21 |
+
inplace: bool = False,
|
22 |
+
) -> None:
|
23 |
+
super().__init__()
|
24 |
+
self.inplace = inplace
|
25 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
26 |
+
|
27 |
+
def forward(self, x: Tensor) -> Tensor:
|
28 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
finetune/modules/depth_warping/depth_anything_v2/dinov2_layers/mlp.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
10 |
+
|
11 |
+
|
12 |
+
from typing import Callable, Optional
|
13 |
+
|
14 |
+
from torch import Tensor, nn
|
15 |
+
|
16 |
+
|
17 |
+
class Mlp(nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
in_features: int,
|
21 |
+
hidden_features: Optional[int] = None,
|
22 |
+
out_features: Optional[int] = None,
|
23 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
24 |
+
drop: float = 0.0,
|
25 |
+
bias: bool = True,
|
26 |
+
) -> None:
|
27 |
+
super().__init__()
|
28 |
+
out_features = out_features or in_features
|
29 |
+
hidden_features = hidden_features or in_features
|
30 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
31 |
+
self.act = act_layer()
|
32 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
33 |
+
self.drop = nn.Dropout(drop)
|
34 |
+
|
35 |
+
def forward(self, x: Tensor) -> Tensor:
|
36 |
+
x = self.fc1(x)
|
37 |
+
x = self.act(x)
|
38 |
+
x = self.drop(x)
|
39 |
+
x = self.fc2(x)
|
40 |
+
x = self.drop(x)
|
41 |
+
return x
|
finetune/modules/depth_warping/depth_anything_v2/dinov2_layers/patch_embed.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
10 |
+
|
11 |
+
from typing import Callable, Optional, Tuple, Union
|
12 |
+
|
13 |
+
from torch import Tensor
|
14 |
+
import torch.nn as nn
|
15 |
+
|
16 |
+
|
17 |
+
def make_2tuple(x):
|
18 |
+
if isinstance(x, tuple):
|
19 |
+
assert len(x) == 2
|
20 |
+
return x
|
21 |
+
|
22 |
+
assert isinstance(x, int)
|
23 |
+
return (x, x)
|
24 |
+
|
25 |
+
|
26 |
+
class PatchEmbed(nn.Module):
|
27 |
+
"""
|
28 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
29 |
+
|
30 |
+
Args:
|
31 |
+
img_size: Image size.
|
32 |
+
patch_size: Patch token size.
|
33 |
+
in_chans: Number of input image channels.
|
34 |
+
embed_dim: Number of linear projection output channels.
|
35 |
+
norm_layer: Normalization layer.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
41 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
42 |
+
in_chans: int = 3,
|
43 |
+
embed_dim: int = 768,
|
44 |
+
norm_layer: Optional[Callable] = None,
|
45 |
+
flatten_embedding: bool = True,
|
46 |
+
) -> None:
|
47 |
+
super().__init__()
|
48 |
+
|
49 |
+
image_HW = make_2tuple(img_size)
|
50 |
+
patch_HW = make_2tuple(patch_size)
|
51 |
+
patch_grid_size = (
|
52 |
+
image_HW[0] // patch_HW[0],
|
53 |
+
image_HW[1] // patch_HW[1],
|
54 |
+
)
|
55 |
+
|
56 |
+
self.img_size = image_HW
|
57 |
+
self.patch_size = patch_HW
|
58 |
+
self.patches_resolution = patch_grid_size
|
59 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
60 |
+
|
61 |
+
self.in_chans = in_chans
|
62 |
+
self.embed_dim = embed_dim
|
63 |
+
|
64 |
+
self.flatten_embedding = flatten_embedding
|
65 |
+
|
66 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
67 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
68 |
+
|
69 |
+
def forward(self, x: Tensor) -> Tensor:
|
70 |
+
_, _, H, W = x.shape
|
71 |
+
patch_H, patch_W = self.patch_size
|
72 |
+
|
73 |
+
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
74 |
+
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
75 |
+
|
76 |
+
x = self.proj(x) # B C H W
|
77 |
+
H, W = x.size(2), x.size(3)
|
78 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
79 |
+
x = self.norm(x)
|
80 |
+
if not self.flatten_embedding:
|
81 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
82 |
+
return x
|
83 |
+
|
84 |
+
def flops(self) -> float:
|
85 |
+
Ho, Wo = self.patches_resolution
|
86 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
87 |
+
if self.norm is not None:
|
88 |
+
flops += Ho * Wo * self.embed_dim
|
89 |
+
return flops
|
finetune/modules/depth_warping/depth_anything_v2/dinov2_layers/swiglu_ffn.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from typing import Callable, Optional
|
8 |
+
|
9 |
+
from torch import Tensor, nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
|
13 |
+
class SwiGLUFFN(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
in_features: int,
|
17 |
+
hidden_features: Optional[int] = None,
|
18 |
+
out_features: Optional[int] = None,
|
19 |
+
act_layer: Callable[..., nn.Module] = None,
|
20 |
+
drop: float = 0.0,
|
21 |
+
bias: bool = True,
|
22 |
+
) -> None:
|
23 |
+
super().__init__()
|
24 |
+
out_features = out_features or in_features
|
25 |
+
hidden_features = hidden_features or in_features
|
26 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
27 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
28 |
+
|
29 |
+
def forward(self, x: Tensor) -> Tensor:
|
30 |
+
x12 = self.w12(x)
|
31 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
32 |
+
hidden = F.silu(x1) * x2
|
33 |
+
return self.w3(hidden)
|
34 |
+
|
35 |
+
|
36 |
+
try:
|
37 |
+
from xformers.ops import SwiGLU
|
38 |
+
|
39 |
+
XFORMERS_AVAILABLE = True
|
40 |
+
except ImportError:
|
41 |
+
SwiGLU = SwiGLUFFN
|
42 |
+
XFORMERS_AVAILABLE = False
|
43 |
+
|
44 |
+
|
45 |
+
class SwiGLUFFNFused(SwiGLU):
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
in_features: int,
|
49 |
+
hidden_features: Optional[int] = None,
|
50 |
+
out_features: Optional[int] = None,
|
51 |
+
act_layer: Callable[..., nn.Module] = None,
|
52 |
+
drop: float = 0.0,
|
53 |
+
bias: bool = True,
|
54 |
+
) -> None:
|
55 |
+
out_features = out_features or in_features
|
56 |
+
hidden_features = hidden_features or in_features
|
57 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
58 |
+
super().__init__(
|
59 |
+
in_features=in_features,
|
60 |
+
hidden_features=hidden_features,
|
61 |
+
out_features=out_features,
|
62 |
+
bias=bias,
|
63 |
+
)
|
finetune/modules/depth_warping/depth_anything_v2/dpt.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torchvision.transforms as tf
|
6 |
+
from torchvision.transforms import Compose
|
7 |
+
|
8 |
+
from .dinov2 import DINOv2
|
9 |
+
from .util.blocks import FeatureFusionBlock, _make_scratch
|
10 |
+
from .util.transform import Resize, NormalizeImage, PrepareForNet
|
11 |
+
|
12 |
+
|
13 |
+
def _make_fusion_block(features, use_bn, size=None):
|
14 |
+
return FeatureFusionBlock(
|
15 |
+
features,
|
16 |
+
nn.ReLU(False),
|
17 |
+
deconv=False,
|
18 |
+
bn=use_bn,
|
19 |
+
expand=False,
|
20 |
+
align_corners=True,
|
21 |
+
size=size,
|
22 |
+
)
|
23 |
+
|
24 |
+
|
25 |
+
class ConvBlock(nn.Module):
|
26 |
+
def __init__(self, in_feature, out_feature):
|
27 |
+
super().__init__()
|
28 |
+
|
29 |
+
self.conv_block = nn.Sequential(
|
30 |
+
nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1),
|
31 |
+
nn.BatchNorm2d(out_feature),
|
32 |
+
nn.ReLU(True)
|
33 |
+
)
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
return self.conv_block(x)
|
37 |
+
|
38 |
+
|
39 |
+
class DPTHead(nn.Module):
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
in_channels,
|
43 |
+
features=256,
|
44 |
+
use_bn=False,
|
45 |
+
out_channels=[256, 512, 1024, 1024],
|
46 |
+
use_clstoken=False
|
47 |
+
):
|
48 |
+
super(DPTHead, self).__init__()
|
49 |
+
|
50 |
+
self.use_clstoken = use_clstoken
|
51 |
+
|
52 |
+
self.projects = nn.ModuleList([
|
53 |
+
nn.Conv2d(
|
54 |
+
in_channels=in_channels,
|
55 |
+
out_channels=out_channel,
|
56 |
+
kernel_size=1,
|
57 |
+
stride=1,
|
58 |
+
padding=0,
|
59 |
+
) for out_channel in out_channels
|
60 |
+
])
|
61 |
+
|
62 |
+
self.resize_layers = nn.ModuleList([
|
63 |
+
nn.ConvTranspose2d(
|
64 |
+
in_channels=out_channels[0],
|
65 |
+
out_channels=out_channels[0],
|
66 |
+
kernel_size=4,
|
67 |
+
stride=4,
|
68 |
+
padding=0),
|
69 |
+
nn.ConvTranspose2d(
|
70 |
+
in_channels=out_channels[1],
|
71 |
+
out_channels=out_channels[1],
|
72 |
+
kernel_size=2,
|
73 |
+
stride=2,
|
74 |
+
padding=0),
|
75 |
+
nn.Identity(),
|
76 |
+
nn.Conv2d(
|
77 |
+
in_channels=out_channels[3],
|
78 |
+
out_channels=out_channels[3],
|
79 |
+
kernel_size=3,
|
80 |
+
stride=2,
|
81 |
+
padding=1)
|
82 |
+
])
|
83 |
+
|
84 |
+
if use_clstoken:
|
85 |
+
self.readout_projects = nn.ModuleList()
|
86 |
+
for _ in range(len(self.projects)):
|
87 |
+
self.readout_projects.append(
|
88 |
+
nn.Sequential(
|
89 |
+
nn.Linear(2 * in_channels, in_channels),
|
90 |
+
nn.GELU()))
|
91 |
+
|
92 |
+
self.scratch = _make_scratch(
|
93 |
+
out_channels,
|
94 |
+
features,
|
95 |
+
groups=1,
|
96 |
+
expand=False,
|
97 |
+
)
|
98 |
+
|
99 |
+
self.scratch.stem_transpose = None
|
100 |
+
|
101 |
+
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
102 |
+
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
103 |
+
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
104 |
+
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
105 |
+
|
106 |
+
head_features_1 = features
|
107 |
+
head_features_2 = 32
|
108 |
+
|
109 |
+
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
|
110 |
+
self.scratch.output_conv2 = nn.Sequential(
|
111 |
+
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
|
112 |
+
nn.ReLU(True),
|
113 |
+
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
|
114 |
+
nn.Sigmoid()
|
115 |
+
)
|
116 |
+
|
117 |
+
def forward(self, out_features, patch_h, patch_w):
|
118 |
+
out = []
|
119 |
+
for i, x in enumerate(out_features):
|
120 |
+
if self.use_clstoken:
|
121 |
+
x, cls_token = x[0], x[1]
|
122 |
+
readout = cls_token.unsqueeze(1).expand_as(x)
|
123 |
+
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
124 |
+
else:
|
125 |
+
x = x[0]
|
126 |
+
|
127 |
+
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
128 |
+
|
129 |
+
x = self.projects[i](x)
|
130 |
+
x = self.resize_layers[i](x)
|
131 |
+
|
132 |
+
out.append(x)
|
133 |
+
|
134 |
+
layer_1, layer_2, layer_3, layer_4 = out
|
135 |
+
|
136 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
137 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
138 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
139 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
140 |
+
|
141 |
+
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
142 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
|
143 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
|
144 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
145 |
+
|
146 |
+
out = self.scratch.output_conv1(path_1)
|
147 |
+
out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
|
148 |
+
out = self.scratch.output_conv2(out)
|
149 |
+
|
150 |
+
return out
|
151 |
+
|
152 |
+
|
153 |
+
class DepthAnythingV2(nn.Module):
|
154 |
+
def __init__(
|
155 |
+
self,
|
156 |
+
encoder='vitl',
|
157 |
+
features=256,
|
158 |
+
out_channels=[256, 512, 1024, 1024],
|
159 |
+
use_bn=False,
|
160 |
+
use_clstoken=False,
|
161 |
+
max_depth=20.0
|
162 |
+
):
|
163 |
+
super(DepthAnythingV2, self).__init__()
|
164 |
+
|
165 |
+
self.intermediate_layer_idx = {
|
166 |
+
'vits': [2, 5, 8, 11],
|
167 |
+
'vitb': [2, 5, 8, 11],
|
168 |
+
'vitl': [4, 11, 17, 23],
|
169 |
+
'vitg': [9, 19, 29, 39]
|
170 |
+
}
|
171 |
+
|
172 |
+
self.max_depth = max_depth
|
173 |
+
|
174 |
+
self.encoder = encoder
|
175 |
+
self.pretrained = DINOv2(model_name=encoder)
|
176 |
+
|
177 |
+
self.depth_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
|
178 |
+
|
179 |
+
def forward(self, x):
|
180 |
+
patch_h, patch_w = x.shape[-2] // 14, x.shape[-1] // 14
|
181 |
+
|
182 |
+
features = self.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True)
|
183 |
+
|
184 |
+
depth = self.depth_head(features, patch_h, patch_w) * self.max_depth
|
185 |
+
|
186 |
+
return depth.squeeze(1)
|
187 |
+
|
188 |
+
@torch.no_grad()
|
189 |
+
def infer_image(self, raw_image, input_size=518):
|
190 |
+
image, (h, w) = self.image2tensor(raw_image, input_size)
|
191 |
+
|
192 |
+
depth = self.forward(image)
|
193 |
+
|
194 |
+
depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True)[0, 0]
|
195 |
+
|
196 |
+
return depth
|
197 |
+
# return depth.cpu().numpy()
|
198 |
+
|
199 |
+
|
200 |
+
# TODO. transform for torch.Tensor
|
201 |
+
# TODO. inference for torch.Tensor
|
202 |
+
# def image2tensor_pt(self, raw_image, input_size=518):
|
203 |
+
# transform = Compose([
|
204 |
+
# tf
|
205 |
+
# ])
|
206 |
+
|
207 |
+
|
208 |
+
def image2tensor(self, raw_image, input_size=518):
|
209 |
+
transform = Compose([
|
210 |
+
Resize(
|
211 |
+
width=input_size,
|
212 |
+
height=input_size,
|
213 |
+
resize_target=False,
|
214 |
+
keep_aspect_ratio=True,
|
215 |
+
ensure_multiple_of=14,
|
216 |
+
resize_method='lower_bound',
|
217 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
218 |
+
),
|
219 |
+
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
220 |
+
PrepareForNet(),
|
221 |
+
])
|
222 |
+
|
223 |
+
h, w = raw_image.shape[:2]
|
224 |
+
|
225 |
+
# raw_image already has RGB order, [0,255]
|
226 |
+
# image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0
|
227 |
+
image = raw_image / 255.0
|
228 |
+
|
229 |
+
image = transform({'image': image})['image']
|
230 |
+
image = torch.from_numpy(image).unsqueeze(0)
|
231 |
+
|
232 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
|
233 |
+
image = image.to(DEVICE)
|
234 |
+
|
235 |
+
return image, (h, w)
|
finetune/modules/depth_warping/depth_anything_v2/util/blocks.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
5 |
+
scratch = nn.Module()
|
6 |
+
|
7 |
+
out_shape1 = out_shape
|
8 |
+
out_shape2 = out_shape
|
9 |
+
out_shape3 = out_shape
|
10 |
+
if len(in_shape) >= 4:
|
11 |
+
out_shape4 = out_shape
|
12 |
+
|
13 |
+
if expand:
|
14 |
+
out_shape1 = out_shape
|
15 |
+
out_shape2 = out_shape * 2
|
16 |
+
out_shape3 = out_shape * 4
|
17 |
+
if len(in_shape) >= 4:
|
18 |
+
out_shape4 = out_shape * 8
|
19 |
+
|
20 |
+
scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
|
21 |
+
scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
|
22 |
+
scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
|
23 |
+
if len(in_shape) >= 4:
|
24 |
+
scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
|
25 |
+
|
26 |
+
return scratch
|
27 |
+
|
28 |
+
|
29 |
+
class ResidualConvUnit(nn.Module):
|
30 |
+
"""Residual convolution module.
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(self, features, activation, bn):
|
34 |
+
"""Init.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
features (int): number of features
|
38 |
+
"""
|
39 |
+
super().__init__()
|
40 |
+
|
41 |
+
self.bn = bn
|
42 |
+
|
43 |
+
self.groups=1
|
44 |
+
|
45 |
+
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
46 |
+
|
47 |
+
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
48 |
+
|
49 |
+
if self.bn == True:
|
50 |
+
self.bn1 = nn.BatchNorm2d(features)
|
51 |
+
self.bn2 = nn.BatchNorm2d(features)
|
52 |
+
|
53 |
+
self.activation = activation
|
54 |
+
|
55 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
"""Forward pass.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
x (tensor): input
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
tensor: output
|
65 |
+
"""
|
66 |
+
|
67 |
+
out = self.activation(x)
|
68 |
+
out = self.conv1(out)
|
69 |
+
if self.bn == True:
|
70 |
+
out = self.bn1(out)
|
71 |
+
|
72 |
+
out = self.activation(out)
|
73 |
+
out = self.conv2(out)
|
74 |
+
if self.bn == True:
|
75 |
+
out = self.bn2(out)
|
76 |
+
|
77 |
+
if self.groups > 1:
|
78 |
+
out = self.conv_merge(out)
|
79 |
+
|
80 |
+
return self.skip_add.add(out, x)
|
81 |
+
|
82 |
+
|
83 |
+
class FeatureFusionBlock(nn.Module):
|
84 |
+
"""Feature fusion block.
|
85 |
+
"""
|
86 |
+
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
features,
|
90 |
+
activation,
|
91 |
+
deconv=False,
|
92 |
+
bn=False,
|
93 |
+
expand=False,
|
94 |
+
align_corners=True,
|
95 |
+
size=None
|
96 |
+
):
|
97 |
+
"""Init.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
features (int): number of features
|
101 |
+
"""
|
102 |
+
super(FeatureFusionBlock, self).__init__()
|
103 |
+
|
104 |
+
self.deconv = deconv
|
105 |
+
self.align_corners = align_corners
|
106 |
+
|
107 |
+
self.groups=1
|
108 |
+
|
109 |
+
self.expand = expand
|
110 |
+
out_features = features
|
111 |
+
if self.expand == True:
|
112 |
+
out_features = features // 2
|
113 |
+
|
114 |
+
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
115 |
+
|
116 |
+
self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
|
117 |
+
self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
|
118 |
+
|
119 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
120 |
+
|
121 |
+
self.size=size
|
122 |
+
|
123 |
+
def forward(self, *xs, size=None):
|
124 |
+
"""Forward pass.
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
tensor: output
|
128 |
+
"""
|
129 |
+
output = xs[0]
|
130 |
+
|
131 |
+
if len(xs) == 2:
|
132 |
+
res = self.resConfUnit1(xs[1])
|
133 |
+
output = self.skip_add.add(output, res)
|
134 |
+
|
135 |
+
output = self.resConfUnit2(output)
|
136 |
+
|
137 |
+
if (size is None) and (self.size is None):
|
138 |
+
modifier = {"scale_factor": 2}
|
139 |
+
elif size is None:
|
140 |
+
modifier = {"size": self.size}
|
141 |
+
else:
|
142 |
+
modifier = {"size": size}
|
143 |
+
|
144 |
+
output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
|
145 |
+
|
146 |
+
output = self.out_conv(output)
|
147 |
+
|
148 |
+
return output
|
finetune/modules/depth_warping/depth_anything_v2/util/transform.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
|
4 |
+
|
5 |
+
class Resize(object):
|
6 |
+
"""Resize sample to given size (width, height).
|
7 |
+
"""
|
8 |
+
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
width,
|
12 |
+
height,
|
13 |
+
resize_target=True,
|
14 |
+
keep_aspect_ratio=False,
|
15 |
+
ensure_multiple_of=1,
|
16 |
+
resize_method="lower_bound",
|
17 |
+
image_interpolation_method=cv2.INTER_AREA,
|
18 |
+
):
|
19 |
+
"""Init.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
width (int): desired output width
|
23 |
+
height (int): desired output height
|
24 |
+
resize_target (bool, optional):
|
25 |
+
True: Resize the full sample (image, mask, target).
|
26 |
+
False: Resize image only.
|
27 |
+
Defaults to True.
|
28 |
+
keep_aspect_ratio (bool, optional):
|
29 |
+
True: Keep the aspect ratio of the input sample.
|
30 |
+
Output sample might not have the given width and height, and
|
31 |
+
resize behaviour depends on the parameter 'resize_method'.
|
32 |
+
Defaults to False.
|
33 |
+
ensure_multiple_of (int, optional):
|
34 |
+
Output width and height is constrained to be multiple of this parameter.
|
35 |
+
Defaults to 1.
|
36 |
+
resize_method (str, optional):
|
37 |
+
"lower_bound": Output will be at least as large as the given size.
|
38 |
+
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
|
39 |
+
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
40 |
+
Defaults to "lower_bound".
|
41 |
+
"""
|
42 |
+
self.__width = width
|
43 |
+
self.__height = height
|
44 |
+
|
45 |
+
self.__resize_target = resize_target
|
46 |
+
self.__keep_aspect_ratio = keep_aspect_ratio
|
47 |
+
self.__multiple_of = ensure_multiple_of
|
48 |
+
self.__resize_method = resize_method
|
49 |
+
self.__image_interpolation_method = image_interpolation_method
|
50 |
+
|
51 |
+
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
52 |
+
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
53 |
+
|
54 |
+
if max_val is not None and y > max_val:
|
55 |
+
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
56 |
+
|
57 |
+
if y < min_val:
|
58 |
+
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
59 |
+
|
60 |
+
return y
|
61 |
+
|
62 |
+
def get_size(self, width, height):
|
63 |
+
# determine new height and width
|
64 |
+
scale_height = self.__height / height
|
65 |
+
scale_width = self.__width / width
|
66 |
+
|
67 |
+
if self.__keep_aspect_ratio:
|
68 |
+
if self.__resize_method == "lower_bound":
|
69 |
+
# scale such that output size is lower bound
|
70 |
+
if scale_width > scale_height:
|
71 |
+
# fit width
|
72 |
+
scale_height = scale_width
|
73 |
+
else:
|
74 |
+
# fit height
|
75 |
+
scale_width = scale_height
|
76 |
+
elif self.__resize_method == "upper_bound":
|
77 |
+
# scale such that output size is upper bound
|
78 |
+
if scale_width < scale_height:
|
79 |
+
# fit width
|
80 |
+
scale_height = scale_width
|
81 |
+
else:
|
82 |
+
# fit height
|
83 |
+
scale_width = scale_height
|
84 |
+
elif self.__resize_method == "minimal":
|
85 |
+
# scale as least as possbile
|
86 |
+
if abs(1 - scale_width) < abs(1 - scale_height):
|
87 |
+
# fit width
|
88 |
+
scale_height = scale_width
|
89 |
+
else:
|
90 |
+
# fit height
|
91 |
+
scale_width = scale_height
|
92 |
+
else:
|
93 |
+
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
94 |
+
|
95 |
+
if self.__resize_method == "lower_bound":
|
96 |
+
new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
|
97 |
+
new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
|
98 |
+
elif self.__resize_method == "upper_bound":
|
99 |
+
new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
|
100 |
+
new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
|
101 |
+
elif self.__resize_method == "minimal":
|
102 |
+
new_height = self.constrain_to_multiple_of(scale_height * height)
|
103 |
+
new_width = self.constrain_to_multiple_of(scale_width * width)
|
104 |
+
else:
|
105 |
+
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
106 |
+
|
107 |
+
return (new_width, new_height)
|
108 |
+
|
109 |
+
def __call__(self, sample):
|
110 |
+
width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
|
111 |
+
|
112 |
+
# resize sample
|
113 |
+
sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method)
|
114 |
+
|
115 |
+
if self.__resize_target:
|
116 |
+
if "depth" in sample:
|
117 |
+
sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
|
118 |
+
|
119 |
+
if "mask" in sample:
|
120 |
+
sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST)
|
121 |
+
|
122 |
+
return sample
|
123 |
+
|
124 |
+
|
125 |
+
class NormalizeImage(object):
|
126 |
+
"""Normlize image by given mean and std.
|
127 |
+
"""
|
128 |
+
|
129 |
+
def __init__(self, mean, std):
|
130 |
+
self.__mean = mean
|
131 |
+
self.__std = std
|
132 |
+
|
133 |
+
def __call__(self, sample):
|
134 |
+
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
135 |
+
|
136 |
+
return sample
|
137 |
+
|
138 |
+
|
139 |
+
class PrepareForNet(object):
|
140 |
+
"""Prepare sample for usage as network input.
|
141 |
+
"""
|
142 |
+
|
143 |
+
def __init__(self):
|
144 |
+
pass
|
145 |
+
|
146 |
+
def __call__(self, sample):
|
147 |
+
image = np.transpose(sample["image"], (2, 0, 1))
|
148 |
+
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
149 |
+
|
150 |
+
if "depth" in sample:
|
151 |
+
depth = sample["depth"].astype(np.float32)
|
152 |
+
sample["depth"] = np.ascontiguousarray(depth)
|
153 |
+
|
154 |
+
if "mask" in sample:
|
155 |
+
sample["mask"] = sample["mask"].astype(np.float32)
|
156 |
+
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
157 |
+
|
158 |
+
return sample
|
finetune/modules/depth_warping/depth_pro/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
2 |
+
"""Depth Pro package."""
|
3 |
+
|
4 |
+
from .depth_pro import create_model_and_transforms # noqa
|
5 |
+
from .utils import load_rgb # noqa
|
finetune/modules/depth_warping/depth_pro/cli/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
2 |
+
"""Depth Pro CLI and tools."""
|
3 |
+
|
4 |
+
from .run import main as run_main # noqa
|
finetune/modules/depth_warping/depth_pro/cli/run.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""Sample script to run DepthPro.
|
3 |
+
|
4 |
+
Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
5 |
+
"""
|
6 |
+
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
import logging
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import PIL.Image
|
14 |
+
import torch
|
15 |
+
from matplotlib import pyplot as plt
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
from depth_pro import create_model_and_transforms, load_rgb
|
19 |
+
|
20 |
+
LOGGER = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
|
23 |
+
def get_torch_device() -> torch.device:
|
24 |
+
"""Get the Torch device."""
|
25 |
+
device = torch.device("cpu")
|
26 |
+
if torch.cuda.is_available():
|
27 |
+
device = torch.device("cuda:0")
|
28 |
+
elif torch.backends.mps.is_available():
|
29 |
+
device = torch.device("mps")
|
30 |
+
return device
|
31 |
+
|
32 |
+
|
33 |
+
def run(args):
|
34 |
+
"""Run Depth Pro on a sample image."""
|
35 |
+
if args.verbose:
|
36 |
+
logging.basicConfig(level=logging.INFO)
|
37 |
+
|
38 |
+
# Load model.
|
39 |
+
model, transform = create_model_and_transforms(
|
40 |
+
device=get_torch_device(),
|
41 |
+
precision=torch.half,
|
42 |
+
)
|
43 |
+
model.eval()
|
44 |
+
|
45 |
+
image_paths = [args.image_path]
|
46 |
+
if args.image_path.is_dir():
|
47 |
+
image_paths = args.image_path.glob("**/*")
|
48 |
+
relative_path = args.image_path
|
49 |
+
else:
|
50 |
+
relative_path = args.image_path.parent
|
51 |
+
|
52 |
+
if not args.skip_display:
|
53 |
+
plt.ion()
|
54 |
+
fig = plt.figure()
|
55 |
+
ax_rgb = fig.add_subplot(121)
|
56 |
+
ax_disp = fig.add_subplot(122)
|
57 |
+
|
58 |
+
for image_path in tqdm(image_paths):
|
59 |
+
# Load image and focal length from exif info (if found.).
|
60 |
+
try:
|
61 |
+
LOGGER.info(f"Loading image {image_path} ...")
|
62 |
+
image, _, f_px = load_rgb(image_path)
|
63 |
+
except Exception as e:
|
64 |
+
LOGGER.error(str(e))
|
65 |
+
continue
|
66 |
+
# Run prediction. If `f_px` is provided, it is used to estimate the final metric depth,
|
67 |
+
# otherwise the model estimates `f_px` to compute the depth metricness.
|
68 |
+
prediction = model.infer(transform(image), f_px=f_px)
|
69 |
+
|
70 |
+
# Extract the depth and focal length.
|
71 |
+
depth = prediction["depth"].detach().cpu().numpy().squeeze()
|
72 |
+
if f_px is not None:
|
73 |
+
LOGGER.debug(f"Focal length (from exif): {f_px:0.2f}")
|
74 |
+
elif prediction["focallength_px"] is not None:
|
75 |
+
focallength_px = prediction["focallength_px"].detach().cpu().item()
|
76 |
+
LOGGER.info(f"Estimated focal length: {focallength_px}")
|
77 |
+
|
78 |
+
inverse_depth = 1 / depth
|
79 |
+
# Visualize inverse depth instead of depth, clipped to [0.1m;250m] range for better visualization.
|
80 |
+
max_invdepth_vizu = min(inverse_depth.max(), 1 / 0.1)
|
81 |
+
min_invdepth_vizu = max(1 / 250, inverse_depth.min())
|
82 |
+
inverse_depth_normalized = (inverse_depth - min_invdepth_vizu) / (
|
83 |
+
max_invdepth_vizu - min_invdepth_vizu
|
84 |
+
)
|
85 |
+
|
86 |
+
# Save Depth as npz file.
|
87 |
+
if args.output_path is not None:
|
88 |
+
output_file = (
|
89 |
+
args.output_path
|
90 |
+
/ image_path.relative_to(relative_path).parent
|
91 |
+
/ image_path.stem
|
92 |
+
)
|
93 |
+
LOGGER.info(f"Saving depth map to: {str(output_file)}")
|
94 |
+
output_file.parent.mkdir(parents=True, exist_ok=True)
|
95 |
+
np.savez_compressed(output_file, depth=depth)
|
96 |
+
|
97 |
+
# Save as color-mapped "turbo" jpg image.
|
98 |
+
cmap = plt.get_cmap("turbo")
|
99 |
+
color_depth = (cmap(inverse_depth_normalized)[..., :3] * 255).astype(
|
100 |
+
np.uint8
|
101 |
+
)
|
102 |
+
color_map_output_file = str(output_file) + ".jpg"
|
103 |
+
LOGGER.info(f"Saving color-mapped depth to: : {color_map_output_file}")
|
104 |
+
PIL.Image.fromarray(color_depth).save(
|
105 |
+
color_map_output_file, format="JPEG", quality=90
|
106 |
+
)
|
107 |
+
|
108 |
+
# Display the image and estimated depth map.
|
109 |
+
if not args.skip_display:
|
110 |
+
ax_rgb.imshow(image)
|
111 |
+
ax_disp.imshow(inverse_depth_normalized, cmap="turbo")
|
112 |
+
fig.canvas.draw()
|
113 |
+
fig.canvas.flush_events()
|
114 |
+
|
115 |
+
LOGGER.info("Done predicting depth!")
|
116 |
+
if not args.skip_display:
|
117 |
+
plt.show(block=True)
|
118 |
+
|
119 |
+
|
120 |
+
def main():
|
121 |
+
"""Run DepthPro inference example."""
|
122 |
+
parser = argparse.ArgumentParser(
|
123 |
+
description="Inference scripts of DepthPro with PyTorch models."
|
124 |
+
)
|
125 |
+
parser.add_argument(
|
126 |
+
"-i",
|
127 |
+
"--image-path",
|
128 |
+
type=Path,
|
129 |
+
default="./data/example.jpg",
|
130 |
+
help="Path to input image.",
|
131 |
+
)
|
132 |
+
parser.add_argument(
|
133 |
+
"-o",
|
134 |
+
"--output-path",
|
135 |
+
type=Path,
|
136 |
+
help="Path to store output files.",
|
137 |
+
)
|
138 |
+
parser.add_argument(
|
139 |
+
"--skip-display",
|
140 |
+
action="store_true",
|
141 |
+
help="Skip matplotlib display.",
|
142 |
+
)
|
143 |
+
parser.add_argument(
|
144 |
+
"-v",
|
145 |
+
"--verbose",
|
146 |
+
action="store_true",
|
147 |
+
help="Show verbose output."
|
148 |
+
)
|
149 |
+
|
150 |
+
run(parser.parse_args())
|
151 |
+
|
152 |
+
|
153 |
+
if __name__ == "__main__":
|
154 |
+
main()
|
finetune/modules/depth_warping/depth_pro/depth_pro.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
2 |
+
# Depth Pro: Sharp Monocular Metric Depth in Less Than a Second
|
3 |
+
|
4 |
+
|
5 |
+
from __future__ import annotations
|
6 |
+
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from typing import Mapping, Optional, Tuple, Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch import nn
|
12 |
+
from torchvision.transforms import (
|
13 |
+
Compose,
|
14 |
+
ConvertImageDtype,
|
15 |
+
Lambda,
|
16 |
+
Normalize,
|
17 |
+
ToTensor,
|
18 |
+
)
|
19 |
+
|
20 |
+
from .network.decoder import MultiresConvDecoder
|
21 |
+
from .network.encoder import DepthProEncoder
|
22 |
+
from .network.fov import FOVNetwork
|
23 |
+
from .network.vit_factory import VIT_CONFIG_DICT, ViTPreset, create_vit
|
24 |
+
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class DepthProConfig:
|
28 |
+
"""Configuration for DepthPro."""
|
29 |
+
|
30 |
+
patch_encoder_preset: ViTPreset
|
31 |
+
image_encoder_preset: ViTPreset
|
32 |
+
decoder_features: int
|
33 |
+
|
34 |
+
checkpoint_uri: Optional[str] = None
|
35 |
+
fov_encoder_preset: Optional[ViTPreset] = None
|
36 |
+
use_fov_head: bool = True
|
37 |
+
|
38 |
+
|
39 |
+
DEFAULT_MONODEPTH_CONFIG_DICT = DepthProConfig(
|
40 |
+
patch_encoder_preset="dinov2l16_384",
|
41 |
+
image_encoder_preset="dinov2l16_384",
|
42 |
+
checkpoint_uri="./checkpoints/depth_pro.pt",
|
43 |
+
decoder_features=256,
|
44 |
+
use_fov_head=True,
|
45 |
+
fov_encoder_preset="dinov2l16_384",
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
def create_backbone_model(
|
50 |
+
preset: ViTPreset
|
51 |
+
) -> Tuple[nn.Module, ViTPreset]:
|
52 |
+
"""Create and load a backbone model given a config.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
----
|
56 |
+
preset: A backbone preset to load pre-defind configs.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
-------
|
60 |
+
A Torch module and the associated config.
|
61 |
+
|
62 |
+
"""
|
63 |
+
if preset in VIT_CONFIG_DICT:
|
64 |
+
config = VIT_CONFIG_DICT[preset]
|
65 |
+
model = create_vit(preset=preset, use_pretrained=False)
|
66 |
+
else:
|
67 |
+
raise KeyError(f"Preset {preset} not found.")
|
68 |
+
|
69 |
+
return model, config
|
70 |
+
|
71 |
+
|
72 |
+
def create_model_and_transforms(
|
73 |
+
config: DepthProConfig = DEFAULT_MONODEPTH_CONFIG_DICT,
|
74 |
+
device: torch.device = torch.device("cpu"),
|
75 |
+
precision: torch.dtype = torch.float32,
|
76 |
+
) -> Tuple[DepthPro, Compose]:
|
77 |
+
"""Create a DepthPro model and load weights from `config.checkpoint_uri`.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
----
|
81 |
+
config: The configuration for the DPT model architecture.
|
82 |
+
device: The optional Torch device to load the model onto, default runs on "cpu".
|
83 |
+
precision: The optional precision used for the model, default is FP32.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
-------
|
87 |
+
The Torch DepthPro model and associated Transform.
|
88 |
+
|
89 |
+
"""
|
90 |
+
patch_encoder, patch_encoder_config = create_backbone_model(
|
91 |
+
preset=config.patch_encoder_preset
|
92 |
+
)
|
93 |
+
image_encoder, _ = create_backbone_model(
|
94 |
+
preset=config.image_encoder_preset
|
95 |
+
)
|
96 |
+
|
97 |
+
fov_encoder = None
|
98 |
+
if config.use_fov_head and config.fov_encoder_preset is not None:
|
99 |
+
fov_encoder, _ = create_backbone_model(preset=config.fov_encoder_preset)
|
100 |
+
|
101 |
+
dims_encoder = patch_encoder_config.encoder_feature_dims
|
102 |
+
hook_block_ids = patch_encoder_config.encoder_feature_layer_ids
|
103 |
+
encoder = DepthProEncoder(
|
104 |
+
dims_encoder=dims_encoder,
|
105 |
+
patch_encoder=patch_encoder,
|
106 |
+
image_encoder=image_encoder,
|
107 |
+
hook_block_ids=hook_block_ids,
|
108 |
+
decoder_features=config.decoder_features,
|
109 |
+
)
|
110 |
+
decoder = MultiresConvDecoder(
|
111 |
+
dims_encoder=[config.decoder_features] + list(encoder.dims_encoder),
|
112 |
+
dim_decoder=config.decoder_features,
|
113 |
+
)
|
114 |
+
model = DepthPro(
|
115 |
+
encoder=encoder,
|
116 |
+
decoder=decoder,
|
117 |
+
last_dims=(32, 1),
|
118 |
+
use_fov_head=config.use_fov_head,
|
119 |
+
fov_encoder=fov_encoder,
|
120 |
+
).to(device)
|
121 |
+
|
122 |
+
if precision == torch.half:
|
123 |
+
model.half()
|
124 |
+
|
125 |
+
transform = Compose(
|
126 |
+
[
|
127 |
+
ToTensor(),
|
128 |
+
Lambda(lambda x: x.to(device)),
|
129 |
+
Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
130 |
+
ConvertImageDtype(precision),
|
131 |
+
]
|
132 |
+
)
|
133 |
+
|
134 |
+
if config.checkpoint_uri is not None:
|
135 |
+
state_dict = torch.load(config.checkpoint_uri, map_location="cpu")
|
136 |
+
missing_keys, unexpected_keys = model.load_state_dict(
|
137 |
+
state_dict=state_dict, strict=True
|
138 |
+
)
|
139 |
+
|
140 |
+
if len(unexpected_keys) != 0:
|
141 |
+
raise KeyError(
|
142 |
+
f"Found unexpected keys when loading monodepth: {unexpected_keys}"
|
143 |
+
)
|
144 |
+
|
145 |
+
# fc_norm is only for the classification head,
|
146 |
+
# which we would not use. We only use the encoding.
|
147 |
+
missing_keys = [key for key in missing_keys if "fc_norm" not in key]
|
148 |
+
if len(missing_keys) != 0:
|
149 |
+
raise KeyError(f"Keys are missing when loading monodepth: {missing_keys}")
|
150 |
+
|
151 |
+
return model, transform
|
152 |
+
|
153 |
+
|
154 |
+
class DepthPro(nn.Module):
|
155 |
+
"""DepthPro network."""
|
156 |
+
|
157 |
+
def __init__(
|
158 |
+
self,
|
159 |
+
encoder: DepthProEncoder,
|
160 |
+
decoder: MultiresConvDecoder,
|
161 |
+
last_dims: tuple[int, int],
|
162 |
+
use_fov_head: bool = True,
|
163 |
+
fov_encoder: Optional[nn.Module] = None,
|
164 |
+
):
|
165 |
+
"""Initialize DepthPro.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
----
|
169 |
+
encoder: The DepthProEncoder backbone.
|
170 |
+
decoder: The MultiresConvDecoder decoder.
|
171 |
+
last_dims: The dimension for the last convolution layers.
|
172 |
+
use_fov_head: Whether to use the field-of-view head.
|
173 |
+
fov_encoder: A separate encoder for the field of view.
|
174 |
+
|
175 |
+
"""
|
176 |
+
super().__init__()
|
177 |
+
|
178 |
+
self.encoder = encoder
|
179 |
+
self.decoder = decoder
|
180 |
+
|
181 |
+
dim_decoder = decoder.dim_decoder
|
182 |
+
self.head = nn.Sequential(
|
183 |
+
nn.Conv2d(
|
184 |
+
dim_decoder, dim_decoder // 2, kernel_size=3, stride=1, padding=1
|
185 |
+
),
|
186 |
+
nn.ConvTranspose2d(
|
187 |
+
in_channels=dim_decoder // 2,
|
188 |
+
out_channels=dim_decoder // 2,
|
189 |
+
kernel_size=2,
|
190 |
+
stride=2,
|
191 |
+
padding=0,
|
192 |
+
bias=True,
|
193 |
+
),
|
194 |
+
nn.Conv2d(
|
195 |
+
dim_decoder // 2,
|
196 |
+
last_dims[0],
|
197 |
+
kernel_size=3,
|
198 |
+
stride=1,
|
199 |
+
padding=1,
|
200 |
+
),
|
201 |
+
nn.ReLU(True),
|
202 |
+
nn.Conv2d(last_dims[0], last_dims[1], kernel_size=1, stride=1, padding=0),
|
203 |
+
nn.ReLU(),
|
204 |
+
)
|
205 |
+
|
206 |
+
# Set the final convolution layer's bias to be 0.
|
207 |
+
self.head[4].bias.data.fill_(0)
|
208 |
+
|
209 |
+
# Set the FOV estimation head.
|
210 |
+
if use_fov_head:
|
211 |
+
self.fov = FOVNetwork(num_features=dim_decoder, fov_encoder=fov_encoder)
|
212 |
+
|
213 |
+
@property
|
214 |
+
def img_size(self) -> int:
|
215 |
+
"""Return the internal image size of the network."""
|
216 |
+
return self.encoder.img_size
|
217 |
+
|
218 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
219 |
+
"""Decode by projection and fusion of multi-resolution encodings.
|
220 |
+
|
221 |
+
Args:
|
222 |
+
----
|
223 |
+
x (torch.Tensor): Input image.
|
224 |
+
|
225 |
+
Returns:
|
226 |
+
-------
|
227 |
+
The canonical inverse depth map [m] and the optional estimated field of view [deg].
|
228 |
+
|
229 |
+
"""
|
230 |
+
_, _, H, W = x.shape
|
231 |
+
assert H == self.img_size and W == self.img_size
|
232 |
+
|
233 |
+
encodings = self.encoder(x)
|
234 |
+
features, features_0 = self.decoder(encodings)
|
235 |
+
canonical_inverse_depth = self.head(features)
|
236 |
+
|
237 |
+
fov_deg = None
|
238 |
+
if hasattr(self, "fov"):
|
239 |
+
fov_deg = self.fov.forward(x, features_0.detach())
|
240 |
+
|
241 |
+
return canonical_inverse_depth, fov_deg
|
242 |
+
|
243 |
+
@torch.no_grad()
|
244 |
+
def infer(
|
245 |
+
self,
|
246 |
+
x: torch.Tensor,
|
247 |
+
f_px: Optional[Union[float, torch.Tensor]] = None,
|
248 |
+
interpolation_mode="bilinear",
|
249 |
+
) -> Mapping[str, torch.Tensor]:
|
250 |
+
"""Infer depth and fov for a given image.
|
251 |
+
|
252 |
+
If the image is not at network resolution, it is resized to 1536x1536 and
|
253 |
+
the estimated depth is resized to the original image resolution.
|
254 |
+
Note: if the focal length is given, the estimated value is ignored and the provided
|
255 |
+
focal length is use to generate the metric depth values.
|
256 |
+
|
257 |
+
Args:
|
258 |
+
----
|
259 |
+
x (torch.Tensor): Input image
|
260 |
+
f_px (torch.Tensor): Optional focal length in pixels corresponding to `x`.
|
261 |
+
interpolation_mode (str): Interpolation function for downsampling/upsampling.
|
262 |
+
|
263 |
+
Returns:
|
264 |
+
-------
|
265 |
+
Tensor dictionary (torch.Tensor): depth [m], focallength [pixels].
|
266 |
+
|
267 |
+
"""
|
268 |
+
if len(x.shape) == 3:
|
269 |
+
x = x.unsqueeze(0)
|
270 |
+
_, _, H, W = x.shape
|
271 |
+
resize = H != self.img_size or W != self.img_size
|
272 |
+
|
273 |
+
if resize:
|
274 |
+
x = nn.functional.interpolate(
|
275 |
+
x,
|
276 |
+
size=(self.img_size, self.img_size),
|
277 |
+
mode=interpolation_mode,
|
278 |
+
align_corners=False,
|
279 |
+
)
|
280 |
+
|
281 |
+
canonical_inverse_depth, fov_deg = self.forward(x)
|
282 |
+
if f_px is None:
|
283 |
+
f_px = 0.5 * W / torch.tan(0.5 * torch.deg2rad(fov_deg.to(torch.float)))
|
284 |
+
|
285 |
+
inverse_depth = canonical_inverse_depth * (W / f_px)
|
286 |
+
f_px = f_px.squeeze()
|
287 |
+
|
288 |
+
if resize:
|
289 |
+
inverse_depth = nn.functional.interpolate(
|
290 |
+
inverse_depth, size=(H, W), mode=interpolation_mode, align_corners=False
|
291 |
+
)
|
292 |
+
|
293 |
+
depth = 1.0 / torch.clamp(inverse_depth, min=1e-4, max=1e4)
|
294 |
+
|
295 |
+
return {
|
296 |
+
"depth": depth.squeeze(),
|
297 |
+
"focallength_px": f_px,
|
298 |
+
}
|
finetune/modules/depth_warping/depth_pro/eval/boundary_metrics.py
ADDED
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
def connected_component(r: np.ndarray, c: np.ndarray) -> List[List[int]]:
|
7 |
+
"""Find connected components in the given row and column indices.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
----
|
11 |
+
r (np.ndarray): Row indices.
|
12 |
+
c (np.ndarray): Column indices.
|
13 |
+
|
14 |
+
Yields:
|
15 |
+
------
|
16 |
+
List[int]: Indices of connected components.
|
17 |
+
|
18 |
+
"""
|
19 |
+
indices = [0]
|
20 |
+
for i in range(1, r.size):
|
21 |
+
if r[i] == r[indices[-1]] and c[i] == c[indices[-1]] + 1:
|
22 |
+
indices.append(i)
|
23 |
+
else:
|
24 |
+
yield indices
|
25 |
+
indices = [i]
|
26 |
+
yield indices
|
27 |
+
|
28 |
+
|
29 |
+
def nms_horizontal(ratio: np.ndarray, threshold: float) -> np.ndarray:
|
30 |
+
"""Apply Non-Maximum Suppression (NMS) horizontally on the given ratio matrix.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
----
|
34 |
+
ratio (np.ndarray): Input ratio matrix.
|
35 |
+
threshold (float): Threshold for NMS.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
-------
|
39 |
+
np.ndarray: Binary mask after applying NMS.
|
40 |
+
|
41 |
+
"""
|
42 |
+
mask = np.zeros_like(ratio, dtype=bool)
|
43 |
+
r, c = np.nonzero(ratio > threshold)
|
44 |
+
if len(r) == 0:
|
45 |
+
return mask
|
46 |
+
for ids in connected_component(r, c):
|
47 |
+
values = [ratio[r[i], c[i]] for i in ids]
|
48 |
+
mi = np.argmax(values)
|
49 |
+
mask[r[ids[mi]], c[ids[mi]]] = True
|
50 |
+
return mask
|
51 |
+
|
52 |
+
|
53 |
+
def nms_vertical(ratio: np.ndarray, threshold: float) -> np.ndarray:
|
54 |
+
"""Apply Non-Maximum Suppression (NMS) vertically on the given ratio matrix.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
----
|
58 |
+
ratio (np.ndarray): Input ratio matrix.
|
59 |
+
threshold (float): Threshold for NMS.
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
-------
|
63 |
+
np.ndarray: Binary mask after applying NMS.
|
64 |
+
|
65 |
+
"""
|
66 |
+
return np.transpose(nms_horizontal(np.transpose(ratio), threshold))
|
67 |
+
|
68 |
+
|
69 |
+
def fgbg_depth(
|
70 |
+
d: np.ndarray, t: float
|
71 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
72 |
+
"""Find foreground-background relations between neighboring pixels.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
----
|
76 |
+
d (np.ndarray): Depth matrix.
|
77 |
+
t (float): Threshold for comparison.
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
-------
|
81 |
+
Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: Four matrices indicating
|
82 |
+
left, top, right, and bottom foreground-background relations.
|
83 |
+
|
84 |
+
"""
|
85 |
+
right_is_big_enough = (d[..., :, 1:] / d[..., :, :-1]) > t
|
86 |
+
left_is_big_enough = (d[..., :, :-1] / d[..., :, 1:]) > t
|
87 |
+
bottom_is_big_enough = (d[..., 1:, :] / d[..., :-1, :]) > t
|
88 |
+
top_is_big_enough = (d[..., :-1, :] / d[..., 1:, :]) > t
|
89 |
+
return (
|
90 |
+
left_is_big_enough,
|
91 |
+
top_is_big_enough,
|
92 |
+
right_is_big_enough,
|
93 |
+
bottom_is_big_enough,
|
94 |
+
)
|
95 |
+
|
96 |
+
|
97 |
+
def fgbg_depth_thinned(
|
98 |
+
d: np.ndarray, t: float
|
99 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
100 |
+
"""Find foreground-background relations between neighboring pixels with Non-Maximum Suppression.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
----
|
104 |
+
d (np.ndarray): Depth matrix.
|
105 |
+
t (float): Threshold for NMS.
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
-------
|
109 |
+
Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: Four matrices indicating
|
110 |
+
left, top, right, and bottom foreground-background relations with NMS applied.
|
111 |
+
|
112 |
+
"""
|
113 |
+
right_is_big_enough = nms_horizontal(d[..., :, 1:] / d[..., :, :-1], t)
|
114 |
+
left_is_big_enough = nms_horizontal(d[..., :, :-1] / d[..., :, 1:], t)
|
115 |
+
bottom_is_big_enough = nms_vertical(d[..., 1:, :] / d[..., :-1, :], t)
|
116 |
+
top_is_big_enough = nms_vertical(d[..., :-1, :] / d[..., 1:, :], t)
|
117 |
+
return (
|
118 |
+
left_is_big_enough,
|
119 |
+
top_is_big_enough,
|
120 |
+
right_is_big_enough,
|
121 |
+
bottom_is_big_enough,
|
122 |
+
)
|
123 |
+
|
124 |
+
|
125 |
+
def fgbg_binary_mask(
|
126 |
+
d: np.ndarray,
|
127 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
128 |
+
"""Find foreground-background relations between neighboring pixels in binary masks.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
----
|
132 |
+
d (np.ndarray): Binary depth matrix.
|
133 |
+
|
134 |
+
Returns:
|
135 |
+
-------
|
136 |
+
Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: Four matrices indicating
|
137 |
+
left, top, right, and bottom foreground-background relations in binary masks.
|
138 |
+
|
139 |
+
"""
|
140 |
+
assert d.dtype == bool
|
141 |
+
right_is_big_enough = d[..., :, 1:] & ~d[..., :, :-1]
|
142 |
+
left_is_big_enough = d[..., :, :-1] & ~d[..., :, 1:]
|
143 |
+
bottom_is_big_enough = d[..., 1:, :] & ~d[..., :-1, :]
|
144 |
+
top_is_big_enough = d[..., :-1, :] & ~d[..., 1:, :]
|
145 |
+
return (
|
146 |
+
left_is_big_enough,
|
147 |
+
top_is_big_enough,
|
148 |
+
right_is_big_enough,
|
149 |
+
bottom_is_big_enough,
|
150 |
+
)
|
151 |
+
|
152 |
+
|
153 |
+
def edge_recall_matting(pr: np.ndarray, gt: np.ndarray, t: float) -> float:
|
154 |
+
"""Calculate edge recall for image matting.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
----
|
158 |
+
pr (np.ndarray): Predicted depth matrix.
|
159 |
+
gt (np.ndarray): Ground truth binary mask.
|
160 |
+
t (float): Threshold for NMS.
|
161 |
+
|
162 |
+
Returns:
|
163 |
+
-------
|
164 |
+
float: Edge recall value.
|
165 |
+
|
166 |
+
"""
|
167 |
+
assert gt.dtype == bool
|
168 |
+
ap, bp, cp, dp = fgbg_depth_thinned(pr, t)
|
169 |
+
ag, bg, cg, dg = fgbg_binary_mask(gt)
|
170 |
+
return 0.25 * (
|
171 |
+
np.count_nonzero(ap & ag) / max(np.count_nonzero(ag), 1)
|
172 |
+
+ np.count_nonzero(bp & bg) / max(np.count_nonzero(bg), 1)
|
173 |
+
+ np.count_nonzero(cp & cg) / max(np.count_nonzero(cg), 1)
|
174 |
+
+ np.count_nonzero(dp & dg) / max(np.count_nonzero(dg), 1)
|
175 |
+
)
|
176 |
+
|
177 |
+
|
178 |
+
def boundary_f1(
|
179 |
+
pr: np.ndarray,
|
180 |
+
gt: np.ndarray,
|
181 |
+
t: float,
|
182 |
+
return_p: bool = False,
|
183 |
+
return_r: bool = False,
|
184 |
+
) -> float:
|
185 |
+
"""Calculate Boundary F1 score.
|
186 |
+
|
187 |
+
Args:
|
188 |
+
----
|
189 |
+
pr (np.ndarray): Predicted depth matrix.
|
190 |
+
gt (np.ndarray): Ground truth depth matrix.
|
191 |
+
t (float): Threshold for comparison.
|
192 |
+
return_p (bool, optional): If True, return precision. Defaults to False.
|
193 |
+
return_r (bool, optional): If True, return recall. Defaults to False.
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
-------
|
197 |
+
float: Boundary F1 score, or precision, or recall depending on the flags.
|
198 |
+
|
199 |
+
"""
|
200 |
+
ap, bp, cp, dp = fgbg_depth(pr, t)
|
201 |
+
ag, bg, cg, dg = fgbg_depth(gt, t)
|
202 |
+
|
203 |
+
r = 0.25 * (
|
204 |
+
np.count_nonzero(ap & ag) / max(np.count_nonzero(ag), 1)
|
205 |
+
+ np.count_nonzero(bp & bg) / max(np.count_nonzero(bg), 1)
|
206 |
+
+ np.count_nonzero(cp & cg) / max(np.count_nonzero(cg), 1)
|
207 |
+
+ np.count_nonzero(dp & dg) / max(np.count_nonzero(dg), 1)
|
208 |
+
)
|
209 |
+
p = 0.25 * (
|
210 |
+
np.count_nonzero(ap & ag) / max(np.count_nonzero(ap), 1)
|
211 |
+
+ np.count_nonzero(bp & bg) / max(np.count_nonzero(bp), 1)
|
212 |
+
+ np.count_nonzero(cp & cg) / max(np.count_nonzero(cp), 1)
|
213 |
+
+ np.count_nonzero(dp & dg) / max(np.count_nonzero(dp), 1)
|
214 |
+
)
|
215 |
+
if r + p == 0:
|
216 |
+
return 0.0
|
217 |
+
if return_p:
|
218 |
+
return p
|
219 |
+
if return_r:
|
220 |
+
return r
|
221 |
+
return 2 * (r * p) / (r + p)
|
222 |
+
|
223 |
+
|
224 |
+
def get_thresholds_and_weights(
|
225 |
+
t_min: float, t_max: float, N: int
|
226 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
227 |
+
"""Generate thresholds and weights for the given range.
|
228 |
+
|
229 |
+
Args:
|
230 |
+
----
|
231 |
+
t_min (float): Minimum threshold.
|
232 |
+
t_max (float): Maximum threshold.
|
233 |
+
N (int): Number of thresholds.
|
234 |
+
|
235 |
+
Returns:
|
236 |
+
-------
|
237 |
+
Tuple[np.ndarray, np.ndarray]: Array of thresholds and corresponding weights.
|
238 |
+
|
239 |
+
"""
|
240 |
+
thresholds = np.linspace(t_min, t_max, N)
|
241 |
+
weights = thresholds / thresholds.sum()
|
242 |
+
return thresholds, weights
|
243 |
+
|
244 |
+
|
245 |
+
def invert_depth(depth: np.ndarray, eps: float = 1e-6) -> np.ndarray:
|
246 |
+
"""Inverts a depth map with numerical stability.
|
247 |
+
|
248 |
+
Args:
|
249 |
+
----
|
250 |
+
depth (np.ndarray): Depth map to be inverted.
|
251 |
+
eps (float): Minimum value to avoid division by zero (default is 1e-6).
|
252 |
+
|
253 |
+
Returns:
|
254 |
+
-------
|
255 |
+
np.ndarray: Inverted depth map.
|
256 |
+
|
257 |
+
"""
|
258 |
+
inverse_depth = 1.0 / depth.clip(min=eps)
|
259 |
+
return inverse_depth
|
260 |
+
|
261 |
+
|
262 |
+
def SI_boundary_F1(
|
263 |
+
predicted_depth: np.ndarray,
|
264 |
+
target_depth: np.ndarray,
|
265 |
+
t_min: float = 1.05,
|
266 |
+
t_max: float = 1.25,
|
267 |
+
N: int = 10,
|
268 |
+
) -> float:
|
269 |
+
"""Calculate Scale-Invariant Boundary F1 Score for depth-based ground-truth.
|
270 |
+
|
271 |
+
Args:
|
272 |
+
----
|
273 |
+
predicted_depth (np.ndarray): Predicted depth matrix.
|
274 |
+
target_depth (np.ndarray): Ground truth depth matrix.
|
275 |
+
t_min (float, optional): Minimum threshold. Defaults to 1.05.
|
276 |
+
t_max (float, optional): Maximum threshold. Defaults to 1.25.
|
277 |
+
N (int, optional): Number of thresholds. Defaults to 10.
|
278 |
+
|
279 |
+
Returns:
|
280 |
+
-------
|
281 |
+
float: Scale-Invariant Boundary F1 Score.
|
282 |
+
|
283 |
+
"""
|
284 |
+
assert predicted_depth.ndim == target_depth.ndim == 2
|
285 |
+
thresholds, weights = get_thresholds_and_weights(t_min, t_max, N)
|
286 |
+
f1_scores = np.array(
|
287 |
+
[
|
288 |
+
boundary_f1(invert_depth(predicted_depth), invert_depth(target_depth), t)
|
289 |
+
for t in thresholds
|
290 |
+
]
|
291 |
+
)
|
292 |
+
return np.sum(f1_scores * weights)
|
293 |
+
|
294 |
+
|
295 |
+
def SI_boundary_Recall(
|
296 |
+
predicted_depth: np.ndarray,
|
297 |
+
target_mask: np.ndarray,
|
298 |
+
t_min: float = 1.05,
|
299 |
+
t_max: float = 1.25,
|
300 |
+
N: int = 10,
|
301 |
+
alpha_threshold: float = 0.1,
|
302 |
+
) -> float:
|
303 |
+
"""Calculate Scale-Invariant Boundary Recall Score for mask-based ground-truth.
|
304 |
+
|
305 |
+
Args:
|
306 |
+
----
|
307 |
+
predicted_depth (np.ndarray): Predicted depth matrix.
|
308 |
+
target_mask (np.ndarray): Ground truth binary mask.
|
309 |
+
t_min (float, optional): Minimum threshold. Defaults to 1.05.
|
310 |
+
t_max (float, optional): Maximum threshold. Defaults to 1.25.
|
311 |
+
N (int, optional): Number of thresholds. Defaults to 10.
|
312 |
+
alpha_threshold (float, optional): Threshold for alpha masking. Defaults to 0.1.
|
313 |
+
|
314 |
+
Returns:
|
315 |
+
-------
|
316 |
+
float: Scale-Invariant Boundary Recall Score.
|
317 |
+
|
318 |
+
"""
|
319 |
+
assert predicted_depth.ndim == target_mask.ndim == 2
|
320 |
+
thresholds, weights = get_thresholds_and_weights(t_min, t_max, N)
|
321 |
+
thresholded_target = target_mask > alpha_threshold
|
322 |
+
|
323 |
+
recall_scores = np.array(
|
324 |
+
[
|
325 |
+
edge_recall_matting(
|
326 |
+
invert_depth(predicted_depth), thresholded_target, t=float(t)
|
327 |
+
)
|
328 |
+
for t in thresholds
|
329 |
+
]
|
330 |
+
)
|
331 |
+
weighted_recall = np.sum(recall_scores * weights)
|
332 |
+
return weighted_recall
|
finetune/modules/depth_warping/depth_pro/eval/dis5k_sample_list.txt
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DIS5K/DIS-TE1/im/12#Graphics#4#TrafficSign#8245751856_821be14f86_o.jpg
|
2 |
+
DIS5K/DIS-TE1/im/13#Insect#4#Butterfly#16023994688_7ff8cdccb1_o.jpg
|
3 |
+
DIS5K/DIS-TE1/im/14#Kitchenware#4#Kitchenware#IMG_20210520_205538.jpg
|
4 |
+
DIS5K/DIS-TE1/im/14#Kitchenware#8#SweetStand#4848284981_fc90f54b50_o.jpg
|
5 |
+
DIS5K/DIS-TE1/im/17#Non-motor Vehicle#4#Cart#15012855035_d10b57014f_o.jpg
|
6 |
+
DIS5K/DIS-TE1/im/2#Aircraft#5#Kite#13104545564_5afceec9bd_o.jpg
|
7 |
+
DIS5K/DIS-TE1/im/20#Sports#10#Skateboarding#8472763540_bb2390e928_o.jpg
|
8 |
+
DIS5K/DIS-TE1/im/21#Tool#14#Sword#32473146960_dcc6b77848_o.jpg
|
9 |
+
DIS5K/DIS-TE1/im/21#Tool#15#Tapeline#9680492386_2d2020f282_o.jpg
|
10 |
+
DIS5K/DIS-TE1/im/21#Tool#4#Flag#507752845_ef852100f0_o.jpg
|
11 |
+
DIS5K/DIS-TE1/im/21#Tool#6#Key#11966089533_3becd78b44_o.jpg
|
12 |
+
DIS5K/DIS-TE1/im/21#Tool#8#Scale#31946428472_d28def471b_o.jpg
|
13 |
+
DIS5K/DIS-TE1/im/22#Weapon#4#Rifle#8472656430_3eb908b211_o.jpg
|
14 |
+
DIS5K/DIS-TE1/im/8#Electronics#3#Earphone#1177468301_641df8c267_o.jpg
|
15 |
+
DIS5K/DIS-TE1/im/8#Electronics#9#MusicPlayer#2235782872_7d47847bb4_o.jpg
|
16 |
+
DIS5K/DIS-TE2/im/11#Furniture#13#Ladder#3878434417_2ed740586e_o.jpg
|
17 |
+
DIS5K/DIS-TE2/im/13#Insect#1#Ant#27047700955_3b3a1271f8_o.jpg
|
18 |
+
DIS5K/DIS-TE2/im/13#Insect#11#Spider#5567179191_38d1f65589_o.jpg
|
19 |
+
DIS5K/DIS-TE2/im/13#Insect#8#Locust#5237933769_e6687c05e4_o.jpg
|
20 |
+
DIS5K/DIS-TE2/im/14#Kitchenware#2#DishRack#70838854_40cf689da7_o.jpg
|
21 |
+
DIS5K/DIS-TE2/im/14#Kitchenware#8#SweetStand#8467929412_fef7f4275d_o.jpg
|
22 |
+
DIS5K/DIS-TE2/im/16#Music Instrument#2#Harp#28058219806_28e05ff24a_o.jpg
|
23 |
+
DIS5K/DIS-TE2/im/17#Non-motor Vehicle#1#BabyCarriage#29794777180_2e1695a0cf_o.jpg
|
24 |
+
DIS5K/DIS-TE2/im/19#Ship#3#Sailboat#22442908623_5977e3becf_o.jpg
|
25 |
+
DIS5K/DIS-TE2/im/2#Aircraft#5#Kite#44654358051_1400e71cc4_o.jpg
|
26 |
+
DIS5K/DIS-TE2/im/21#Tool#11#Stand#IMG_20210520_205442.jpg
|
27 |
+
DIS5K/DIS-TE2/im/21#Tool#17#Tripod#9318977876_34615ec9a0_o.jpg
|
28 |
+
DIS5K/DIS-TE2/im/5#Artifact#3#Handcraft#50860882577_8482143b1b_o.jpg
|
29 |
+
DIS5K/DIS-TE2/im/8#Electronics#10#Robot#3093360210_fee54dc5c5_o.jpg
|
30 |
+
DIS5K/DIS-TE2/im/8#Electronics#6#Microphone#47411477652_6da66cbc10_o.jpg
|
31 |
+
DIS5K/DIS-TE3/im/14#Kitchenware#4#Kitchenware#2451122898_ef883175dd_o.jpg
|
32 |
+
DIS5K/DIS-TE3/im/15#Machine#4#SewingMachine#9311164128_97ba1d3947_o.jpg
|
33 |
+
DIS5K/DIS-TE3/im/16#Music Instrument#2#Harp#7670920550_59e992fd7b_o.jpg
|
34 |
+
DIS5K/DIS-TE3/im/17#Non-motor Vehicle#1#BabyCarriage#8389984877_1fddf8715c_o.jpg
|
35 |
+
DIS5K/DIS-TE3/im/17#Non-motor Vehicle#3#Carriage#5947122724_98e0fc3d1f_o.jpg
|
36 |
+
DIS5K/DIS-TE3/im/2#Aircraft#2#Balloon#2487168092_641505883f_o.jpg
|
37 |
+
DIS5K/DIS-TE3/im/2#Aircraft#4#Helicopter#8401177591_06c71c8df2_o.jpg
|
38 |
+
DIS5K/DIS-TE3/im/20#Sports#1#Archery#12520003103_faa43ea3e0_o.jpg
|
39 |
+
DIS5K/DIS-TE3/im/21#Tool#11#Stand#IMG_20210709_221507.jpg
|
40 |
+
DIS5K/DIS-TE3/im/21#Tool#2#Clip#5656649687_63d0c6696d_o.jpg
|
41 |
+
DIS5K/DIS-TE3/im/21#Tool#6#Key#12878459244_6387a140ea_o.jpg
|
42 |
+
DIS5K/DIS-TE3/im/3#Aquatic#1#Lobster#109214461_f52b4b6093_o.jpg
|
43 |
+
DIS5K/DIS-TE3/im/4#Architecture#19#Windmill#20195851863_2627117e0e_o.jpg
|
44 |
+
DIS5K/DIS-TE3/im/5#Artifact#2#Cage#5821476369_ea23927487_o.jpg
|
45 |
+
DIS5K/DIS-TE3/im/8#Electronics#7#MobileHolder#49732997896_7f53c290b5_o.jpg
|
46 |
+
DIS5K/DIS-TE4/im/13#Insect#6#Centipede#15302179708_a267850881_o.jpg
|
47 |
+
DIS5K/DIS-TE4/im/17#Non-motor Vehicle#11#Tricycle#5771069105_a3aef6f665_o.jpg
|
48 |
+
DIS5K/DIS-TE4/im/17#Non-motor Vehicle#2#Bicycle#4245936196_fdf812dcb7_o.jpg
|
49 |
+
DIS5K/DIS-TE4/im/17#Non-motor Vehicle#9#ShoppingCart#4674052920_a5b7a2b236_o.jpg
|
50 |
+
DIS5K/DIS-TE4/im/18#Plant#1#Bonsai#3539420884_ca8973e2c0_o.jpg
|
51 |
+
DIS5K/DIS-TE4/im/2#Aircraft#6#Parachute#33590416634_9d6f2325e7_o.jpg
|
52 |
+
DIS5K/DIS-TE4/im/20#Sports#1#Archery#46924476515_0be1caa684_o.jpg
|
53 |
+
DIS5K/DIS-TE4/im/20#Sports#8#Racket#19337607166_dd1985fb59_o.jpg
|
54 |
+
DIS5K/DIS-TE4/im/21#Tool#6#Key#3193329588_839b0c74ce_o.jpg
|
55 |
+
DIS5K/DIS-TE4/im/5#Artifact#2#Cage#5821886526_0573ba2d0d_o.jpg
|
56 |
+
DIS5K/DIS-TE4/im/5#Artifact#3#Handcraft#50105138282_3c1d02c968_o.jpg
|
57 |
+
DIS5K/DIS-TE4/im/8#Electronics#1#Antenna#4305034305_874f21a701_o.jpg
|
58 |
+
DIS5K/DIS-TR/im/1#Accessories#1#Bag#15554964549_3105e51b6f_o.jpg
|
59 |
+
DIS5K/DIS-TR/im/1#Accessories#1#Bag#41104261980_098a6c4a56_o.jpg
|
60 |
+
DIS5K/DIS-TR/im/1#Accessories#2#Clothes#2284764037_871b2e8ca4_o.jpg
|
61 |
+
DIS5K/DIS-TR/im/1#Accessories#3#Eyeglasses#1824643784_70d0134156_o.jpg
|
62 |
+
DIS5K/DIS-TR/im/1#Accessories#3#Eyeglasses#3590020230_37b09a29b3_o.jpg
|
63 |
+
DIS5K/DIS-TR/im/1#Accessories#3#Eyeglasses#4809652879_4da8a69f3b_o.jpg
|
64 |
+
DIS5K/DIS-TR/im/1#Accessories#3#Eyeglasses#792204934_f9b28f99b4_o.jpg
|
65 |
+
DIS5K/DIS-TR/im/1#Accessories#5#Jewelry#13909132974_c4750c5fb7_o.jpg
|
66 |
+
DIS5K/DIS-TR/im/1#Accessories#7#Shoe#2483391615_9199ece8d6_o.jpg
|
67 |
+
DIS5K/DIS-TR/im/1#Accessories#8#Watch#4343266960_f6633b029b_o.jpg
|
68 |
+
DIS5K/DIS-TR/im/10#Frame#2#BicycleFrame#17897573_42964dd104_o.jpg
|
69 |
+
DIS5K/DIS-TR/im/10#Frame#5#Rack#15898634812_64807069ff_o.jpg
|
70 |
+
DIS5K/DIS-TR/im/10#Frame#5#Rack#23928546819_c184cb0b60_o.jpg
|
71 |
+
DIS5K/DIS-TR/im/11#Furniture#19#Shower#6189119596_77bcfe80ee_o.jpg
|
72 |
+
DIS5K/DIS-TR/im/11#Furniture#2#Bench#3263647075_9306e280b5_o.jpg
|
73 |
+
DIS5K/DIS-TR/im/11#Furniture#5#CoatHanger#12774091054_cd5ff520ef_o.jpg
|
74 |
+
DIS5K/DIS-TR/im/11#Furniture#6#DentalChair#13878156865_d0439dcb32_o.jpg
|
75 |
+
DIS5K/DIS-TR/im/11#Furniture#9#Easel#5861024714_2070cd480c_o.jpg
|
76 |
+
DIS5K/DIS-TR/im/12#Graphics#4#TrafficSign#40621867334_f3c32ec189_o.jpg
|
77 |
+
DIS5K/DIS-TR/im/13#Insect#1#Ant#3295038190_db5dd0d4f4_o.jpg
|
78 |
+
DIS5K/DIS-TR/im/13#Insect#10#Mosquito#24341339_a88a1dad4c_o.jpg
|
79 |
+
DIS5K/DIS-TR/im/13#Insect#11#Spider#27171518270_63b78069ff_o.jpg
|
80 |
+
DIS5K/DIS-TR/im/13#Insect#11#Spider#49925050281_fa727c154e_o.jpg
|
81 |
+
DIS5K/DIS-TR/im/13#Insect#2#Beatle#279616486_2f1e64f591_o.jpg
|
82 |
+
DIS5K/DIS-TR/im/13#Insect#3#Bee#43892067695_82cf3e536b_o.jpg
|
83 |
+
DIS5K/DIS-TR/im/13#Insect#6#Centipede#20874281788_3e15c90a1c_o.jpg
|
84 |
+
DIS5K/DIS-TR/im/13#Insect#7#Dragonfly#14106671120_1b824d77e4_o.jpg
|
85 |
+
DIS5K/DIS-TR/im/13#Insect#8#Locust#21637491048_676ef7c9f7_o.jpg
|
86 |
+
DIS5K/DIS-TR/im/13#Insect#9#Mantis#1381120202_9dff6987b2_o.jpg
|
87 |
+
DIS5K/DIS-TR/im/14#Kitchenware#1#Cup#12812517473_327d6474b8_o.jpg
|
88 |
+
DIS5K/DIS-TR/im/14#Kitchenware#10#WineGlass#6402491641_389275d4d1_o.jpg
|
89 |
+
DIS5K/DIS-TR/im/14#Kitchenware#3#Hydrovalve#3129932040_8c05825004_o.jpg
|
90 |
+
DIS5K/DIS-TR/im/14#Kitchenware#4#Kitchenware#2881934780_87d5218ebb_o.jpg
|
91 |
+
DIS5K/DIS-TR/im/14#Kitchenware#4#Kitchenware#IMG_20210520_205527.jpg
|
92 |
+
DIS5K/DIS-TR/im/14#Kitchenware#6#Spoon#32989113501_b69eccf0df_o.jpg
|
93 |
+
DIS5K/DIS-TR/im/14#Kitchenware#8#SweetStand#2867322189_c56d1e0b87_o.jpg
|
94 |
+
DIS5K/DIS-TR/im/15#Machine#1#Gear#19217846720_f5f2807475_o.jpg
|
95 |
+
DIS5K/DIS-TR/im/15#Machine#2#Machine#1620160659_9571b7a7ab_o.jpg
|
96 |
+
DIS5K/DIS-TR/im/16#Music Instrument#2#Harp#6012801603_1a6e2c16a6_o.jpg
|
97 |
+
DIS5K/DIS-TR/im/16#Music Instrument#5#Trombone#8683292118_d223c17ccb_o.jpg
|
98 |
+
DIS5K/DIS-TR/im/16#Music Instrument#6#Trumpet#8393262740_b8c216142c_o.jpg
|
99 |
+
DIS5K/DIS-TR/im/16#Music Instrument#8#Violin#1511267391_40e4949d68_o.jpg
|
100 |
+
DIS5K/DIS-TR/im/17#Non-motor Vehicle#1#BabyCarriage#6989512997_38b3dbc88b_o.jpg
|
101 |
+
DIS5K/DIS-TR/im/17#Non-motor Vehicle#12#Wheel#14627183228_b2d68cf501_o.jpg
|
102 |
+
DIS5K/DIS-TR/im/17#Non-motor Vehicle#12#Wheel#2932226475_1b2403e549_o.jpg
|
103 |
+
DIS5K/DIS-TR/im/17#Non-motor Vehicle#12#Wheel#5420155648_86459905b8_o.jpg
|
104 |
+
DIS5K/DIS-TR/im/17#Non-motor Vehicle#2#Bicycle#IMG_20210513_134904.jpg
|
105 |
+
DIS5K/DIS-TR/im/17#Non-motor Vehicle#3#Carriage#3311962551_6f211b7bd6_o.jpg
|
106 |
+
DIS5K/DIS-TR/im/17#Non-motor Vehicle#4#Cart#2609732026_baf7fff3a1_o.jpg
|
107 |
+
DIS5K/DIS-TR/im/17#Non-motor Vehicle#5#Handcart#5821282211_201cefeaf2_o.jpg
|
108 |
+
DIS5K/DIS-TR/im/17#Non-motor Vehicle#7#Mower#5779003232_3bb3ae531a_o.jpg
|
109 |
+
DIS5K/DIS-TR/im/17#Non-motor Vehicle#9#ShoppingCart#10051622843_ace07e32b8_o.jpg
|
110 |
+
DIS5K/DIS-TR/im/17#Non-motor Vehicle#9#ShoppingCart#8075259294_f23e243849_o.jpg
|
111 |
+
DIS5K/DIS-TR/im/18#Plant#2#Tree#44800999741_e377e16dbb_o.jpg
|
112 |
+
DIS5K/DIS-TR/im/2#Aircraft#1#Airplane#2631761913_3ac67d0223_o.jpg
|
113 |
+
DIS5K/DIS-TR/im/2#Aircraft#1#Airplane#37707911566_e908a261b6_o.jpg
|
114 |
+
DIS5K/DIS-TR/im/2#Aircraft#3#HangGlider#2557220131_b8506920c5_o.jpg
|
115 |
+
DIS5K/DIS-TR/im/2#Aircraft#4#Helicopter#6215659280_5dbd9b4546_o.jpg
|
116 |
+
DIS5K/DIS-TR/im/2#Aircraft#6#Parachute#20185790493_e56fcaf8c6_o.jpg
|
117 |
+
DIS5K/DIS-TR/im/20#Sports#1#Archery#3871269982_ae4c59a7eb_o.jpg
|
118 |
+
DIS5K/DIS-TR/im/20#Sports#9#RockClimbing#9662433268_51299bc50e_o.jpg
|
119 |
+
DIS5K/DIS-TR/im/21#Tool#14#Sword#26258479365_2950d7fa37_o.jpg
|
120 |
+
DIS5K/DIS-TR/im/21#Tool#15#Tapeline#15505703447_e0fdeaa5a6_o.jpg
|
121 |
+
DIS5K/DIS-TR/im/21#Tool#4#Flag#26678602024_9b665742de_o.jpg
|
122 |
+
DIS5K/DIS-TR/im/21#Tool#4#Flag#5774823110_d603ce3cc8_o.jpg
|
123 |
+
DIS5K/DIS-TR/im/21#Tool#5#Hook#6867989814_dba18d673c_o.jpg
|
124 |
+
DIS5K/DIS-TR/im/22#Weapon#4#Rifle#4451713125_cd91719189_o.jpg
|
125 |
+
DIS5K/DIS-TR/im/3#Aquatic#2#Seadragon#4910944581_913139b238_o.jpg
|
126 |
+
DIS5K/DIS-TR/im/4#Architecture#12#Scaffold#3661448960_8aff24cc4d_o.jpg
|
127 |
+
DIS5K/DIS-TR/im/4#Architecture#13#Sculpture#6385318715_9a88d4eba7_o.jpg
|
128 |
+
DIS5K/DIS-TR/im/4#Architecture#17#Well#5011603479_75cf42808a_o.jpg
|
129 |
+
DIS5K/DIS-TR/im/5#Artifact#2#Cage#4892828841_7f1bc05682_o.jpg
|
130 |
+
DIS5K/DIS-TR/im/5#Artifact#3#Handcraft#15404211628_9e9ff2ce2e_o.jpg
|
131 |
+
DIS5K/DIS-TR/im/5#Artifact#3#Handcraft#3200169865_7c84cfcccf_o.jpg
|
132 |
+
DIS5K/DIS-TR/im/5#Artifact#3#Handcraft#5859295071_c217e7c22f_o.jpg
|
133 |
+
DIS5K/DIS-TR/im/6#Automobile#10#SteeringWheel#17200338026_f1e2122d8e_o.jpg
|
134 |
+
DIS5K/DIS-TR/im/6#Automobile#3#Car#3780893425_1a7d275e09_o.jpg
|
135 |
+
DIS5K/DIS-TR/im/6#Automobile#5#Crane#15282506502_1b1132a7c3_o.jpg
|
136 |
+
DIS5K/DIS-TR/im/7#Electrical#1#Cable#16767791875_8e6df41752_o.jpg
|
137 |
+
DIS5K/DIS-TR/im/7#Electrical#1#Cable#3291433361_38747324c4_o.jpg
|
138 |
+
DIS5K/DIS-TR/im/7#Electrical#1#Cable#4195104238_12a754c61a_o.jpg
|
139 |
+
DIS5K/DIS-TR/im/7#Electrical#1#Cable#49645415132_61e5664ecf_o.jpg
|
140 |
+
DIS5K/DIS-TR/im/7#Electrical#1#Cable#IMG_20210521_232406.jpg
|
141 |
+
DIS5K/DIS-TR/im/7#Electrical#10#UtilityPole#3298312021_92f431e3e9_o.jpg
|
142 |
+
DIS5K/DIS-TR/im/7#Electrical#10#UtilityPole#47950134773_fbfff63f4e_o.jpg
|
143 |
+
DIS5K/DIS-TR/im/7#Electrical#11#VacuumCleaner#5448403677_6a29e21881_o.jpg
|
144 |
+
DIS5K/DIS-TR/im/7#Electrical#2#CeilingLamp#611568868_680ed5d39f_o.jpg
|
145 |
+
DIS5K/DIS-TR/im/7#Electrical#3#Fan#3391683115_990525a693_o.jpg
|
146 |
+
DIS5K/DIS-TR/im/7#Electrical#6#StreetLamp#150049122_0692266618_o.jpg
|
147 |
+
DIS5K/DIS-TR/im/7#Electrical#9#TransmissionTower#31433908671_7e7e277dfe_o.jpg
|
148 |
+
DIS5K/DIS-TR/im/8#Electronics#1#Antenna#8727884873_e0622ee5c4_o.jpg
|
149 |
+
DIS5K/DIS-TR/im/8#Electronics#2#Camcorder#4172690390_7e5f280ace_o.jpg
|
150 |
+
DIS5K/DIS-TR/im/8#Electronics#3#Earphone#413984555_f290febdf5_o.jpg
|
151 |
+
DIS5K/DIS-TR/im/8#Electronics#5#Headset#30574225373_3717ed9fa4_o.jpg
|
152 |
+
DIS5K/DIS-TR/im/8#Electronics#6#Microphone#538006482_4aae4f5bd6_o.jpg
|
153 |
+
DIS5K/DIS-TR/im/8#Electronics#9#MusicPlayer#1306012480_2ea80d2afd_o.jpg
|
154 |
+
DIS5K/DIS-TR/im/9#Entertainment#1#GymEquipment#33071754135_8f3195cbd1_o.jpg
|
155 |
+
DIS5K/DIS-TR/im/9#Entertainment#2#KidsPlayground#2305807849_be53d724ea_o.jpg
|
156 |
+
DIS5K/DIS-TR/im/9#Entertainment#2#KidsPlayground#3862040422_5bbf903204_o.jpg
|
157 |
+
DIS5K/DIS-TR/im/9#Entertainment#3#OutdoorFitnessEquipment#10814507005_3dacaa28b3_o.jpg
|
158 |
+
DIS5K/DIS-TR/im/9#Entertainment#4#FerrisWheel#81640293_4b0ee62040_o.jpg
|
159 |
+
DIS5K/DIS-TR/im/9#Entertainment#5#Swing#49867339188_08073f4b76_o.jpg
|
160 |
+
DIS5K/DIS-VD/im/1#Accessories#1#Bag#6815402415_e01c1a41e6_o.jpg
|
161 |
+
DIS5K/DIS-VD/im/1#Accessories#5#Jewelry#2744070193_1486582e8d_o.jpg
|
162 |
+
DIS5K/DIS-VD/im/10#Frame#1#BasketballHoop#IMG_20210521_232650.jpg
|
163 |
+
DIS5K/DIS-VD/im/10#Frame#5#Rack#6156611713_49ebf12b1e_o.jpg
|
164 |
+
DIS5K/DIS-VD/im/11#Furniture#11#Handrail#3276641240_1b84b5af85_o.jpg
|
165 |
+
DIS5K/DIS-VD/im/11#Furniture#13#Ladder#33423266_5391cf47e9_o.jpg
|
166 |
+
DIS5K/DIS-VD/im/11#Furniture#17#Table#3725111755_4fc101e7ab_o.jpg
|
167 |
+
DIS5K/DIS-VD/im/11#Furniture#2#Bench#35556410400_7235b58070_o.jpg
|
168 |
+
DIS5K/DIS-VD/im/11#Furniture#4#Chair#3301769985_e49de6739f_o.jpg
|
169 |
+
DIS5K/DIS-VD/im/11#Furniture#6#DentalChair#23811071619_2a95c3a688_o.jpg
|
170 |
+
DIS5K/DIS-VD/im/11#Furniture#9#Easel#8322807354_df6d56542e_o.jpg
|
171 |
+
DIS5K/DIS-VD/im/13#Insect#10#Mosquito#12391674863_0cdf430d3f_o.jpg
|
172 |
+
DIS5K/DIS-VD/im/13#Insect#7#Dragonfly#14693028899_344ea118f2_o.jpg
|
173 |
+
DIS5K/DIS-VD/im/14#Kitchenware#10#WineGlass#4450148455_8f460f541a_o.jpg
|
174 |
+
DIS5K/DIS-VD/im/14#Kitchenware#3#Hydrovalve#IMG_20210520_203410.jpg
|
175 |
+
DIS5K/DIS-VD/im/15#Machine#3#PlowHarrow#34521712846_df4babb024_o.jpg
|
176 |
+
DIS5K/DIS-VD/im/16#Music Instrument#5#Trombone#6222242743_e7189405cd_o.jpg
|
177 |
+
DIS5K/DIS-VD/im/17#Non-motor Vehicle#12#Wheel#25677578797_ea47e1d9e8_o.jpg
|
178 |
+
DIS5K/DIS-VD/im/17#Non-motor Vehicle#2#Bicycle#5153474856_21560b081b_o.jpg
|
179 |
+
DIS5K/DIS-VD/im/17#Non-motor Vehicle#7#Mower#16992510572_8a6ff27398_o.jpg
|
180 |
+
DIS5K/DIS-VD/im/19#Ship#2#Canoe#40571458163_7faf8b73d9_o.jpg
|
181 |
+
DIS5K/DIS-VD/im/2#Aircraft#1#Airplane#4270588164_66a619e834_o.jpg
|
182 |
+
DIS5K/DIS-VD/im/2#Aircraft#4#Helicopter#86789665_650b94b2ee_o.jpg
|
183 |
+
DIS5K/DIS-VD/im/20#Sports#14#Wakesurfing#5589577652_5061c168d2_o.jpg
|
184 |
+
DIS5K/DIS-VD/im/21#Tool#10#Spade#37018312543_63b21b0784_o.jpg
|
185 |
+
DIS5K/DIS-VD/im/21#Tool#14#Sword#24789047250_42df9bf422_o.jpg
|
186 |
+
DIS5K/DIS-VD/im/21#Tool#18#Umbrella#IMG_20210513_140445.jpg
|
187 |
+
DIS5K/DIS-VD/im/21#Tool#6#Key#43939732715_5a6e28b518_o.jpg
|
188 |
+
DIS5K/DIS-VD/im/22#Weapon#1#Cannon#12758066705_90b54295e7_o.jpg
|
189 |
+
DIS5K/DIS-VD/im/22#Weapon#4#Rifle#8019368790_fb6dc469a7_o.jpg
|
190 |
+
DIS5K/DIS-VD/im/3#Aquatic#5#Shrimp#2582833427_7a99e7356e_o.jpg
|
191 |
+
DIS5K/DIS-VD/im/4#Architecture#12#Scaffold#1013402687_590750354e_o.jpg
|
192 |
+
DIS5K/DIS-VD/im/4#Architecture#13#Sculpture#17176841759_272a3ed6e3_o.jpg
|
193 |
+
DIS5K/DIS-VD/im/4#Architecture#14#Stair#15079108505_0d11281624_o.jpg
|
194 |
+
DIS5K/DIS-VD/im/4#Architecture#19#Windmill#2928111082_ceb3051c04_o.jpg
|
195 |
+
DIS5K/DIS-VD/im/4#Architecture#3#Crack#3551574032_17dd106d31_o.jpg
|
196 |
+
DIS5K/DIS-VD/im/4#Architecture#5#GasStation#4564307581_c3069bdc62_o.jpg
|
197 |
+
DIS5K/DIS-VD/im/4#Architecture#8#ObservationTower#2704526950_d4f0ddc807_o.jpg
|
198 |
+
DIS5K/DIS-VD/im/5#Artifact#3#Handcraft#10873642323_1bafce3aa5_o.jpg
|
199 |
+
DIS5K/DIS-VD/im/6#Automobile#11#Tractor#8594504006_0c2c557d85_o.jpg
|
200 |
+
DIS5K/DIS-VD/im/8#Electronics#3#Earphone#8106454803_1178d867cc_o.jpg
|
finetune/modules/depth_warping/depth_pro/network/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
2 |
+
"""Depth Pro network blocks."""
|
finetune/modules/depth_warping/depth_pro/network/decoder.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
2 |
+
|
3 |
+
Dense Prediction Transformer Decoder architecture.
|
4 |
+
|
5 |
+
Implements a variant of Vision Transformers for Dense Prediction, https://arxiv.org/abs/2103.13413
|
6 |
+
"""
|
7 |
+
|
8 |
+
from __future__ import annotations
|
9 |
+
|
10 |
+
from typing import Iterable
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch import nn
|
14 |
+
|
15 |
+
|
16 |
+
class MultiresConvDecoder(nn.Module):
|
17 |
+
"""Decoder for multi-resolution encodings."""
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
dims_encoder: Iterable[int],
|
22 |
+
dim_decoder: int,
|
23 |
+
):
|
24 |
+
"""Initialize multiresolution convolutional decoder.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
----
|
28 |
+
dims_encoder: Expected dims at each level from the encoder.
|
29 |
+
dim_decoder: Dim of decoder features.
|
30 |
+
|
31 |
+
"""
|
32 |
+
super().__init__()
|
33 |
+
self.dims_encoder = list(dims_encoder)
|
34 |
+
self.dim_decoder = dim_decoder
|
35 |
+
self.dim_out = dim_decoder
|
36 |
+
|
37 |
+
num_encoders = len(self.dims_encoder)
|
38 |
+
|
39 |
+
# At the highest resolution, i.e. level 0, we apply projection w/ 1x1 convolution
|
40 |
+
# when the dimensions mismatch. Otherwise we do not do anything, which is
|
41 |
+
# the default behavior of monodepth.
|
42 |
+
conv0 = (
|
43 |
+
nn.Conv2d(self.dims_encoder[0], dim_decoder, kernel_size=1, bias=False)
|
44 |
+
if self.dims_encoder[0] != dim_decoder
|
45 |
+
else nn.Identity()
|
46 |
+
)
|
47 |
+
|
48 |
+
convs = [conv0]
|
49 |
+
for i in range(1, num_encoders):
|
50 |
+
convs.append(
|
51 |
+
nn.Conv2d(
|
52 |
+
self.dims_encoder[i],
|
53 |
+
dim_decoder,
|
54 |
+
kernel_size=3,
|
55 |
+
stride=1,
|
56 |
+
padding=1,
|
57 |
+
bias=False,
|
58 |
+
)
|
59 |
+
)
|
60 |
+
|
61 |
+
self.convs = nn.ModuleList(convs)
|
62 |
+
|
63 |
+
fusions = []
|
64 |
+
for i in range(num_encoders):
|
65 |
+
fusions.append(
|
66 |
+
FeatureFusionBlock2d(
|
67 |
+
num_features=dim_decoder,
|
68 |
+
deconv=(i != 0),
|
69 |
+
batch_norm=False,
|
70 |
+
)
|
71 |
+
)
|
72 |
+
self.fusions = nn.ModuleList(fusions)
|
73 |
+
|
74 |
+
def forward(self, encodings: torch.Tensor) -> torch.Tensor:
|
75 |
+
"""Decode the multi-resolution encodings."""
|
76 |
+
num_levels = len(encodings)
|
77 |
+
num_encoders = len(self.dims_encoder)
|
78 |
+
|
79 |
+
if num_levels != num_encoders:
|
80 |
+
raise ValueError(
|
81 |
+
f"Got encoder output levels={num_levels}, expected levels={num_encoders+1}."
|
82 |
+
)
|
83 |
+
|
84 |
+
# Project features of different encoder dims to the same decoder dim.
|
85 |
+
# Fuse features from the lowest resolution (num_levels-1)
|
86 |
+
# to the highest (0).
|
87 |
+
features = self.convs[-1](encodings[-1])
|
88 |
+
lowres_features = features
|
89 |
+
features = self.fusions[-1](features)
|
90 |
+
for i in range(num_levels - 2, -1, -1):
|
91 |
+
features_i = self.convs[i](encodings[i])
|
92 |
+
features = self.fusions[i](features, features_i)
|
93 |
+
return features, lowres_features
|
94 |
+
|
95 |
+
|
96 |
+
class ResidualBlock(nn.Module):
|
97 |
+
"""Generic implementation of residual blocks.
|
98 |
+
|
99 |
+
This implements a generic residual block from
|
100 |
+
He et al. - Identity Mappings in Deep Residual Networks (2016),
|
101 |
+
https://arxiv.org/abs/1603.05027
|
102 |
+
which can be further customized via factory functions.
|
103 |
+
"""
|
104 |
+
|
105 |
+
def __init__(self, residual: nn.Module, shortcut: nn.Module | None = None) -> None:
|
106 |
+
"""Initialize ResidualBlock."""
|
107 |
+
super().__init__()
|
108 |
+
self.residual = residual
|
109 |
+
self.shortcut = shortcut
|
110 |
+
|
111 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
112 |
+
"""Apply residual block."""
|
113 |
+
delta_x = self.residual(x)
|
114 |
+
|
115 |
+
if self.shortcut is not None:
|
116 |
+
x = self.shortcut(x)
|
117 |
+
|
118 |
+
return x + delta_x
|
119 |
+
|
120 |
+
|
121 |
+
class FeatureFusionBlock2d(nn.Module):
|
122 |
+
"""Feature fusion for DPT."""
|
123 |
+
|
124 |
+
def __init__(
|
125 |
+
self,
|
126 |
+
num_features: int,
|
127 |
+
deconv: bool = False,
|
128 |
+
batch_norm: bool = False,
|
129 |
+
):
|
130 |
+
"""Initialize feature fusion block.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
----
|
134 |
+
num_features: Input and output dimensions.
|
135 |
+
deconv: Whether to use deconv before the final output conv.
|
136 |
+
batch_norm: Whether to use batch normalization in resnet blocks.
|
137 |
+
|
138 |
+
"""
|
139 |
+
super().__init__()
|
140 |
+
|
141 |
+
self.resnet1 = self._residual_block(num_features, batch_norm)
|
142 |
+
self.resnet2 = self._residual_block(num_features, batch_norm)
|
143 |
+
|
144 |
+
self.use_deconv = deconv
|
145 |
+
if deconv:
|
146 |
+
self.deconv = nn.ConvTranspose2d(
|
147 |
+
in_channels=num_features,
|
148 |
+
out_channels=num_features,
|
149 |
+
kernel_size=2,
|
150 |
+
stride=2,
|
151 |
+
padding=0,
|
152 |
+
bias=False,
|
153 |
+
)
|
154 |
+
|
155 |
+
self.out_conv = nn.Conv2d(
|
156 |
+
num_features,
|
157 |
+
num_features,
|
158 |
+
kernel_size=1,
|
159 |
+
stride=1,
|
160 |
+
padding=0,
|
161 |
+
bias=True,
|
162 |
+
)
|
163 |
+
|
164 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
165 |
+
|
166 |
+
def forward(self, x0: torch.Tensor, x1: torch.Tensor | None = None) -> torch.Tensor:
|
167 |
+
"""Process and fuse input features."""
|
168 |
+
x = x0
|
169 |
+
|
170 |
+
if x1 is not None:
|
171 |
+
res = self.resnet1(x1)
|
172 |
+
x = self.skip_add.add(x, res)
|
173 |
+
|
174 |
+
x = self.resnet2(x)
|
175 |
+
|
176 |
+
if self.use_deconv:
|
177 |
+
x = self.deconv(x)
|
178 |
+
x = self.out_conv(x)
|
179 |
+
|
180 |
+
return x
|
181 |
+
|
182 |
+
@staticmethod
|
183 |
+
def _residual_block(num_features: int, batch_norm: bool):
|
184 |
+
"""Create a residual block."""
|
185 |
+
|
186 |
+
def _create_block(dim: int, batch_norm: bool) -> list[nn.Module]:
|
187 |
+
layers = [
|
188 |
+
nn.ReLU(False),
|
189 |
+
nn.Conv2d(
|
190 |
+
num_features,
|
191 |
+
num_features,
|
192 |
+
kernel_size=3,
|
193 |
+
stride=1,
|
194 |
+
padding=1,
|
195 |
+
bias=not batch_norm,
|
196 |
+
),
|
197 |
+
]
|
198 |
+
if batch_norm:
|
199 |
+
layers.append(nn.BatchNorm2d(dim))
|
200 |
+
return layers
|
201 |
+
|
202 |
+
residual = nn.Sequential(
|
203 |
+
*_create_block(dim=num_features, batch_norm=batch_norm),
|
204 |
+
*_create_block(dim=num_features, batch_norm=batch_norm),
|
205 |
+
)
|
206 |
+
return ResidualBlock(residual)
|