roll-ai commited on
Commit
59d751c
·
verified ·
1 Parent(s): 70e2d21

Upload 177 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. finetune/accelerate_config.yaml +21 -0
  3. finetune/configs/zero2.yaml +38 -0
  4. finetune/configs/zero2_controlnet.yaml +38 -0
  5. finetune/configs/zero2_offload.yaml +42 -0
  6. finetune/configs/zero3.yaml +43 -0
  7. finetune/configs/zero3_offload.yaml +51 -0
  8. finetune/constants.py +2 -0
  9. finetune/datasets/__init__.py +14 -0
  10. finetune/datasets/bucket_sampler.py +71 -0
  11. finetune/datasets/i2v_dataset.py +311 -0
  12. finetune/datasets/i2v_flow_dataset.py +188 -0
  13. finetune/datasets/t2v_dataset.py +251 -0
  14. finetune/datasets/utils.py +211 -0
  15. finetune/models/__init__.py +12 -0
  16. finetune/models/cogvideox_i2v/flovd_OMSM_lora_trainer.py +748 -0
  17. finetune/models/cogvideox_i2v/flovd_controlnet_trainer.py +814 -0
  18. finetune/models/cogvideox_i2v/lora_trainer.py +246 -0
  19. finetune/models/cogvideox_i2v/sft_trainer.py +9 -0
  20. finetune/models/utils.py +57 -0
  21. finetune/modules/__init__.py +0 -0
  22. finetune/modules/camera_flow_generator.py +46 -0
  23. finetune/modules/camera_sampler.py +52 -0
  24. finetune/modules/cogvideox_controlnet.py +353 -0
  25. finetune/modules/cogvideox_custom_model.py +109 -0
  26. finetune/modules/cogvideox_custom_modules.py +357 -0
  27. finetune/modules/depth_warping/__init__.py +0 -0
  28. finetune/modules/depth_warping/camera/Camera.py +70 -0
  29. finetune/modules/depth_warping/camera/WarperPytorch.py +416 -0
  30. finetune/modules/depth_warping/depth_anything_v2/depth_anything_wrapper.py +12 -0
  31. finetune/modules/depth_warping/depth_anything_v2/dinov2.py +415 -0
  32. finetune/modules/depth_warping/depth_anything_v2/dinov2_layers/__init__.py +11 -0
  33. finetune/modules/depth_warping/depth_anything_v2/dinov2_layers/attention.py +83 -0
  34. finetune/modules/depth_warping/depth_anything_v2/dinov2_layers/block.py +252 -0
  35. finetune/modules/depth_warping/depth_anything_v2/dinov2_layers/drop_path.py +35 -0
  36. finetune/modules/depth_warping/depth_anything_v2/dinov2_layers/layer_scale.py +28 -0
  37. finetune/modules/depth_warping/depth_anything_v2/dinov2_layers/mlp.py +41 -0
  38. finetune/modules/depth_warping/depth_anything_v2/dinov2_layers/patch_embed.py +89 -0
  39. finetune/modules/depth_warping/depth_anything_v2/dinov2_layers/swiglu_ffn.py +63 -0
  40. finetune/modules/depth_warping/depth_anything_v2/dpt.py +235 -0
  41. finetune/modules/depth_warping/depth_anything_v2/util/blocks.py +148 -0
  42. finetune/modules/depth_warping/depth_anything_v2/util/transform.py +158 -0
  43. finetune/modules/depth_warping/depth_pro/__init__.py +5 -0
  44. finetune/modules/depth_warping/depth_pro/cli/__init__.py +4 -0
  45. finetune/modules/depth_warping/depth_pro/cli/run.py +154 -0
  46. finetune/modules/depth_warping/depth_pro/depth_pro.py +298 -0
  47. finetune/modules/depth_warping/depth_pro/eval/boundary_metrics.py +332 -0
  48. finetune/modules/depth_warping/depth_pro/eval/dis5k_sample_list.txt +200 -0
  49. finetune/modules/depth_warping/depth_pro/network/__init__.py +2 -0
  50. finetune/modules/depth_warping/depth_pro/network/decoder.py +206 -0
.gitattributes CHANGED
@@ -73,3 +73,7 @@ assets/pages/res1.mp4 filter=lfs diff=lfs merge=lfs -text
73
  assets/pages/res2.mp4 filter=lfs diff=lfs merge=lfs -text
74
  assets/pages/res3.mp4 filter=lfs diff=lfs merge=lfs -text
75
  assets/pages/teaser.png filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
73
  assets/pages/res2.mp4 filter=lfs diff=lfs merge=lfs -text
74
  assets/pages/res3.mp4 filter=lfs diff=lfs merge=lfs -text
75
  assets/pages/teaser.png filter=lfs diff=lfs merge=lfs -text
76
+ results/generated_videos/A_chef_in_a_white_coat_and_gla_1593596b99e2dde9.txt.mp4 filter=lfs diff=lfs merge=lfs -text
77
+ results/generated_videos/A_stunning_and_untouched_coast_6b6d20c6a46b9fe9.txt.mp4 filter=lfs diff=lfs merge=lfs -text
78
+ tools/caption/assests/CogVLM2-Caption-example.png filter=lfs diff=lfs merge=lfs -text
79
+ tools/caption/assests/cogvlm2-video-example.png filter=lfs diff=lfs merge=lfs -text
finetune/accelerate_config.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+
3
+ gpu_ids: "0,1,2,3,4,5,6,7"
4
+ num_processes: 8 # should be the same as the number of GPUs
5
+
6
+ debug: false
7
+ deepspeed_config:
8
+ deepspeed_config_file: configs/zero2_controlnet.yaml # e.g. configs/zero2.yaml, need use absolute path
9
+ zero3_init_flag: false
10
+ distributed_type: DEEPSPEED
11
+ downcast_bf16: 'no'
12
+ enable_cpu_affinity: false
13
+ machine_rank: 0
14
+ main_training_function: main
15
+ num_machines: 1
16
+ rdzv_backend: static
17
+ same_network: true
18
+ tpu_env: []
19
+ tpu_use_cluster: false
20
+ tpu_use_sudo: false
21
+ use_cpu: false
finetune/configs/zero2.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {
3
+ "enabled": true
4
+ },
5
+ "optimizer": {
6
+ "type": "AdamW",
7
+ "params": {
8
+ "lr": "auto",
9
+ "weight_decay": "auto",
10
+ "torch_adam": true,
11
+ "adam_w_mode": true
12
+ }
13
+ },
14
+ "scheduler": {
15
+ "type": "WarmupDecayLR",
16
+ "params": {
17
+ "warmup_min_lr": "auto",
18
+ "warmup_max_lr": "auto",
19
+ "warmup_num_steps": "auto",
20
+ "total_num_steps": "auto"
21
+ }
22
+ },
23
+ "zero_optimization": {
24
+ "stage": 2,
25
+ "allgather_partitions": true,
26
+ "allgather_bucket_size": 2e8,
27
+ "overlap_comm": true,
28
+ "reduce_scatter": true,
29
+ "reduce_bucket_size": 5e8,
30
+ "contiguous_gradients": true
31
+ },
32
+ "gradient_accumulation_steps": 1,
33
+ "train_micro_batch_size_per_gpu": 1,
34
+ "train_batch_size": "auto",
35
+ "gradient_clipping": "auto",
36
+ "steps_per_print": 2000,
37
+ "wall_clock_breakdown": false
38
+ }
finetune/configs/zero2_controlnet.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {
3
+ "enabled": true
4
+ },
5
+ "optimizer": {
6
+ "type": "AdamW",
7
+ "params": {
8
+ "lr": "auto",
9
+ "weight_decay": "auto",
10
+ "torch_adam": true,
11
+ "adam_w_mode": true
12
+ }
13
+ },
14
+ "scheduler": {
15
+ "type": "WarmupCosineLR",
16
+ "params": {
17
+ "warmup_min_ratio": 0.0,
18
+ "cos_min_ratio": 0.0001,
19
+ "warmup_num_steps": 250,
20
+ "total_num_steps": "auto"
21
+ }
22
+ },
23
+ "zero_optimization": {
24
+ "stage": 2,
25
+ "allgather_partitions": true,
26
+ "allgather_bucket_size": 2e8,
27
+ "overlap_comm": true,
28
+ "reduce_scatter": true,
29
+ "reduce_bucket_size": 5e8,
30
+ "contiguous_gradients": true
31
+ },
32
+ "gradient_accumulation_steps": 1,
33
+ "train_micro_batch_size_per_gpu": 1,
34
+ "train_batch_size": "auto",
35
+ "gradient_clipping": "auto",
36
+ "steps_per_print": 2000,
37
+ "wall_clock_breakdown": false
38
+ }
finetune/configs/zero2_offload.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {
3
+ "enabled": true
4
+ },
5
+ "optimizer": {
6
+ "type": "AdamW",
7
+ "params": {
8
+ "lr": "auto",
9
+ "weight_decay": "auto",
10
+ "torch_adam": true,
11
+ "adam_w_mode": true
12
+ }
13
+ },
14
+ "scheduler": {
15
+ "type": "WarmupDecayLR",
16
+ "params": {
17
+ "warmup_min_lr": "auto",
18
+ "warmup_max_lr": "auto",
19
+ "warmup_num_steps": "auto",
20
+ "total_num_steps": "auto"
21
+ }
22
+ },
23
+ "zero_optimization": {
24
+ "stage": 2,
25
+ "allgather_partitions": true,
26
+ "allgather_bucket_size": 2e8,
27
+ "overlap_comm": true,
28
+ "reduce_scatter": true,
29
+ "reduce_bucket_size": 5e8,
30
+ "contiguous_gradients": true,
31
+ "offload_optimizer": {
32
+ "device": "cpu",
33
+ "pin_memory": true
34
+ }
35
+ },
36
+ "gradient_accumulation_steps": 1,
37
+ "train_micro_batch_size_per_gpu": 1,
38
+ "train_batch_size": "auto",
39
+ "gradient_clipping": "auto",
40
+ "steps_per_print": 2000,
41
+ "wall_clock_breakdown": false
42
+ }
finetune/configs/zero3.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {
3
+ "enabled": true
4
+ },
5
+ "optimizer": {
6
+ "type": "AdamW",
7
+ "params": {
8
+ "lr": "auto",
9
+ "weight_decay": "auto",
10
+ "torch_adam": true,
11
+ "adam_w_mode": true
12
+ }
13
+ },
14
+ "scheduler": {
15
+ "type": "WarmupDecayLR",
16
+ "params": {
17
+ "warmup_min_lr": "auto",
18
+ "warmup_max_lr": "auto",
19
+ "warmup_num_steps": "auto",
20
+ "total_num_steps": "auto"
21
+ }
22
+ },
23
+ "zero_optimization": {
24
+ "stage": 3,
25
+ "overlap_comm": true,
26
+ "contiguous_gradients": true,
27
+ "reduce_bucket_size": 5e8,
28
+ "stage3_prefetch_bucket_size": "auto",
29
+ "stage3_param_persistence_threshold": "auto",
30
+ "sub_group_size": 1e9,
31
+ "stage3_max_live_parameters": 1e9,
32
+ "stage3_max_reuse_distance": 1e9,
33
+ "stage3_gather_16bit_weights_on_model_save": "auto",
34
+ "stage3_prefetch_bucket_size": 5e8,
35
+ "stage3_param_persistence_threshold": 1e5
36
+ },
37
+ "gradient_accumulation_steps": 1,
38
+ "train_micro_batch_size_per_gpu": 1,
39
+ "train_batch_size": "auto",
40
+ "gradient_clipping": "auto",
41
+ "steps_per_print": 2000,
42
+ "wall_clock_breakdown": false
43
+ }
finetune/configs/zero3_offload.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {
3
+ "enabled": true
4
+ },
5
+ "optimizer": {
6
+ "type": "AdamW",
7
+ "params": {
8
+ "lr": "auto",
9
+ "weight_decay": "auto",
10
+ "torch_adam": true,
11
+ "adam_w_mode": true
12
+ }
13
+ },
14
+ "scheduler": {
15
+ "type": "WarmupDecayLR",
16
+ "params": {
17
+ "warmup_min_lr": "auto",
18
+ "warmup_max_lr": "auto",
19
+ "warmup_num_steps": "auto",
20
+ "total_num_steps": "auto"
21
+ }
22
+ },
23
+ "zero_optimization": {
24
+ "stage": 3,
25
+ "offload_optimizer": {
26
+ "device": "cpu",
27
+ "pin_memory": true
28
+ },
29
+ "offload_param": {
30
+ "device": "cpu",
31
+ "pin_memory": true
32
+ },
33
+ "overlap_comm": true,
34
+ "contiguous_gradients": true,
35
+ "reduce_bucket_size": 5e8,
36
+ "stage3_prefetch_bucket_size": "auto",
37
+ "stage3_param_persistence_threshold": "auto",
38
+ "sub_group_size": 1e9,
39
+ "stage3_max_live_parameters": 1e9,
40
+ "stage3_max_reuse_distance": 1e9,
41
+ "stage3_gather_16bit_weights_on_model_save": "auto",
42
+ "stage3_prefetch_bucket_size": 5e8,
43
+ "stage3_param_persistence_threshold": 1e6
44
+ },
45
+ "gradient_accumulation_steps": 1,
46
+ "train_micro_batch_size_per_gpu": 1,
47
+ "train_batch_size": "auto",
48
+ "gradient_clipping": "auto",
49
+ "steps_per_print": 2000,
50
+ "wall_clock_breakdown": false
51
+ }
finetune/constants.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ LOG_NAME = "trainer"
2
+ LOG_LEVEL = "INFO"
finetune/datasets/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .bucket_sampler import BucketSampler
2
+ from .i2v_dataset import I2VDatasetWithBuckets, I2VDatasetWithResize
3
+ from .t2v_dataset import T2VDatasetWithBuckets, T2VDatasetWithResize
4
+ from .i2v_flow_dataset import I2VFlowDataset
5
+
6
+
7
+ __all__ = [
8
+ "I2VDatasetWithResize",
9
+ "I2VDatasetWithBuckets",
10
+ "T2VDatasetWithResize",
11
+ "T2VDatasetWithBuckets",
12
+ "BucketSampler",
13
+ "I2VFlowDataset",
14
+ ]
finetune/datasets/bucket_sampler.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+
4
+ from torch.utils.data import Dataset, Sampler
5
+
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ class BucketSampler(Sampler):
11
+ r"""
12
+ PyTorch Sampler that groups 3D data by height, width and frames.
13
+
14
+ Args:
15
+ data_source (`VideoDataset`):
16
+ A PyTorch dataset object that is an instance of `VideoDataset`.
17
+ batch_size (`int`, defaults to `8`):
18
+ The batch size to use for training.
19
+ shuffle (`bool`, defaults to `True`):
20
+ Whether or not to shuffle the data in each batch before dispatching to dataloader.
21
+ drop_last (`bool`, defaults to `False`):
22
+ Whether or not to drop incomplete buckets of data after completely iterating over all data
23
+ in the dataset. If set to True, only batches that have `batch_size` number of entries will
24
+ be yielded. If set to False, it is guaranteed that all data in the dataset will be processed
25
+ and batches that do not have `batch_size` number of entries will also be yielded.
26
+ """
27
+
28
+ def __init__(
29
+ self, data_source: Dataset, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False
30
+ ) -> None:
31
+ self.data_source = data_source
32
+ self.batch_size = batch_size
33
+ self.shuffle = shuffle
34
+ self.drop_last = drop_last
35
+
36
+ self.buckets = {resolution: [] for resolution in data_source.video_resolution_buckets}
37
+
38
+ self._raised_warning_for_drop_last = False
39
+
40
+ def __len__(self):
41
+ if self.drop_last and not self._raised_warning_for_drop_last:
42
+ self._raised_warning_for_drop_last = True
43
+ logger.warning(
44
+ "Calculating the length for bucket sampler is not possible when `drop_last` is set to True. This may cause problems when setting the number of epochs used for training."
45
+ )
46
+ return (len(self.data_source) + self.batch_size - 1) // self.batch_size
47
+
48
+ def __iter__(self):
49
+ for index, data in enumerate(self.data_source):
50
+ video_metadata = data["video_metadata"]
51
+ f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"]
52
+
53
+ self.buckets[(f, h, w)].append(data)
54
+ if len(self.buckets[(f, h, w)]) == self.batch_size:
55
+ if self.shuffle:
56
+ random.shuffle(self.buckets[(f, h, w)])
57
+ yield self.buckets[(f, h, w)]
58
+ del self.buckets[(f, h, w)]
59
+ self.buckets[(f, h, w)] = []
60
+
61
+ if self.drop_last:
62
+ return
63
+
64
+ for fhw, bucket in list(self.buckets.items()):
65
+ if len(bucket) == 0:
66
+ continue
67
+ if self.shuffle:
68
+ random.shuffle(bucket)
69
+ yield bucket
70
+ del self.buckets[fhw]
71
+ self.buckets[fhw] = []
finetune/datasets/i2v_dataset.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ from pathlib import Path
3
+ from typing import TYPE_CHECKING, Any, Dict, List, Tuple
4
+
5
+ import torch
6
+ from accelerate.logging import get_logger
7
+ from safetensors.torch import load_file, save_file
8
+ from torch.utils.data import Dataset
9
+ from torchvision import transforms
10
+ from typing_extensions import override
11
+
12
+ from finetune.constants import LOG_LEVEL, LOG_NAME
13
+
14
+ from .utils import (
15
+ load_images,
16
+ load_images_from_videos,
17
+ load_prompts,
18
+ load_videos,
19
+ preprocess_image_with_resize,
20
+ preprocess_video_with_buckets,
21
+ preprocess_video_with_resize,
22
+ )
23
+
24
+
25
+ if TYPE_CHECKING:
26
+ from finetune.trainer import Trainer
27
+
28
+ # Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
29
+ # Very few bug reports but it happens. Look in decord Github issues for more relevant information.
30
+ import decord # isort:skip
31
+
32
+ decord.bridge.set_bridge("torch")
33
+
34
+ logger = get_logger(LOG_NAME, LOG_LEVEL)
35
+
36
+
37
+ class BaseI2VDataset(Dataset):
38
+ """
39
+ Base dataset class for Image-to-Video (I2V) training.
40
+
41
+ This dataset loads prompts, videos and corresponding conditioning images for I2V training.
42
+
43
+ Args:
44
+ data_root (str): Root directory containing the dataset files
45
+ caption_column (str): Path to file containing text prompts/captions
46
+ video_column (str): Path to file containing video paths
47
+ image_column (str): Path to file containing image paths
48
+ device (torch.device): Device to load the data on
49
+ encode_video_fn (Callable[[torch.Tensor], torch.Tensor], optional): Function to encode videos
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ data_root: str,
55
+ caption_column: str,
56
+ video_column: str,
57
+ image_column: str | None,
58
+ device: torch.device,
59
+ trainer: "Trainer" = None,
60
+ *args,
61
+ **kwargs,
62
+ ) -> None:
63
+ super().__init__()
64
+
65
+ data_root = Path(data_root)
66
+ self.prompts = load_prompts(data_root / caption_column)
67
+ self.videos = load_videos(data_root / video_column)
68
+ if image_column is not None:
69
+ self.images = load_images(data_root / image_column)
70
+ else:
71
+ self.images = load_images_from_videos(self.videos)
72
+ self.trainer = trainer
73
+
74
+ self.device = device
75
+ self.encode_video = trainer.encode_video
76
+ self.encode_text = trainer.encode_text
77
+
78
+ # Check if number of prompts matches number of videos and images
79
+ if not (len(self.videos) == len(self.prompts) == len(self.images)):
80
+ raise ValueError(
81
+ f"Expected length of prompts, videos and images to be the same but found {len(self.prompts)=}, {len(self.videos)=} and {len(self.images)=}. Please ensure that the number of caption prompts, videos and images match in your dataset."
82
+ )
83
+
84
+ # Check if all video files exist
85
+ if any(not path.is_file() for path in self.videos):
86
+ raise ValueError(
87
+ f"Some video files were not found. Please ensure that all video files exist in the dataset directory. Missing file: {next(path for path in self.videos if not path.is_file())}"
88
+ )
89
+
90
+ # Check if all image files exist
91
+ if any(not path.is_file() for path in self.images):
92
+ raise ValueError(
93
+ f"Some image files were not found. Please ensure that all image files exist in the dataset directory. Missing file: {next(path for path in self.images if not path.is_file())}"
94
+ )
95
+
96
+ def __len__(self) -> int:
97
+ return len(self.videos)
98
+
99
+ def __getitem__(self, index: int) -> Dict[str, Any]:
100
+ if isinstance(index, list):
101
+ # Here, index is actually a list of data objects that we need to return.
102
+ # The BucketSampler should ideally return indices. But, in the sampler, we'd like
103
+ # to have information about num_frames, height and width. Since this is not stored
104
+ # as metadata, we need to read the video to get this information. You could read this
105
+ # information without loading the full video in memory, but we do it anyway. In order
106
+ # to not load the video twice (once to get the metadata, and once to return the loaded video
107
+ # based on sampled indices), we cache it in the BucketSampler. When the sampler is
108
+ # to yield, we yield the cache data instead of indices. So, this special check ensures
109
+ # that data is not loaded a second time. PRs are welcome for improvements.
110
+ return index
111
+
112
+ prompt = self.prompts[index]
113
+ video = self.videos[index]
114
+ image = self.images[index]
115
+ train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
116
+
117
+ cache_dir = self.trainer.args.data_root / "cache"
118
+ video_latent_dir = cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str
119
+ prompt_embeddings_dir = cache_dir / "prompt_embeddings"
120
+ video_latent_dir.mkdir(parents=True, exist_ok=True)
121
+ prompt_embeddings_dir.mkdir(parents=True, exist_ok=True)
122
+
123
+ prompt_hash = str(hashlib.sha256(prompt.encode()).hexdigest())
124
+ prompt_embedding_path = prompt_embeddings_dir / (prompt_hash + ".safetensors")
125
+ encoded_video_path = video_latent_dir / (video.stem + ".safetensors")
126
+
127
+ if prompt_embedding_path.exists():
128
+ prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
129
+ logger.debug(
130
+ f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}",
131
+ main_process_only=False,
132
+ )
133
+ else:
134
+ prompt_embedding = self.encode_text(prompt)
135
+ prompt_embedding = prompt_embedding.to("cpu")
136
+ # [1, seq_len, hidden_size] -> [seq_len, hidden_size]
137
+ prompt_embedding = prompt_embedding[0]
138
+ save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path)
139
+ logger.info(f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False)
140
+
141
+ if encoded_video_path.exists():
142
+ encoded_video = load_file(encoded_video_path)["encoded_video"]
143
+ logger.debug(f"Loaded encoded video from {encoded_video_path}", main_process_only=False)
144
+ # shape of image: [C, H, W]
145
+ _, image = self.preprocess(None, self.images[index])
146
+ image = self.image_transform(image)
147
+ else:
148
+ frames, image = self.preprocess(video, image)
149
+ frames = frames.to(self.device)
150
+ image = image.to(self.device)
151
+ image = self.image_transform(image)
152
+ # Current shape of frames: [F, C, H, W]
153
+ frames = self.video_transform(frames)
154
+
155
+ # Convert to [B, C, F, H, W]
156
+ frames = frames.unsqueeze(0)
157
+ frames = frames.permute(0, 2, 1, 3, 4).contiguous()
158
+ encoded_video = self.encode_video(frames)
159
+
160
+ # [1, C, F, H, W] -> [C, F, H, W]
161
+ encoded_video = encoded_video[0]
162
+ encoded_video = encoded_video.to("cpu")
163
+ image = image.to("cpu")
164
+ save_file({"encoded_video": encoded_video}, encoded_video_path)
165
+ logger.info(f"Saved encoded video to {encoded_video_path}", main_process_only=False)
166
+
167
+ # shape of encoded_video: [C, F, H, W]
168
+ # shape of image: [C, H, W]
169
+ return {
170
+ "image": image,
171
+ "prompt_embedding": prompt_embedding,
172
+ "encoded_video": encoded_video,
173
+ "video_metadata": {
174
+ "num_frames": encoded_video.shape[1],
175
+ "height": encoded_video.shape[2],
176
+ "width": encoded_video.shape[3],
177
+ },
178
+ }
179
+
180
+ def preprocess(self, video_path: Path | None, image_path: Path | None) -> Tuple[torch.Tensor, torch.Tensor]:
181
+ """
182
+ Loads and preprocesses a video and an image.
183
+ If either path is None, no preprocessing will be done for that input.
184
+
185
+ Args:
186
+ video_path: Path to the video file to load
187
+ image_path: Path to the image file to load
188
+
189
+ Returns:
190
+ A tuple containing:
191
+ - video(torch.Tensor) of shape [F, C, H, W] where F is number of frames,
192
+ C is number of channels, H is height and W is width
193
+ - image(torch.Tensor) of shape [C, H, W]
194
+ """
195
+ raise NotImplementedError("Subclass must implement this method")
196
+
197
+ def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
198
+ """
199
+ Applies transformations to a video.
200
+
201
+ Args:
202
+ frames (torch.Tensor): A 4D tensor representing a video
203
+ with shape [F, C, H, W] where:
204
+ - F is number of frames
205
+ - C is number of channels (3 for RGB)
206
+ - H is height
207
+ - W is width
208
+
209
+ Returns:
210
+ torch.Tensor: The transformed video tensor
211
+ """
212
+ raise NotImplementedError("Subclass must implement this method")
213
+
214
+ def image_transform(self, image: torch.Tensor) -> torch.Tensor:
215
+ """
216
+ Applies transformations to an image.
217
+
218
+ Args:
219
+ image (torch.Tensor): A 3D tensor representing an image
220
+ with shape [C, H, W] where:
221
+ - C is number of channels (3 for RGB)
222
+ - H is height
223
+ - W is width
224
+
225
+ Returns:
226
+ torch.Tensor: The transformed image tensor
227
+ """
228
+ raise NotImplementedError("Subclass must implement this method")
229
+
230
+
231
+ class I2VDatasetWithResize(BaseI2VDataset):
232
+ """
233
+ A dataset class for image-to-video generation that resizes inputs to fixed dimensions.
234
+
235
+ This class preprocesses videos and images by resizing them to specified dimensions:
236
+ - Videos are resized to max_num_frames x height x width
237
+ - Images are resized to height x width
238
+
239
+ Args:
240
+ max_num_frames (int): Maximum number of frames to extract from videos
241
+ height (int): Target height for resizing videos and images
242
+ width (int): Target width for resizing videos and images
243
+ """
244
+
245
+ def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None:
246
+ super().__init__(*args, **kwargs)
247
+
248
+ self.max_num_frames = max_num_frames
249
+ self.height = height
250
+ self.width = width
251
+
252
+ self.__frame_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
253
+ self.__image_transforms = self.__frame_transforms
254
+
255
+ @override
256
+ def preprocess(self, video_path: Path | None, image_path: Path | None) -> Tuple[torch.Tensor, torch.Tensor]:
257
+ if video_path is not None:
258
+ video = preprocess_video_with_resize(video_path, self.max_num_frames, self.height, self.width)
259
+ else:
260
+ video = None
261
+ if image_path is not None:
262
+ image = preprocess_image_with_resize(image_path, self.height, self.width)
263
+ else:
264
+ image = None
265
+ return video, image
266
+
267
+ @override
268
+ def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
269
+ return torch.stack([self.__frame_transforms(f) for f in frames], dim=0)
270
+
271
+ @override
272
+ def image_transform(self, image: torch.Tensor) -> torch.Tensor:
273
+ return self.__image_transforms(image)
274
+
275
+
276
+ class I2VDatasetWithBuckets(BaseI2VDataset):
277
+ def __init__(
278
+ self,
279
+ video_resolution_buckets: List[Tuple[int, int, int]],
280
+ vae_temporal_compression_ratio: int,
281
+ vae_height_compression_ratio: int,
282
+ vae_width_compression_ratio: int,
283
+ *args,
284
+ **kwargs,
285
+ ) -> None:
286
+ super().__init__(*args, **kwargs)
287
+
288
+ self.video_resolution_buckets = [
289
+ (
290
+ int(b[0] / vae_temporal_compression_ratio),
291
+ int(b[1] / vae_height_compression_ratio),
292
+ int(b[2] / vae_width_compression_ratio),
293
+ )
294
+ for b in video_resolution_buckets
295
+ ]
296
+ self.__frame_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
297
+ self.__image_transforms = self.__frame_transforms
298
+
299
+ @override
300
+ def preprocess(self, video_path: Path, image_path: Path) -> Tuple[torch.Tensor, torch.Tensor]:
301
+ video = preprocess_video_with_buckets(video_path, self.video_resolution_buckets)
302
+ image = preprocess_image_with_resize(image_path, video.shape[2], video.shape[3])
303
+ return video, image
304
+
305
+ @override
306
+ def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
307
+ return torch.stack([self.__frame_transforms(f) for f in frames], dim=0)
308
+
309
+ @override
310
+ def image_transform(self, image: torch.Tensor) -> torch.Tensor:
311
+ return self.__image_transforms(image)
finetune/datasets/i2v_flow_dataset.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ from pathlib import Path
3
+ from typing import TYPE_CHECKING, Any, Dict, List, Tuple
4
+ import json
5
+ import random
6
+
7
+ import torch
8
+ from accelerate.logging import get_logger
9
+ from safetensors.torch import load_file, save_file
10
+ from torch.utils.data import Dataset
11
+ from torchvision import transforms
12
+ from typing_extensions import override
13
+
14
+ from finetune.constants import LOG_LEVEL, LOG_NAME
15
+
16
+ from .utils import (
17
+ load_images,
18
+ load_images_from_videos,
19
+ load_prompts,
20
+ load_videos,
21
+ preprocess_image_with_resize,
22
+ preprocess_video_with_buckets,
23
+ preprocess_video_with_resize,
24
+ load_binary_mask_compressed,
25
+ )
26
+
27
+ import pdb
28
+
29
+ if TYPE_CHECKING:
30
+ from finetune.trainer import Trainer
31
+
32
+ # Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
33
+ # Very few bug reports but it happens. Look in decord Github issues for more relevant information.
34
+ import decord # isort:skip
35
+
36
+ decord.bridge.set_bridge("torch")
37
+
38
+ logger = get_logger(LOG_NAME, LOG_LEVEL)
39
+
40
+
41
+ class I2VFlowDataset(Dataset):
42
+ """
43
+ A dataset class for (image,flow)-to-video generation or image-to-flow_video that resizes inputs to fixed dimensions.
44
+
45
+ This class preprocesses videos and images by resizing them to specified dimensions:
46
+ - Videos are resized to max_num_frames x height x width
47
+ - Images are resized to height x width
48
+
49
+ Args:
50
+ max_num_frames (int): Maximum number of frames to extract from videos
51
+ height (int): Target height for resizing videos and images
52
+ width (int): Target width for resizing videos and images
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ max_num_frames: int,
58
+ height: int,
59
+ width: int,
60
+ data_root: str,
61
+ caption_column: str,
62
+ video_column: str,
63
+ image_column: str | None,
64
+ device: torch.device,
65
+ trainer: "Trainer" = None,
66
+ *args,
67
+ **kwargs
68
+ ) -> None:
69
+ data_root = Path(data_root)
70
+ metadata_path = data_root / "metadata_revised.jsonl"
71
+ assert metadata_path.is_file(), "For this dataset type, you need metadata.jsonl in the root path"
72
+
73
+ # Load metadata
74
+ # metadata = {
75
+ # "video_path": ...,
76
+ # "hash_code": ...,
77
+ # "prompt": ...,
78
+ # }
79
+ metadata = []
80
+ with open(metadata_path, "r") as f:
81
+ for line in f:
82
+ metadata.append( json.loads(line) )
83
+
84
+ self.prompts = [x["prompt"] for x in metadata]
85
+ if 'curated' in str(data_root).lower():
86
+ self.prompt_embeddings = [data_root / "prompt_embeddings" / (x["hash_code"] + '.safetensors') for x in metadata]
87
+ else:
88
+ self.prompt_embeddings = [data_root / "prompt_embeddings_revised" / (x["hash_code"] + '.safetensors') for x in metadata]
89
+ self.videos = [data_root / "video_latent" / "x".join(str(x) for x in trainer.args.train_resolution) / (x["hash_code"] + '.safetensors') for x in metadata]
90
+ self.images = [data_root / "first_frames" / (x["hash_code"] + '.png') for x in metadata]
91
+ self.flows = [data_root / "flow_direct_f_latent" / (x["hash_code"] + '.safetensors') for x in metadata]
92
+
93
+
94
+ # data_root = Path(data_root)
95
+ # self.prompts = load_prompts(data_root / caption_column)
96
+ # self.videos = load_videos(data_root / video_column)
97
+
98
+ self.trainer = trainer
99
+
100
+ self.device = device
101
+ self.encode_video = trainer.encode_video
102
+ self.encode_text = trainer.encode_text
103
+
104
+ # Check if number of prompts matches number of videos and images
105
+ if not (len(self.videos) == len(self.prompts) == len(self.images) == len(self.flows)):
106
+ raise ValueError(
107
+ f"Expected length of prompts, videos and images to be the same but found {len(self.prompts)=}, {len(self.videos)=}, {len(self.images)=} and {len(self.flows)=}. Please ensure that the number of caption prompts, videos and images match in your dataset."
108
+ )
109
+
110
+ self.max_num_frames = max_num_frames
111
+ self.height = height
112
+ self.width = width
113
+
114
+ self.__frame_transforms = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
115
+ self.__image_transforms = self.__frame_transforms
116
+
117
+ self.length = len(self.videos)
118
+
119
+ print(f"Dataset size: {self.length}")
120
+
121
+ def __len__(self) -> int:
122
+ return self.length
123
+
124
+ def load_data_pair(self, index):
125
+ # prompt = self.prompts[index]
126
+ prompt_embedding_path = self.prompt_embeddings[index]
127
+ encoded_video_path = self.videos[index]
128
+ encoded_flow_path = self.flows[index]
129
+ # mask_path = self.masks[index]
130
+ # image_path = self.images[index]
131
+ # train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
132
+
133
+ prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
134
+ encoded_video = load_file(encoded_video_path)["encoded_video"] # CFHW
135
+ encoded_flow = load_file(encoded_flow_path)["encoded_flow_f"] # CFHW
136
+
137
+ return prompt_embedding, encoded_video, encoded_flow
138
+
139
+ def __getitem__(self, index: int) -> Dict[str, Any]:
140
+ while True:
141
+ try:
142
+ prompt_embedding, encoded_video, encoded_flow = self.load_data_pair(index)
143
+ break
144
+ except Exception as e:
145
+ print(f"Error loading {self.prompt_embeddings[index]}: {str(e)}")
146
+ index = random.randint(0, self.length - 1)
147
+
148
+ image_path = self.images[index]
149
+ prompt = self.prompts[index]
150
+ train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
151
+
152
+ _, image = self.preprocess(None, image_path)
153
+ image = self.image_transform(image)
154
+
155
+
156
+ # shape of encoded_video: [C, F, H, W]
157
+ # shape and scale of image: [C, H, W], [-1,1]
158
+ return {
159
+ "image": image,
160
+ "prompt_embedding": prompt_embedding,
161
+ "encoded_video": encoded_video,
162
+ "encoded_flow": encoded_flow,
163
+ "video_metadata": {
164
+ "num_frames": encoded_video.shape[1],
165
+ "height": encoded_video.shape[2],
166
+ "width": encoded_video.shape[3],
167
+ },
168
+ }
169
+
170
+ @override
171
+ def preprocess(self, video_path: Path | None, image_path: Path | None) -> Tuple[torch.Tensor, torch.Tensor]:
172
+ if video_path is not None:
173
+ video = preprocess_video_with_resize(video_path, self.max_num_frames, self.height, self.width)
174
+ else:
175
+ video = None
176
+ if image_path is not None:
177
+ image = preprocess_image_with_resize(image_path, self.height, self.width)
178
+ else:
179
+ image = None
180
+ return video, image
181
+
182
+ @override
183
+ def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
184
+ return torch.stack([self.__frame_transforms(f) for f in frames], dim=0)
185
+
186
+ @override
187
+ def image_transform(self, image: torch.Tensor) -> torch.Tensor:
188
+ return self.__image_transforms(image)
finetune/datasets/t2v_dataset.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ from pathlib import Path
3
+ from typing import TYPE_CHECKING, Any, Dict, List, Tuple
4
+
5
+ import torch
6
+ from accelerate.logging import get_logger
7
+ from safetensors.torch import load_file, save_file
8
+ from torch.utils.data import Dataset
9
+ from torchvision import transforms
10
+ from typing_extensions import override
11
+
12
+ from finetune.constants import LOG_LEVEL, LOG_NAME
13
+
14
+ from .utils import load_prompts, load_videos, preprocess_video_with_buckets, preprocess_video_with_resize
15
+
16
+
17
+ if TYPE_CHECKING:
18
+ from finetune.trainer import Trainer
19
+
20
+ # Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
21
+ # Very few bug reports but it happens. Look in decord Github issues for more relevant information.
22
+ import decord # isort:skip
23
+
24
+ decord.bridge.set_bridge("torch")
25
+
26
+ logger = get_logger(LOG_NAME, LOG_LEVEL)
27
+
28
+
29
+ class BaseT2VDataset(Dataset):
30
+ """
31
+ Base dataset class for Text-to-Video (T2V) training.
32
+
33
+ This dataset loads prompts and videos for T2V training.
34
+
35
+ Args:
36
+ data_root (str): Root directory containing the dataset files
37
+ caption_column (str): Path to file containing text prompts/captions
38
+ video_column (str): Path to file containing video paths
39
+ device (torch.device): Device to load the data on
40
+ encode_video_fn (Callable[[torch.Tensor], torch.Tensor], optional): Function to encode videos
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ data_root: str,
46
+ caption_column: str,
47
+ video_column: str,
48
+ device: torch.device = None,
49
+ trainer: "Trainer" = None,
50
+ *args,
51
+ **kwargs,
52
+ ) -> None:
53
+ super().__init__()
54
+
55
+ data_root = Path(data_root)
56
+ self.prompts = load_prompts(data_root / caption_column)
57
+ self.videos = load_videos(data_root / video_column)
58
+ self.device = device
59
+ self.encode_video = trainer.encode_video
60
+ self.encode_text = trainer.encode_text
61
+ self.trainer = trainer
62
+
63
+ # Check if all video files exist
64
+ if any(not path.is_file() for path in self.videos):
65
+ raise ValueError(
66
+ f"Some video files were not found. Please ensure that all video files exist in the dataset directory. Missing file: {next(path for path in self.videos if not path.is_file())}"
67
+ )
68
+
69
+ # Check if number of prompts matches number of videos
70
+ if len(self.videos) != len(self.prompts):
71
+ raise ValueError(
72
+ f"Expected length of prompts and videos to be the same but found {len(self.prompts)=} and {len(self.videos)=}. Please ensure that the number of caption prompts and videos match in your dataset."
73
+ )
74
+
75
+ def __len__(self) -> int:
76
+ return len(self.videos)
77
+
78
+ def __getitem__(self, index: int) -> Dict[str, Any]:
79
+ if isinstance(index, list):
80
+ # Here, index is actually a list of data objects that we need to return.
81
+ # The BucketSampler should ideally return indices. But, in the sampler, we'd like
82
+ # to have information about num_frames, height and width. Since this is not stored
83
+ # as metadata, we need to read the video to get this information. You could read this
84
+ # information without loading the full video in memory, but we do it anyway. In order
85
+ # to not load the video twice (once to get the metadata, and once to return the loaded video
86
+ # based on sampled indices), we cache it in the BucketSampler. When the sampler is
87
+ # to yield, we yield the cache data instead of indices. So, this special check ensures
88
+ # that data is not loaded a second time. PRs are welcome for improvements.
89
+ return index
90
+
91
+ prompt = self.prompts[index]
92
+ video = self.videos[index]
93
+ train_resolution_str = "x".join(str(x) for x in self.trainer.args.train_resolution)
94
+
95
+ cache_dir = self.trainer.args.data_root / "cache"
96
+ video_latent_dir = cache_dir / "video_latent" / self.trainer.args.model_name / train_resolution_str
97
+ prompt_embeddings_dir = cache_dir / "prompt_embeddings"
98
+ video_latent_dir.mkdir(parents=True, exist_ok=True)
99
+ prompt_embeddings_dir.mkdir(parents=True, exist_ok=True)
100
+
101
+ prompt_hash = str(hashlib.sha256(prompt.encode()).hexdigest())
102
+ prompt_embedding_path = prompt_embeddings_dir / (prompt_hash + ".safetensors")
103
+ encoded_video_path = video_latent_dir / (video.stem + ".safetensors")
104
+
105
+ if prompt_embedding_path.exists():
106
+ prompt_embedding = load_file(prompt_embedding_path)["prompt_embedding"]
107
+ logger.debug(
108
+ f"process {self.trainer.accelerator.process_index}: Loaded prompt embedding from {prompt_embedding_path}",
109
+ main_process_only=False,
110
+ )
111
+ else:
112
+ prompt_embedding = self.encode_text(prompt)
113
+ prompt_embedding = prompt_embedding.to("cpu")
114
+ # [1, seq_len, hidden_size] -> [seq_len, hidden_size]
115
+ prompt_embedding = prompt_embedding[0]
116
+ save_file({"prompt_embedding": prompt_embedding}, prompt_embedding_path)
117
+ logger.info(f"Saved prompt embedding to {prompt_embedding_path}", main_process_only=False)
118
+
119
+ if encoded_video_path.exists():
120
+ # encoded_video = torch.load(encoded_video_path, weights_only=True)
121
+ encoded_video = load_file(encoded_video_path)["encoded_video"]
122
+ logger.debug(f"Loaded encoded video from {encoded_video_path}", main_process_only=False)
123
+ # shape of image: [C, H, W]
124
+ else:
125
+ frames = self.preprocess(video)
126
+ frames = frames.to(self.device)
127
+ # Current shape of frames: [F, C, H, W]
128
+ frames = self.video_transform(frames)
129
+ # Convert to [B, C, F, H, W]
130
+ frames = frames.unsqueeze(0)
131
+ frames = frames.permute(0, 2, 1, 3, 4).contiguous()
132
+ encoded_video = self.encode_video(frames)
133
+
134
+ # [1, C, F, H, W] -> [C, F, H, W]
135
+ encoded_video = encoded_video[0]
136
+ encoded_video = encoded_video.to("cpu")
137
+ save_file({"encoded_video": encoded_video}, encoded_video_path)
138
+ logger.info(f"Saved encoded video to {encoded_video_path}", main_process_only=False)
139
+
140
+ # shape of encoded_video: [C, F, H, W]
141
+ return {
142
+ "prompt_embedding": prompt_embedding,
143
+ "encoded_video": encoded_video,
144
+ "video_metadata": {
145
+ "num_frames": encoded_video.shape[1],
146
+ "height": encoded_video.shape[2],
147
+ "width": encoded_video.shape[3],
148
+ },
149
+ }
150
+
151
+ def preprocess(self, video_path: Path) -> torch.Tensor:
152
+ """
153
+ Loads and preprocesses a video.
154
+
155
+ Args:
156
+ video_path: Path to the video file to load.
157
+
158
+ Returns:
159
+ torch.Tensor: Video tensor of shape [F, C, H, W] where:
160
+ - F is number of frames
161
+ - C is number of channels (3 for RGB)
162
+ - H is height
163
+ - W is width
164
+ """
165
+ raise NotImplementedError("Subclass must implement this method")
166
+
167
+ def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
168
+ """
169
+ Applies transformations to a video.
170
+
171
+ Args:
172
+ frames (torch.Tensor): A 4D tensor representing a video
173
+ with shape [F, C, H, W] where:
174
+ - F is number of frames
175
+ - C is number of channels (3 for RGB)
176
+ - H is height
177
+ - W is width
178
+
179
+ Returns:
180
+ torch.Tensor: The transformed video tensor with the same shape as the input
181
+ """
182
+ raise NotImplementedError("Subclass must implement this method")
183
+
184
+
185
+ class T2VDatasetWithResize(BaseT2VDataset):
186
+ """
187
+ A dataset class for text-to-video generation that resizes inputs to fixed dimensions.
188
+
189
+ This class preprocesses videos by resizing them to specified dimensions:
190
+ - Videos are resized to max_num_frames x height x width
191
+
192
+ Args:
193
+ max_num_frames (int): Maximum number of frames to extract from videos
194
+ height (int): Target height for resizing videos
195
+ width (int): Target width for resizing videos
196
+ """
197
+
198
+ def __init__(self, max_num_frames: int, height: int, width: int, *args, **kwargs) -> None:
199
+ super().__init__(*args, **kwargs)
200
+
201
+ self.max_num_frames = max_num_frames
202
+ self.height = height
203
+ self.width = width
204
+
205
+ self.__frame_transform = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
206
+
207
+ @override
208
+ def preprocess(self, video_path: Path) -> torch.Tensor:
209
+ return preprocess_video_with_resize(
210
+ video_path,
211
+ self.max_num_frames,
212
+ self.height,
213
+ self.width,
214
+ )
215
+
216
+ @override
217
+ def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
218
+ return torch.stack([self.__frame_transform(f) for f in frames], dim=0)
219
+
220
+
221
+ class T2VDatasetWithBuckets(BaseT2VDataset):
222
+ def __init__(
223
+ self,
224
+ video_resolution_buckets: List[Tuple[int, int, int]],
225
+ vae_temporal_compression_ratio: int,
226
+ vae_height_compression_ratio: int,
227
+ vae_width_compression_ratio: int,
228
+ *args,
229
+ **kwargs,
230
+ ) -> None:
231
+ """ """
232
+ super().__init__(*args, **kwargs)
233
+
234
+ self.video_resolution_buckets = [
235
+ (
236
+ int(b[0] / vae_temporal_compression_ratio),
237
+ int(b[1] / vae_height_compression_ratio),
238
+ int(b[2] / vae_width_compression_ratio),
239
+ )
240
+ for b in video_resolution_buckets
241
+ ]
242
+
243
+ self.__frame_transform = transforms.Compose([transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)])
244
+
245
+ @override
246
+ def preprocess(self, video_path: Path) -> torch.Tensor:
247
+ return preprocess_video_with_buckets(video_path, self.video_resolution_buckets)
248
+
249
+ @override
250
+ def video_transform(self, frames: torch.Tensor) -> torch.Tensor:
251
+ return torch.stack([self.__frame_transform(f) for f in frames], dim=0)
finetune/datasets/utils.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import List, Tuple
4
+
5
+ import cv2
6
+ import torch
7
+ from torchvision.transforms.functional import resize
8
+ from einops import repeat, rearrange
9
+
10
+ # Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error
11
+ # Very few bug reports but it happens. Look in decord Github issues for more relevant information.
12
+ import decord # isort:skip
13
+
14
+ decord.bridge.set_bridge("torch")
15
+
16
+ from PIL import Image
17
+ import numpy as np
18
+ import pdb
19
+
20
+ ########## loaders ##########
21
+
22
+
23
+ def load_prompts(prompt_path: Path) -> List[str]:
24
+ with open(prompt_path, "r", encoding="utf-8") as file:
25
+ return [line.strip() for line in file.readlines() if len(line.strip()) > 0]
26
+
27
+
28
+ def load_videos(video_path: Path) -> List[Path]:
29
+ with open(video_path, "r", encoding="utf-8") as file:
30
+ return [video_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0]
31
+
32
+
33
+ def load_images(image_path: Path) -> List[Path]:
34
+ with open(image_path, "r", encoding="utf-8") as file:
35
+ return [image_path.parent / line.strip() for line in file.readlines() if len(line.strip()) > 0]
36
+
37
+
38
+ def load_images_from_videos(videos_path: List[Path]) -> List[Path]:
39
+ first_frames_dir = videos_path[0].parent.parent / "first_frames"
40
+ first_frames_dir.mkdir(exist_ok=True)
41
+
42
+ first_frame_paths = []
43
+ for video_path in videos_path:
44
+ frame_path = first_frames_dir / f"{video_path.stem}.png"
45
+ if frame_path.exists():
46
+ first_frame_paths.append(frame_path)
47
+ continue
48
+
49
+ # Open video
50
+ cap = cv2.VideoCapture(str(video_path))
51
+
52
+ # Read first frame
53
+ ret, frame = cap.read()
54
+ if not ret:
55
+ raise RuntimeError(f"Failed to read video: {video_path}")
56
+
57
+ # Save frame as PNG with same name as video
58
+ cv2.imwrite(str(frame_path), frame)
59
+ logging.info(f"Saved first frame to {frame_path}")
60
+
61
+ # Release video capture
62
+ cap.release()
63
+
64
+ first_frame_paths.append(frame_path)
65
+
66
+ return first_frame_paths
67
+
68
+
69
+ def load_binary_mask_compressed(path, shape, device, dtype):
70
+ # shape: (F,C,H,W), C=1
71
+ with open(path, 'rb') as f:
72
+ packed = np.frombuffer(f.read(), dtype=np.uint8)
73
+ unpacked = np.unpackbits(packed)[:np.prod(shape)]
74
+ mask_loaded = torch.from_numpy(unpacked).to(device, dtype).reshape(shape)
75
+
76
+ mask_interp = torch.nn.functional.interpolate(rearrange(mask_loaded, 'f c h w -> c f h w').unsqueeze(0), size=(shape[0]//4+1, shape[2]//8, shape[3]//8), mode='trilinear', align_corners=False).squeeze(0) # CFHW
77
+ mask_interp[mask_interp>=0.5] = 1.0
78
+ mask_interp[mask_interp<0.5] = 0.0
79
+
80
+ return rearrange(mask_loaded, 'f c h w -> c f h w'), mask_interp
81
+
82
+ ########## preprocessors ##########
83
+
84
+
85
+ def preprocess_image_with_resize(
86
+ image_path: Path | str,
87
+ height: int,
88
+ width: int,
89
+ ) -> torch.Tensor:
90
+ """
91
+ Loads and resizes a single image.
92
+
93
+ Args:
94
+ image_path: Path to the image file.
95
+ height: Target height for resizing.
96
+ width: Target width for resizing.
97
+
98
+ Returns:
99
+ torch.Tensor: Image tensor with shape [C, H, W] where:
100
+ C = number of channels (3 for RGB)
101
+ H = height
102
+ W = width
103
+ """
104
+ if isinstance(image_path, str):
105
+ image_path = Path(image_path)
106
+ # image = cv2.imread(image_path.as_posix())
107
+ # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
108
+ # image = cv2.resize(image, (width, height))
109
+ # image = torch.from_numpy(image).float()
110
+ # image = image.permute(2, 0, 1).contiguous()
111
+
112
+ image = np.array(Image.open(image_path.as_posix()).resize((width, height)))
113
+ image = torch.from_numpy(image).float()
114
+ image = image.permute(2, 0, 1).contiguous()
115
+
116
+ return image
117
+
118
+
119
+ def preprocess_video_with_resize(
120
+ video_path: Path | str,
121
+ max_num_frames: int,
122
+ height: int,
123
+ width: int,
124
+ ) -> torch.Tensor:
125
+ """
126
+ Loads and resizes a single video.
127
+
128
+ The function processes the video through these steps:
129
+ 1. If video frame count > max_num_frames, downsample frames evenly
130
+ 2. If video dimensions don't match (height, width), resize frames
131
+
132
+ Args:
133
+ video_path: Path to the video file.
134
+ max_num_frames: Maximum number of frames to keep.
135
+ height: Target height for resizing.
136
+ width: Target width for resizing.
137
+
138
+ Returns:
139
+ A torch.Tensor with shape [F, C, H, W] where:
140
+ F = number of frames
141
+ C = number of channels (3 for RGB)
142
+ H = height
143
+ W = width
144
+ """
145
+ if isinstance(video_path, str):
146
+ video_path = Path(video_path)
147
+ video_reader = decord.VideoReader(uri=video_path.as_posix(), width=width, height=height)
148
+ video_num_frames = len(video_reader)
149
+ if video_num_frames < max_num_frames:
150
+ # Get all frames first
151
+ frames = video_reader.get_batch(list(range(video_num_frames)))
152
+ # Repeat the last frame until we reach max_num_frames
153
+ last_frame = frames[-1:]
154
+ num_repeats = max_num_frames - video_num_frames
155
+ repeated_frames = last_frame.repeat(num_repeats, 1, 1, 1)
156
+ frames = torch.cat([frames, repeated_frames], dim=0)
157
+ return frames.float().permute(0, 3, 1, 2).contiguous()
158
+ else:
159
+ indices = list(range(0, video_num_frames, video_num_frames // max_num_frames))
160
+ frames = video_reader.get_batch(indices)
161
+ import pdb
162
+ pdb.set_trace()
163
+ frames = frames[:max_num_frames].float()
164
+ frames = frames.permute(0, 3, 1, 2).contiguous()
165
+ return frames
166
+
167
+
168
+ def preprocess_video_with_buckets(
169
+ video_path: Path,
170
+ resolution_buckets: List[Tuple[int, int, int]],
171
+ ) -> torch.Tensor:
172
+ """
173
+ Args:
174
+ video_path: Path to the video file.
175
+ resolution_buckets: List of tuples (num_frames, height, width) representing
176
+ available resolution buckets.
177
+
178
+ Returns:
179
+ torch.Tensor: Video tensor with shape [F, C, H, W] where:
180
+ F = number of frames
181
+ C = number of channels (3 for RGB)
182
+ H = height
183
+ W = width
184
+
185
+ The function processes the video through these steps:
186
+ 1. Finds nearest frame bucket <= video frame count
187
+ 2. Downsamples frames evenly to match bucket size
188
+ 3. Finds nearest resolution bucket based on dimensions
189
+ 4. Resizes frames to match bucket resolution
190
+ """
191
+ video_reader = decord.VideoReader(uri=video_path.as_posix())
192
+ video_num_frames = len(video_reader)
193
+ resolution_buckets = [bucket for bucket in resolution_buckets if bucket[0] <= video_num_frames]
194
+ if len(resolution_buckets) == 0:
195
+ raise ValueError(f"video frame count in {video_path} is less than all frame buckets {resolution_buckets}")
196
+
197
+ nearest_frame_bucket = min(
198
+ resolution_buckets,
199
+ key=lambda bucket: video_num_frames - bucket[0],
200
+ default=1,
201
+ )[0]
202
+ frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket))
203
+ frames = video_reader.get_batch(frame_indices)
204
+ frames = frames[:nearest_frame_bucket].float()
205
+ frames = frames.permute(0, 3, 1, 2).contiguous()
206
+
207
+ nearest_res = min(resolution_buckets, key=lambda x: abs(x[1] - frames.shape[2]) + abs(x[2] - frames.shape[3]))
208
+ nearest_res = (nearest_res[1], nearest_res[2])
209
+ frames = torch.stack([resize(f, nearest_res) for f in frames], dim=0)
210
+
211
+ return frames
finetune/models/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from pathlib import Path
3
+
4
+
5
+ package_dir = Path(__file__).parent
6
+
7
+ for subdir in package_dir.iterdir():
8
+ if subdir.is_dir() and not subdir.name.startswith("_"):
9
+ for module_path in subdir.glob("*.py"):
10
+ module_name = module_path.stem
11
+ full_module_name = f".{subdir.name}.{module_name}"
12
+ importlib.import_module(full_module_name, package=__name__)
finetune/models/cogvideox_i2v/flovd_OMSM_lora_trainer.py ADDED
@@ -0,0 +1,748 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Tuple
2
+ from pathlib import Path
3
+ import os
4
+ import hashlib
5
+ import json
6
+ import random
7
+ import wandb
8
+ import math
9
+ import numpy as np
10
+ from einops import rearrange, repeat
11
+ from safetensors.torch import load_file, save_file
12
+ from accelerate.logging import get_logger
13
+
14
+ import torch
15
+
16
+ from accelerate.utils import gather_object
17
+
18
+ from diffusers import (
19
+ AutoencoderKLCogVideoX,
20
+ CogVideoXDPMScheduler,
21
+ CogVideoXImageToVideoPipeline,
22
+ CogVideoXTransformer3DModel,
23
+ )
24
+ from diffusers.utils.export_utils import export_to_video
25
+
26
+ from finetune.pipeline.flovd_OMSM_cogvideox_pipeline import FloVDOMSMCogVideoXImageToVideoPipeline
27
+ from finetune.constants import LOG_LEVEL, LOG_NAME
28
+
29
+ from diffusers.models.embeddings import get_3d_rotary_pos_embed
30
+ from PIL import Image
31
+ from numpy import dtype
32
+ from transformers import AutoTokenizer, T5EncoderModel
33
+ from typing_extensions import override
34
+
35
+ from finetune.schemas import Args, Components, State
36
+ from finetune.trainer import Trainer
37
+ from finetune.utils import (
38
+ cast_training_params,
39
+ free_memory,
40
+ get_memory_statistics,
41
+ string_to_filename,
42
+ unwrap_model,
43
+ )
44
+ from finetune.datasets.utils import (
45
+ preprocess_image_with_resize,
46
+ load_binary_mask_compressed,
47
+ )
48
+ from finetune.modules.camera_sampler import SampleManualCam
49
+ from finetune.modules.camera_flow_generator import CameraFlowGenerator
50
+ from finetune.modules.utils import get_camera_flow_generator_input, forward_bilinear_splatting, flow_to_color
51
+
52
+ from ..utils import register
53
+
54
+ import sys
55
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
56
+
57
+ import pdb
58
+
59
+ logger = get_logger(LOG_NAME, LOG_LEVEL)
60
+
61
+ class FloVDOMSMCogVideoXI2VLoraTrainer(Trainer):
62
+ UNLOAD_LIST = ["text_encoder"]
63
+
64
+ @override
65
+ def __init__(self, args: Args) -> None:
66
+ super().__init__(args)
67
+
68
+
69
+ @override
70
+ def load_components(self) -> Dict[str, Any]:
71
+ # TODO. Change the pipeline and ...
72
+ components = Components()
73
+ model_path = str(self.args.model_path)
74
+
75
+ components.pipeline_cls = FloVDOMSMCogVideoXImageToVideoPipeline
76
+
77
+ components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
78
+
79
+ components.text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder")
80
+
81
+ components.transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer")
82
+
83
+ components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
84
+
85
+ components.scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
86
+
87
+ return components
88
+
89
+
90
+ @override
91
+ def initialize_pipeline(self) -> FloVDOMSMCogVideoXImageToVideoPipeline:
92
+ # TODO. Change the pipeline and ...
93
+ pipe = FloVDOMSMCogVideoXImageToVideoPipeline(
94
+ tokenizer=self.components.tokenizer,
95
+ text_encoder=self.components.text_encoder,
96
+ vae=self.components.vae,
97
+ transformer=unwrap_model(self.accelerator, self.components.transformer),
98
+ scheduler=self.components.scheduler,
99
+ )
100
+ return pipe
101
+
102
+ def initialize_flow_generator(self):
103
+ depth_estimator_kwargs = {
104
+ "target": 'modules.depth_warping.depth_warping.DepthWarping_wrapper',
105
+ "kwargs": {
106
+ "ckpt_path": '/workspace/workspace/checkpoints/depth_anything/depth_anything_v2_metric_hypersim_vitb.pth',
107
+ "model_config": {
108
+ "max_depth": 20,
109
+ "encoder": 'vitb',
110
+ "features": 128,
111
+ "out_channels": [96, 192, 384, 768],
112
+ }
113
+
114
+ }
115
+ }
116
+
117
+ return CameraFlowGenerator(depth_estimator_kwargs)
118
+
119
+ @override
120
+ def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
121
+ ret = {"encoded_videos": [], "prompt_embedding": [], "images": [], "encoded_flow": []}
122
+
123
+ for sample in samples:
124
+ encoded_video = sample["encoded_video"]
125
+ prompt_embedding = sample["prompt_embedding"]
126
+ image = sample["image"]
127
+ encoded_flow = sample["encoded_flow"]
128
+
129
+ ret["encoded_videos"].append(encoded_video)
130
+ ret["prompt_embedding"].append(prompt_embedding)
131
+ ret["images"].append(image)
132
+ ret["encoded_flow"].append(encoded_flow)
133
+
134
+ ret["encoded_videos"] = torch.stack(ret["encoded_videos"])
135
+ ret["prompt_embedding"] = torch.stack(ret["prompt_embedding"])
136
+ ret["images"] = torch.stack(ret["images"])
137
+ ret["encoded_flow"] = torch.stack(ret["encoded_flow"])
138
+
139
+ return ret
140
+
141
+
142
+ @override
143
+ def compute_loss(self, batch) -> torch.Tensor:
144
+ prompt_embedding = batch["prompt_embedding"]
145
+ images = batch["images"]
146
+ latent_flow = batch["encoded_flow"]
147
+
148
+ # Shape of prompt_embedding: [B, seq_len, hidden_size]
149
+ # Shape of images: [B, C, H, W]
150
+ # Shape of latent_flow: [B, C, F, H, W]
151
+
152
+ patch_size_t = self.state.transformer_config.patch_size_t # WJ: None in i2v setting...
153
+ if patch_size_t is not None:
154
+ # ncopy = latent.shape[2] % patch_size_t
155
+ # # Copy the first frame ncopy times to match patch_size_t
156
+ # first_frame = latent[:, :, :1, :, :] # Get first frame [B, C, 1, H, W]
157
+ # latent = torch.cat([first_frame.repeat(1, 1, ncopy, 1, 1), latent], dim=2)
158
+ # assert latent.shape[2] % patch_size_t == 0
159
+ raise NotImplementedError("Do not use the case whose patch_size_t is not None")
160
+
161
+ batch_size, num_channels, num_frames, height, width = latent_flow.shape
162
+
163
+ # Get prompt embeddings
164
+ _, seq_len, _ = prompt_embedding.shape
165
+ prompt_embedding = prompt_embedding.view(batch_size, seq_len, -1).to(dtype=latent_flow.dtype)
166
+
167
+ # Add frame dimension to images [B,C,H,W] -> [B,C,F,H,W]
168
+ images = images.unsqueeze(2)
169
+ # Add noise to images
170
+ image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=self.accelerator.device)
171
+ image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=images.dtype)
172
+ noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None]
173
+ image_latent_dist = self.components.vae.encode(noisy_images.to(dtype=self.components.vae.dtype)).latent_dist
174
+ image_latents = image_latent_dist.sample() * self.components.vae.config.scaling_factor
175
+
176
+ # Sample a random timestep for each sample
177
+ timesteps = torch.randint(
178
+ 0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device
179
+ )
180
+ timesteps = timesteps.long()
181
+
182
+ # from [B, C, F, H, W] to [B, F, C, H, W]
183
+ latent_flow = latent_flow.permute(0, 2, 1, 3, 4)
184
+ image_latents = image_latents.permute(0, 2, 1, 3, 4)
185
+ assert (image_latents.shape[0], *image_latents.shape[2:]) == (latent_flow.shape[0], *latent_flow.shape[2:])
186
+
187
+ # Padding image_latents to the same frame number as latent
188
+ padding_shape = (latent_flow.shape[0], latent_flow.shape[1] - 1, *latent_flow.shape[2:])
189
+ latent_padding = image_latents.new_zeros(padding_shape)
190
+ image_latents = torch.cat([image_latents, latent_padding], dim=1)
191
+
192
+ # Add noise to latent
193
+ noise = torch.randn_like(latent_flow)
194
+ latent_flow_noisy = self.components.scheduler.add_noise(latent_flow, noise, timesteps)
195
+
196
+
197
+ # Concatenate latent and image_latents in the channel dimension
198
+ latent_flow_img_noisy = torch.cat([latent_flow_noisy, image_latents], dim=2)
199
+
200
+ # Prepare rotary embeds
201
+ vae_scale_factor_spatial = 2 ** (len(self.components.vae.config.block_out_channels) - 1)
202
+ transformer_config = self.state.transformer_config
203
+ rotary_emb = (
204
+ self.prepare_rotary_positional_embeddings(
205
+ height=height * vae_scale_factor_spatial,
206
+ width=width * vae_scale_factor_spatial,
207
+ num_frames=num_frames,
208
+ transformer_config=transformer_config,
209
+ vae_scale_factor_spatial=vae_scale_factor_spatial,
210
+ device=self.accelerator.device,
211
+ )
212
+ if transformer_config.use_rotary_positional_embeddings
213
+ else None
214
+ )
215
+
216
+ # Predict noise, For CogVideoX1.5 Only.
217
+ ofs_emb = (
218
+ None if self.state.transformer_config.ofs_embed_dim is None else latent_flow.new_full((1,), fill_value=2.0)
219
+ )
220
+
221
+ predicted_noise = self.components.transformer(
222
+ hidden_states=latent_flow_img_noisy,
223
+ encoder_hidden_states=prompt_embedding,
224
+ timestep=timesteps,
225
+ ofs=ofs_emb,
226
+ image_rotary_emb=rotary_emb,
227
+ return_dict=False,
228
+ )[0]
229
+
230
+ # Denoise
231
+ latent_pred = self.components.scheduler.get_velocity(predicted_noise, latent_flow_noisy, timesteps)
232
+
233
+ alphas_cumprod = self.components.scheduler.alphas_cumprod[timesteps]
234
+ weights = 1 / (1 - alphas_cumprod)
235
+ while len(weights.shape) < len(latent_pred.shape):
236
+ weights = weights.unsqueeze(-1)
237
+
238
+ loss = torch.mean((weights * (latent_pred - latent_flow) ** 2).reshape(batch_size, -1), dim=1)
239
+ loss = loss.mean()
240
+
241
+ return loss
242
+
243
+ def prepare_rotary_positional_embeddings(
244
+ self,
245
+ height: int,
246
+ width: int,
247
+ num_frames: int,
248
+ transformer_config: Dict,
249
+ vae_scale_factor_spatial: int,
250
+ device: torch.device,
251
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
252
+ grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
253
+ grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)
254
+
255
+ if transformer_config.patch_size_t is None:
256
+ base_num_frames = num_frames
257
+ else:
258
+ base_num_frames = (num_frames + transformer_config.patch_size_t - 1) // transformer_config.patch_size_t
259
+
260
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
261
+ embed_dim=transformer_config.attention_head_dim,
262
+ crops_coords=None,
263
+ grid_size=(grid_height, grid_width),
264
+ temporal_size=base_num_frames,
265
+ grid_type="slice",
266
+ max_size=(grid_height, grid_width),
267
+ device=device,
268
+ )
269
+
270
+ return freqs_cos, freqs_sin
271
+
272
+ # Validation
273
+
274
+ @override
275
+ def prepare_for_validation(self):
276
+ # Load from dataset?
277
+ # Data_root
278
+ # - metadata.jsonl
279
+ # - video_latent / args.resolution /
280
+ # - prompt_embeddings /
281
+ # - first_frames /
282
+ # - flow_direct_f_latent /
283
+
284
+ data_root = self.args.data_root
285
+ metadata_path = data_root / "metadata_revised.jsonl"
286
+ assert metadata_path.is_file(), "For this dataset type, you need metadata.jsonl or metadata_revised.jsonl in the root path"
287
+
288
+ # Load metadata
289
+ # metadata = {
290
+ # "video_path": ...,
291
+ # "hash_code": ...,
292
+ # "prompt": ...,
293
+ # }
294
+ metadata = []
295
+ with open(metadata_path, "r") as f:
296
+ for line in f:
297
+ metadata.append( json.loads(line) )
298
+
299
+ metadata = random.sample(metadata, self.args.max_scene)
300
+
301
+ prompts = [x["prompt"] for x in metadata]
302
+ if 'curated' in str(data_root).lower():
303
+ self.prompt_embeddings = [data_root / "prompt_embeddings" / (x["hash_code"] + '.safetensors') for x in metadata]
304
+ else:
305
+ self.prompt_embeddings = [data_root / "prompt_embeddings_revised" / (x["hash_code"] + '.safetensors') for x in metadata]
306
+ videos = [data_root / "video_latent" / "x".join(str(x) for x in self.args.train_resolution) / (x["hash_code"] + '.safetensors') for x in metadata]
307
+ images = [data_root / "first_frames" / (x["hash_code"] + '.png') for x in metadata]
308
+ flows = [data_root / "flow_direct_f_latent" / (x["hash_code"] + '.safetensors') for x in metadata]
309
+
310
+ # load prompt embedding
311
+ validation_prompts = []
312
+ validation_prompt_embeddings = []
313
+ validation_video_latents = []
314
+ validation_images = []
315
+ validation_flow_latents = []
316
+ for prompt, prompt_embedding, video_latent, image, flow_latent in zip(prompts, prompt_embeddings, videos, images, flows):
317
+ validation_prompts.append(prompt)
318
+ validation_prompt_embeddings.append(load_file(prompt_embedding)["prompt_embedding"].unsqueeze(0))
319
+ validation_video_latents.append(load_file(video_latent)["encoded_video"].unsqueeze(0))
320
+ validation_flow_latents.append(load_file(flow_latent)["encoded_flow_f"].unsqueeze(0))
321
+ # validation_images.append(preprocess_image_with_resize(image, self.args.train_resolution[1], self.args.train_resolution[2]))
322
+ validation_images.append(image)
323
+
324
+
325
+ validation_videos = [None] * len(validation_prompts)
326
+
327
+
328
+ self.state.validation_prompts = validation_prompts
329
+ self.state.validation_prompt_embeddings = validation_prompt_embeddings
330
+ self.state.validation_images = validation_images
331
+ self.state.validation_videos = validation_videos
332
+ self.state.validation_video_latents = validation_video_latents
333
+ self.state.validation_flow_latents = validation_flow_latents
334
+
335
+ # Debug..
336
+ self.validate(0)
337
+
338
+
339
+ @override
340
+ def validation_step(
341
+ self, eval_data: Dict[str, Any], pipe: FloVDOMSMCogVideoXImageToVideoPipeline
342
+ ) -> List[Tuple[str, Image.Image | List[Image.Image]]]:
343
+ """
344
+ Return the data that needs to be saved. For videos, the data format is List[PIL],
345
+ and for images, the data format is PIL
346
+ """
347
+
348
+ prompt_embedding, image = eval_data["prompt_embedding"], eval_data["image"]
349
+
350
+ flow_latent_generate = pipe(
351
+ num_frames=self.state.train_frames,
352
+ height=self.state.train_height,
353
+ width=self.state.train_width,
354
+ prompt=None,
355
+ prompt_embeds=prompt_embedding,
356
+ image=image,
357
+ generator=self.state.generator,
358
+ num_inference_steps=50,
359
+ output_type='latent'
360
+ ).frames[0]
361
+
362
+ flow_generate = decode_flow(flow_latent_generate.unsqueeze(0).to(self.accelerator.device), self.components.vae, flow_scale_factor=[60, 36]) # BF,C,H,W
363
+
364
+ return [("synthesized_flow", flow_generate)]
365
+
366
+
367
+ @override
368
+ def validate(self, step: int) -> None:
369
+ #TODO. Fix the codes!!!!
370
+ logger.info("Starting validation")
371
+
372
+ accelerator = self.accelerator
373
+ num_validation_samples = len(self.state.validation_prompts)
374
+
375
+ if num_validation_samples == 0:
376
+ logger.warning("No validation samples found. Skipping validation.")
377
+ return
378
+
379
+ self.components.transformer.eval()
380
+ torch.set_grad_enabled(False)
381
+
382
+ memory_statistics = get_memory_statistics()
383
+ logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}")
384
+
385
+ ##### Initialize pipeline #####
386
+ pipe = self.initialize_pipeline()
387
+ camera_flow_generator = self.initialize_flow_generator().to(device=self.accelerator.device, dtype=self.state.weight_dtype)
388
+
389
+ if self.state.using_deepspeed:
390
+ # Can't using model_cpu_offload in deepspeed,
391
+ # so we need to move all components in pipe to device
392
+ # pipe.to(self.accelerator.device, dtype=self.state.weight_dtype)
393
+ self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=["transformer"])
394
+ else:
395
+ # if not using deepspeed, use model_cpu_offload to further reduce memory usage
396
+ # Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage
397
+ pipe.enable_model_cpu_offload(device=self.accelerator.device)
398
+
399
+ # Convert all model weights to training dtype
400
+ # Note, this will change LoRA weights in self.components.transformer to training dtype, rather than keep them in fp32
401
+ pipe = pipe.to(dtype=self.state.weight_dtype)
402
+
403
+ #################################
404
+ all_processes_artifacts = []
405
+ for i in range(num_validation_samples):
406
+ if self.state.using_deepspeed and self.accelerator.deepspeed_plugin.zero_stage != 3:
407
+ # Skip current validation on all processes but one
408
+ if i % accelerator.num_processes != accelerator.process_index:
409
+ continue
410
+
411
+ prompt = self.state.validation_prompts[i]
412
+ image = self.state.validation_images[i]
413
+ video = self.state.validation_videos[i]
414
+ video_latent = self.state.validation_video_latents[i].permute(0,2,1,3,4) # [B,F,C,H,W] (e.g., [B, 13, 16, 60, 90])
415
+ prompt_embedding = self.state.validation_prompt_embeddings[i]
416
+ flow_latent = self.state.validation_flow_latents[i].permute(0,2,1,3,4) # [B,F,C,H,W] (e.g., [B, 13, 16, 60, 90])
417
+
418
+
419
+ if image is not None:
420
+ image = preprocess_image_with_resize(image, self.state.train_height, self.state.train_width)
421
+ image_torch = image.detach().clone()
422
+ # Convert image tensor (C, H, W) to PIL images
423
+ image = image.to(torch.uint8)
424
+ image = image.permute(1, 2, 0).cpu().numpy()
425
+ image = Image.fromarray(image)
426
+
427
+ if video is not None:
428
+ video = preprocess_video_with_resize(
429
+ video, self.state.train_frames, self.state.train_height, self.state.train_width
430
+ )
431
+ # Convert video tensor (F, C, H, W) to list of PIL images
432
+ video = video.round().clamp(0, 255).to(torch.uint8)
433
+ video = [Image.fromarray(frame.permute(1, 2, 0).cpu().numpy()) for frame in video]
434
+ else:
435
+ with torch.cuda.amp.autocast(enabled=True, dtype=self.state.weight_dtype):
436
+ try:
437
+ video_decoded = decode_latents(video_latent.to(self.accelerator.device), self.components.vae)
438
+ except:
439
+ pass
440
+ video_decoded = decode_latents(video_latent.to(self.accelerator.device), self.components.vae)
441
+ video = ((video_decoded + 1.) / 2. * 255.)[0].permute(1,0,2,3).float().clip(0., 255.).to(torch.uint8)
442
+ video = [Image.fromarray(frame.permute(1, 2, 0).cpu().numpy()) for frame in video]
443
+
444
+ with torch.cuda.amp.autocast(enabled=True, dtype=self.state.weight_dtype):
445
+ try:
446
+ flow_decoded = decode_flow(flow_latent.to(self.accelerator.device), self.components.vae, flow_scale_factor=[60, 36])
447
+ except:
448
+ pass
449
+ flow_decoded = decode_flow(flow_latent.to(self.accelerator.device), self.components.vae, flow_scale_factor=[60, 36]) # (BF)CHW (C=2)
450
+
451
+
452
+ logger.debug(
453
+ f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}",
454
+ main_process_only=False,
455
+ )
456
+ # validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe)
457
+ validation_artifacts = self.validation_step({"prompt_embedding": prompt_embedding, "image": image}, pipe)
458
+
459
+ if (
460
+ self.state.using_deepspeed
461
+ and self.accelerator.deepspeed_plugin.zero_stage == 3
462
+ and not accelerator.is_main_process
463
+ ):
464
+ continue
465
+
466
+ prompt_filename = string_to_filename(prompt)[:25]
467
+ # Calculate hash of reversed prompt as a unique identifier
468
+ reversed_prompt = prompt[::-1]
469
+ hash_suffix = hashlib.md5(reversed_prompt.encode()).hexdigest()[:5]
470
+
471
+ artifacts = {
472
+ "image": {"type": "image", "value": image},
473
+ "video": {"type": "video", "value": video},
474
+ }
475
+ for i, (artifact_type, artifact_value) in enumerate(validation_artifacts):
476
+ artifacts.update({f"artifact_{i}": {"type": artifact_type, "value": artifact_value}})
477
+
478
+ # Log flow
479
+ artifacts.update({f"artifact_flow_{i}": {"type": 'flow', "value": flow_decoded}})
480
+
481
+ # Log flow_warped_frames
482
+ image_tensor = repeat(rearrange(torch.tensor(np.array(image)).to(flow_decoded.device, torch.float), 'h w c -> 1 c h w'), 'b c h w -> (b f) c h w', f=flow_decoded.size(0)) # scale~(0,255) (BF) C H W
483
+ warped_video = forward_bilinear_splatting(image_tensor, flow_decoded.to(torch.float)) # if we have an occlusion mask from dataset, we can use it.
484
+ frame_list = []
485
+ for frame in warped_video:
486
+ frame = (frame.permute(1,2,0).float().detach().cpu().numpy()).astype(np.uint8).clip(0,255)
487
+ frame_list.append(Image.fromarray(frame))
488
+
489
+ artifacts.update({f"artifact_warped_video_{i}": {"type": 'warped_video', "value": frame_list}})
490
+
491
+ # Log synthesized_flow_wraped_frames
492
+ # artifact_value: synthesized optical flow
493
+ warped_video2 = forward_bilinear_splatting(image_tensor, artifact_value.to(torch.float)) # if we have an occlusion mask from dataset, we can use it. For OMSM, do not use.
494
+ frame_list2 = []
495
+ for frame in warped_video2:
496
+ frame = (frame.permute(1,2,0).float().detach().cpu().numpy()).astype(np.uint8).clip(0,255)
497
+ frame_list2.append(Image.fromarray(frame))
498
+
499
+ artifacts.update({f"artifact_synthesized_flow_warped_video_{i}": {"type": 'synthesized_flow_warped_video', "value": frame_list2}})
500
+
501
+
502
+ logger.debug(
503
+ f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}",
504
+ main_process_only=False,
505
+ )
506
+
507
+ for key, value in list(artifacts.items()):
508
+ artifact_type = value["type"]
509
+ artifact_value = value["value"]
510
+ if artifact_type not in ["image", "video", "flow", "warped_video", "synthesized_flow", "synthesized_flow_warped_video"] or artifact_value is None:
511
+ continue
512
+
513
+ extension = "png" if artifact_type == "image" else "mp4"
514
+ if artifact_type == "warped_video" or artifact_type == "synthesized_flow_warped_video":
515
+ filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}-{hash_suffix}_{artifact_type}.{extension}"
516
+ elif artifact_type == "synthesized_flow":
517
+ filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}-{hash_suffix}_synthesized_flow.{extension}"
518
+ elif artifact_type == "flow":
519
+ filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}-{hash_suffix}_original_flow.{extension}"
520
+ else:
521
+ filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}-{hash_suffix}.{extension}"
522
+ validation_path = self.args.output_dir / "validation_res"
523
+ validation_path.mkdir(parents=True, exist_ok=True)
524
+ filename = str(validation_path / filename)
525
+
526
+ if artifact_type == "image":
527
+ logger.debug(f"Saving image to {filename}")
528
+ artifact_value.save(filename)
529
+ artifact_value = wandb.Image(filename)
530
+ elif artifact_type == "video" or artifact_type == "warped_video" or artifact_type == "synthesized_flow_warped_video":
531
+ logger.debug(f"Saving video to {filename}")
532
+ export_to_video(artifact_value, filename, fps=self.args.gen_fps)
533
+ artifact_value = wandb.Video(filename, caption=f"[{artifact_type}]--{prompt}")
534
+ elif artifact_type == "synthesized_flow" or artifact_type == "flow":
535
+ # TODO. RGB Visualization of optical flow. (F,2,H,W)
536
+ artifact_value_RGB = flow_to_color(artifact_value) # BF,C,H,W (B=1)
537
+
538
+ frame_list = []
539
+ for frame in artifact_value_RGB:
540
+ frame = (frame.permute(1,2,0).float().detach().cpu().numpy()).astype(np.uint8).clip(0,255)
541
+ frame_list.append(Image.fromarray(frame))
542
+
543
+ logger.debug(f"Saving video to {filename}")
544
+ export_to_video(frame_list, filename, fps=self.args.gen_fps)
545
+ artifact_value = wandb.Video(filename, caption=f"[{artifact_type}]--{prompt}")
546
+
547
+ all_processes_artifacts.append(artifact_value)
548
+
549
+ all_artifacts = gather_object(all_processes_artifacts)
550
+
551
+ if accelerator.is_main_process:
552
+ tracker_key = "validation"
553
+ for tracker in accelerator.trackers:
554
+ if tracker.name == "wandb":
555
+ image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)]
556
+ video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)]
557
+ tracker.log(
558
+ {
559
+ tracker_key: {f"images": image_artifacts, f"videos": video_artifacts},
560
+ },
561
+ step=step,
562
+ )
563
+
564
+ ########## Clean up ##########
565
+ if self.state.using_deepspeed:
566
+ del pipe
567
+ # Unload models except those needed for training
568
+ self.__move_components_to_cpu(unload_list=self.UNLOAD_LIST)
569
+ else:
570
+ pipe.remove_all_hooks()
571
+ del pipe
572
+ # Load models except those not needed for training
573
+ self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=self.UNLOAD_LIST)
574
+ self.components.transformer.to(self.accelerator.device, dtype=self.state.weight_dtype)
575
+
576
+ # Change trainable weights back to fp32 to keep with dtype after prepare the model
577
+ cast_training_params([self.components.transformer], dtype=torch.float32)
578
+
579
+ del camera_flow_generator
580
+
581
+ free_memory()
582
+ accelerator.wait_for_everyone()
583
+ ################################
584
+
585
+ memory_statistics = get_memory_statistics()
586
+ logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
587
+ torch.cuda.reset_peak_memory_stats(accelerator.device)
588
+
589
+ torch.set_grad_enabled(True)
590
+ self.components.transformer.train()
591
+
592
+
593
+ # mangling
594
+ def __move_components_to_device(self, dtype, ignore_list: List[str] = []):
595
+ ignore_list = set(ignore_list)
596
+ components = self.components.model_dump()
597
+ for name, component in components.items():
598
+ if not isinstance(component, type) and hasattr(component, "to"):
599
+ if name not in ignore_list:
600
+ setattr(self.components, name, component.to(self.accelerator.device, dtype=dtype))
601
+
602
+ # mangling
603
+ def __move_components_to_cpu(self, unload_list: List[str] = []):
604
+ unload_list = set(unload_list)
605
+ components = self.components.model_dump()
606
+ for name, component in components.items():
607
+ if not isinstance(component, type) and hasattr(component, "to"):
608
+ if name in unload_list:
609
+ setattr(self.components, name, component.to("cpu"))
610
+
611
+
612
+ register("cogvideox-flovd-omsm", "lora", FloVDOMSMCogVideoXI2VLoraTrainer)
613
+
614
+
615
+ #--------------------------------------------------------------------------------------------------
616
+ # Extract function
617
+ def encode_text(prompt: str, components, device) -> torch.Tensor:
618
+ prompt_token_ids = components.tokenizer(
619
+ prompt,
620
+ padding="max_length",
621
+ max_length=components.transformer.config.max_text_seq_length,
622
+ truncation=True,
623
+ add_special_tokens=True,
624
+ return_tensors="pt",
625
+ )
626
+ prompt_token_ids = prompt_token_ids.input_ids
627
+ prompt_embedding = components.text_encoder(prompt_token_ids.to(device))[0]
628
+ return prompt_embedding
629
+
630
+ def encode_video(video: torch.Tensor, vae) -> torch.Tensor:
631
+ # shape of input video: [B, C, F, H, W]
632
+ video = video.to(vae.device, dtype=vae.dtype)
633
+ latent_dist = vae.encode(video).latent_dist
634
+ latent = latent_dist.sample() * vae.config.scaling_factor
635
+ return latent
636
+
637
+ def decode_latents(latents: torch.Tensor, vae) -> torch.Tensor:
638
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
639
+ latents = 1 / vae.config.scaling_factor * latents
640
+
641
+ frames = vae.decode(latents).sample
642
+ return frames
643
+
644
+ def compute_optical_flow(raft, ctxt, trgt, raft_iter=20, chunk=2, only_forward=True):
645
+ num_frames = ctxt.shape[0]
646
+ chunk_size = (num_frames // chunk) + 1
647
+
648
+ flow_f_list = []
649
+ if not only_forward:
650
+ flow_b_list = []
651
+ for i in range(chunk):
652
+ start = chunk_size * i
653
+ end = chunk_size * (i+1)
654
+
655
+ with torch.no_grad():
656
+ flow_f = raft(ctxt[start:end], trgt[start:end], num_flow_updates=raft_iter)[-1]
657
+ if not only_forward:
658
+ flow_b = raft(trgt[start:end], ctxt[start:end], num_flow_updates=raft_iter)[-1]
659
+
660
+ flow_f_list.append(flow_f)
661
+ if not only_forward:
662
+ flow_b_list.append(flow_b)
663
+
664
+ flow_f = torch.cat(flow_f_list)
665
+ if not only_forward:
666
+ flow_b = torch.cat(flow_b_list)
667
+
668
+ if not only_forward:
669
+ return flow_f, flow_b
670
+ else:
671
+ return flow_f, None
672
+
673
+ def encode_flow(flow, vae, flow_scale_factor):
674
+ # flow: BF,C,H,W
675
+ # flow_scale_factor [sf_x, sf_y]
676
+ assert flow.ndim == 4
677
+ num_frames, _, height, width = flow.shape
678
+
679
+ # Normalize optical flow
680
+ # ndim: 4 -> 5
681
+ flow = rearrange(flow, '(b f) c h w -> b f c h w', b=1)
682
+ flow_norm = adaptive_normalize(flow, flow_scale_factor[0], flow_scale_factor[1])
683
+
684
+ # ndim: 5 -> 4
685
+ flow_norm = rearrange(flow_norm, 'b f c h w -> (b f) c h w', b=1)
686
+
687
+ # Duplicate mean value for third channel
688
+ num_frames, _, H, W = flow_norm.shape
689
+ flow_norm_extended = torch.empty((num_frames, 3, height, width)).to(flow_norm)
690
+ flow_norm_extended[:,:2] = flow_norm
691
+ flow_norm_extended[:,-1:] = flow_norm.mean(dim=1, keepdim=True)
692
+ flow_norm_extended = rearrange(flow_norm_extended, '(b f) c h w -> b c f h w', f=num_frames)
693
+
694
+ return encode_video(flow_norm_extended, vae)
695
+
696
+ def decode_flow(flow_latent, vae, flow_scale_factor):
697
+ flow_latent = flow_latent.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
698
+ flow_latent = 1 / vae.config.scaling_factor * flow_latent
699
+
700
+ flow = vae.decode(flow_latent).sample # BCFHW
701
+
702
+ # discard third channel (which is a mean value of f_x and f_y)
703
+ flow = flow[:,:2].detach().clone()
704
+
705
+ # Unnormalize optical flow
706
+ flow = rearrange(flow, 'b c f h w -> b f c h w')
707
+ flow = adaptive_unnormalize(flow, flow_scale_factor[0], flow_scale_factor[1])
708
+
709
+ flow = rearrange(flow, 'b f c h w -> (b f) c h w')
710
+ return flow # BF,C,H,W
711
+
712
+ def adaptive_normalize(flow, sf_x, sf_y):
713
+ # x: BFCHW, optical flow
714
+ assert flow.ndim == 5, 'Set the shape of the flow input as (B, F, C, H, W)'
715
+ assert sf_x is not None and sf_y is not None
716
+ b, f, c, h, w = flow.shape
717
+
718
+ max_clip_x = math.sqrt(w/sf_x) * 1.0
719
+ max_clip_y = math.sqrt(h/sf_y) * 1.0
720
+
721
+ flow_norm = flow.detach().clone()
722
+ flow_x = flow[:, :, 0].detach().clone()
723
+ flow_y = flow[:, :, 1].detach().clone()
724
+
725
+ flow_x_norm = torch.sign(flow_x) * torch.sqrt(torch.abs(flow_x)/sf_x + 1e-7)
726
+ flow_y_norm = torch.sign(flow_y) * torch.sqrt(torch.abs(flow_y)/sf_y + 1e-7)
727
+
728
+ flow_norm[:, :, 0] = torch.clamp(flow_x_norm, min=-max_clip_x, max=max_clip_x)
729
+ flow_norm[:, :, 1] = torch.clamp(flow_y_norm, min=-max_clip_y, max=max_clip_y)
730
+
731
+ return flow_norm
732
+
733
+
734
+ def adaptive_unnormalize(flow, sf_x, sf_y):
735
+ # x: BFCHW, optical flow
736
+ assert flow.ndim == 5, 'Set the shape of the flow input as (B, F, C, H, W)'
737
+ assert sf_x is not None and sf_y is not None
738
+
739
+ flow_orig = flow.detach().clone()
740
+ flow_x = flow[:, :, 0].detach().clone()
741
+ flow_y = flow[:, :, 1].detach().clone()
742
+
743
+ flow_orig[:, :, 0] = torch.sign(flow_x) * sf_x * (flow_x**2 - 1e-7)
744
+ flow_orig[:, :, 1] = torch.sign(flow_y) * sf_y * (flow_y**2 - 1e-7)
745
+
746
+ return flow_orig
747
+
748
+ #--------------------------------------------------------------------------------------------------
finetune/models/cogvideox_i2v/flovd_controlnet_trainer.py ADDED
@@ -0,0 +1,814 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Tuple
2
+ from pathlib import Path
3
+ import os
4
+ import hashlib
5
+ import json
6
+ import random
7
+ import wandb
8
+ import math
9
+ import numpy as np
10
+ from einops import rearrange, repeat
11
+ from safetensors.torch import load_file, save_file
12
+ from accelerate.logging import get_logger
13
+
14
+ import torch
15
+
16
+ from accelerate.utils import gather_object
17
+
18
+ from diffusers import (
19
+ AutoencoderKLCogVideoX,
20
+ CogVideoXDPMScheduler,
21
+ CogVideoXImageToVideoPipeline,
22
+ CogVideoXTransformer3DModel,
23
+ )
24
+ from diffusers.utils.export_utils import export_to_video
25
+
26
+ from finetune.pipeline.flovd_FVSM_cogvideox_controlnet_pipeline import FloVDCogVideoXControlnetImageToVideoPipeline
27
+ from finetune.constants import LOG_LEVEL, LOG_NAME
28
+
29
+ from diffusers.models.embeddings import get_3d_rotary_pos_embed
30
+ from PIL import Image
31
+ from numpy import dtype
32
+ from transformers import AutoTokenizer, T5EncoderModel
33
+ from typing_extensions import override
34
+
35
+ from finetune.schemas import Args, Components, State
36
+ from finetune.trainer import Trainer
37
+ from finetune.utils import (
38
+ cast_training_params,
39
+ free_memory,
40
+ get_memory_statistics,
41
+ string_to_filename,
42
+ unwrap_model,
43
+ )
44
+ from finetune.datasets.utils import (
45
+ preprocess_image_with_resize,
46
+ load_binary_mask_compressed,
47
+ )
48
+
49
+ from finetune.modules.cogvideox_controlnet import CogVideoXControlnet
50
+ from finetune.modules.cogvideox_custom_model import CustomCogVideoXTransformer3DModel
51
+ from finetune.modules.camera_sampler import SampleManualCam
52
+ from finetune.modules.camera_flow_generator import CameraFlowGenerator
53
+ from finetune.modules.utils import get_camera_flow_generator_input, forward_bilinear_splatting
54
+
55
+ from ..utils import register
56
+
57
+ import sys
58
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
59
+
60
+ import pdb
61
+
62
+ logger = get_logger(LOG_NAME, LOG_LEVEL)
63
+
64
+ class FloVDCogVideoXI2VControlnetTrainer(Trainer):
65
+ UNLOAD_LIST = ["text_encoder"]
66
+
67
+ @override
68
+ def __init__(self, args: Args) -> None:
69
+ super().__init__(args)
70
+
71
+ # For validation
72
+ self.CameraSampler = SampleManualCam()
73
+
74
+
75
+
76
+ @override
77
+ def load_components(self) -> Dict[str, Any]:
78
+ # TODO. Change the pipeline and ...
79
+ components = Components()
80
+ model_path = str(self.args.model_path)
81
+
82
+ components.pipeline_cls = FloVDCogVideoXControlnetImageToVideoPipeline
83
+
84
+ components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
85
+
86
+ components.text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder")
87
+
88
+ # components.transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer")
89
+
90
+ components.transformer = CustomCogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer")
91
+
92
+ additional_kwargs = {
93
+ 'num_layers': self.args.controlnet_transformer_num_layers,
94
+ 'out_proj_dim_factor': self.args.controlnet_out_proj_dim_factor,
95
+ 'out_proj_dim_zero_init': self.args.controlnet_out_proj_zero_init,
96
+ 'notextinflow': self.args.notextinflow,
97
+ }
98
+ components.controlnet = CogVideoXControlnet.from_pretrained(model_path, subfolder="transformer", **additional_kwargs)
99
+
100
+ components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
101
+
102
+ components.scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
103
+
104
+ return components
105
+
106
+
107
+ @override
108
+ def initialize_pipeline(self) -> FloVDCogVideoXControlnetImageToVideoPipeline:
109
+ # TODO. Change the pipeline and ...
110
+ pipe = FloVDCogVideoXControlnetImageToVideoPipeline(
111
+ tokenizer=self.components.tokenizer,
112
+ text_encoder=unwrap_model(self.accelerator, self.components.text_encoder),
113
+ vae=unwrap_model(self.accelerator, self.components.vae),
114
+ transformer=unwrap_model(self.accelerator, self.components.transformer),
115
+ controlnet=unwrap_model(self.accelerator, self.components.controlnet),
116
+ scheduler=self.components.scheduler,
117
+ )
118
+ return pipe
119
+
120
+ def initialize_flow_generator(self, ckpt_path):
121
+ depth_estimator_kwargs = {
122
+ "target": 'modules.depth_warping.depth_warping.DepthWarping_wrapper',
123
+ "kwargs": {
124
+ "ckpt_path": ckpt_path,
125
+ "model_config": {
126
+ "max_depth": 20,
127
+ "encoder": 'vitb',
128
+ "features": 128,
129
+ "out_channels": [96, 192, 384, 768],
130
+ }
131
+
132
+ }
133
+ }
134
+
135
+ return CameraFlowGenerator(depth_estimator_kwargs)
136
+
137
+ @override
138
+ def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
139
+ ret = {"encoded_videos": [], "prompt_embedding": [], "images": [], "encoded_flow": []}
140
+
141
+ for sample in samples:
142
+ encoded_video = sample["encoded_video"]
143
+ prompt_embedding = sample["prompt_embedding"]
144
+ image = sample["image"]
145
+ encoded_flow = sample["encoded_flow"]
146
+
147
+ ret["encoded_videos"].append(encoded_video)
148
+ ret["prompt_embedding"].append(prompt_embedding)
149
+ ret["images"].append(image)
150
+ ret["encoded_flow"].append(encoded_flow)
151
+
152
+
153
+ ret["encoded_videos"] = torch.stack(ret["encoded_videos"])
154
+ ret["prompt_embedding"] = torch.stack(ret["prompt_embedding"])
155
+ ret["images"] = torch.stack(ret["images"])
156
+ ret["encoded_flow"] = torch.stack(ret["encoded_flow"])
157
+
158
+ return ret
159
+
160
+
161
+ @override
162
+ def compute_loss(self, batch) -> torch.Tensor:
163
+ prompt_embedding = batch["prompt_embedding"]
164
+ latent = batch["encoded_videos"]
165
+ images = batch["images"]
166
+ latent_flow = batch["encoded_flow"]
167
+
168
+ # Shape of prompt_embedding: [B, seq_len, hidden_size]
169
+ # Shape of latent: [B, C, F, H, W]
170
+ # Shape of images: [B, C, H, W]
171
+ # Shape of latent_flow: [B, C, F, H, W]
172
+
173
+ patch_size_t = self.state.transformer_config.patch_size_t # WJ: None in i2v setting...
174
+ if patch_size_t is not None:
175
+ ncopy = latent.shape[2] % patch_size_t
176
+ # Copy the first frame ncopy times to match patch_size_t
177
+ first_frame = latent[:, :, :1, :, :] # Get first frame [B, C, 1, H, W]
178
+ latent = torch.cat([first_frame.repeat(1, 1, ncopy, 1, 1), latent], dim=2)
179
+ assert latent.shape[2] % patch_size_t == 0
180
+
181
+ batch_size, num_channels, num_frames, height, width = latent.shape
182
+
183
+ # Get prompt embeddings
184
+ _, seq_len, _ = prompt_embedding.shape
185
+ prompt_embedding = prompt_embedding.view(batch_size, seq_len, -1).to(dtype=latent.dtype)
186
+
187
+ # Add frame dimension to images [B,C,H,W] -> [B,C,F,H,W]
188
+ images = images.unsqueeze(2)
189
+ # Add noise to images
190
+ image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=self.accelerator.device)
191
+ image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=images.dtype)
192
+ noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None]
193
+ image_latent_dist = self.components.vae.encode(noisy_images.to(dtype=self.components.vae.dtype)).latent_dist
194
+ image_latents = image_latent_dist.sample() * self.components.vae.config.scaling_factor
195
+
196
+ """
197
+ Modify below
198
+ """
199
+ # Sample a random timestep for each sample
200
+ # timesteps = torch.randint(
201
+ # 0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device
202
+ # )
203
+ if self.args.enable_time_sampling:
204
+ if self.args.time_sampling_type == "truncated_normal":
205
+ time_sampling_dict = {
206
+ 'mean': self.args.time_sampling_mean,
207
+ 'std': self.args.time_sampling_std,
208
+ 'a': 1 - self.args.controlnet_guidance_end,
209
+ 'b': 1 - self.args.controlnet_guidance_start,
210
+ }
211
+ timesteps = torch.nn.init.trunc_normal_(
212
+ torch.empty(batch_size, device=latent.device), **time_sampling_dict
213
+ ) * self.components.scheduler.config.num_train_timesteps
214
+ elif self.args.time_sampling_type == "truncated_uniform":
215
+ timesteps = torch.randint(
216
+ int((1- self.args.controlnet_guidance_end) * self.components.scheduler.config.num_train_timesteps),
217
+ int((1 - self.args.controlnet_guidance_start) * self.components.scheduler.config.num_train_timesteps),
218
+ (batch_size,), device=latent.device
219
+ )
220
+ else:
221
+ timesteps = torch.randint(
222
+ 0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device
223
+ )
224
+ timesteps = timesteps.long()
225
+
226
+ # from [B, C, F, H, W] to [B, F, C, H, W]
227
+ latent = latent.permute(0, 2, 1, 3, 4)
228
+ latent_flow = latent_flow.permute(0, 2, 1, 3, 4)
229
+ image_latents = image_latents.permute(0, 2, 1, 3, 4)
230
+ assert (latent.shape[0], *latent.shape[2:]) == (image_latents.shape[0], *image_latents.shape[2:]) == (latent_flow.shape[0], *latent_flow.shape[2:])
231
+
232
+ # Padding image_latents to the same frame number as latent
233
+ padding_shape = (latent.shape[0], latent.shape[1] - 1, *latent.shape[2:])
234
+ latent_padding = image_latents.new_zeros(padding_shape)
235
+ image_latents = torch.cat([image_latents, latent_padding], dim=1)
236
+
237
+ # Add noise to latent
238
+ noise = torch.randn_like(latent)
239
+ latent_noisy = self.components.scheduler.add_noise(latent, noise, timesteps)
240
+
241
+
242
+ # Concatenate latent and image_latents in the channel dimension
243
+ # latent_img_flow_noisy = torch.cat([latent_noisy, image_latents, latent_flow], dim=2)
244
+ latent_img_noisy = torch.cat([latent_noisy, image_latents], dim=2)
245
+
246
+ # Prepare rotary embeds
247
+ vae_scale_factor_spatial = 2 ** (len(self.components.vae.config.block_out_channels) - 1)
248
+ transformer_config = self.state.transformer_config
249
+ rotary_emb = (
250
+ self.prepare_rotary_positional_embeddings(
251
+ height=height * vae_scale_factor_spatial,
252
+ width=width * vae_scale_factor_spatial,
253
+ num_frames=num_frames,
254
+ transformer_config=transformer_config,
255
+ vae_scale_factor_spatial=vae_scale_factor_spatial,
256
+ device=self.accelerator.device,
257
+ )
258
+ if transformer_config.use_rotary_positional_embeddings
259
+ else None
260
+ )
261
+
262
+ # Predict noise, For CogVideoX1.5 Only.
263
+ ofs_emb = (
264
+ None if self.state.transformer_config.ofs_embed_dim is None else latent.new_full((1,), fill_value=2.0)
265
+ )
266
+
267
+ # Controlnet feedforward
268
+ controlnet_states = self.components.controlnet(
269
+ hidden_states=latent_noisy,
270
+ encoder_hidden_states=prompt_embedding,
271
+ image_rotary_emb=rotary_emb,
272
+ controlnet_hidden_states=latent_flow,
273
+ timestep=timesteps,
274
+ return_dict=False,
275
+ )[0]
276
+ if isinstance(controlnet_states, (tuple, list)):
277
+ controlnet_states = [x.to(dtype=self.state.weight_dtype) for x in controlnet_states]
278
+ else:
279
+ controlnet_states = controlnet_states.to(dtype=self.state.weight_dtype)
280
+
281
+
282
+ # Transformer feedforward
283
+ predicted_noise = self.components.transformer(
284
+ hidden_states=latent_img_noisy,
285
+ encoder_hidden_states=prompt_embedding,
286
+ controlnet_states=controlnet_states,
287
+ controlnet_weights=self.args.controlnet_weights,
288
+ timestep=timesteps,
289
+ # ofs=ofs_emb,
290
+ image_rotary_emb=rotary_emb,
291
+ return_dict=False,
292
+ )[0]
293
+
294
+
295
+ # Denoise
296
+ latent_pred = self.components.scheduler.get_velocity(predicted_noise, latent_noisy, timesteps)
297
+
298
+ alphas_cumprod = self.components.scheduler.alphas_cumprod[timesteps]
299
+ weights = 1 / (1 - alphas_cumprod)
300
+ while len(weights.shape) < len(latent_pred.shape):
301
+ weights = weights.unsqueeze(-1)
302
+
303
+ loss = torch.mean((weights * (latent_pred - latent) ** 2).reshape(batch_size, -1), dim=1)
304
+ loss = loss.mean()
305
+
306
+ return loss
307
+
308
+ def prepare_rotary_positional_embeddings(
309
+ self,
310
+ height: int,
311
+ width: int,
312
+ num_frames: int,
313
+ transformer_config: Dict,
314
+ vae_scale_factor_spatial: int,
315
+ device: torch.device,
316
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
317
+ grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
318
+ grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)
319
+
320
+ if transformer_config.patch_size_t is None:
321
+ base_num_frames = num_frames
322
+ else:
323
+ base_num_frames = (num_frames + transformer_config.patch_size_t - 1) // transformer_config.patch_size_t
324
+
325
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
326
+ embed_dim=transformer_config.attention_head_dim,
327
+ crops_coords=None,
328
+ grid_size=(grid_height, grid_width),
329
+ temporal_size=base_num_frames,
330
+ grid_type="slice",
331
+ max_size=(grid_height, grid_width),
332
+ device=device,
333
+ )
334
+
335
+ return freqs_cos, freqs_sin
336
+
337
+ # Validation
338
+
339
+ @override
340
+ def prepare_for_validation(self):
341
+ # Load from dataset?
342
+ # Data_root
343
+ # - metadata.jsonl
344
+ # - video_latent / args.resolution /
345
+ # - prompt_embeddings /
346
+ # - first_frames /
347
+ # - flow_direct_f_latent /
348
+
349
+ data_root = self.args.data_root
350
+ metadata_path = data_root / "metadata_revised.jsonl"
351
+ assert metadata_path.is_file(), "For this dataset type, you need metadata.jsonl or metadata_revised.jsonl in the root path"
352
+
353
+ # Load metadata
354
+ # metadata = {
355
+ # "video_path": ...,
356
+ # "hash_code": ...,
357
+ # "prompt": ...,
358
+ # }
359
+ metadata = []
360
+ with open(metadata_path, "r") as f:
361
+ for line in f:
362
+ metadata.append( json.loads(line) )
363
+
364
+ metadata = random.sample(metadata, self.args.max_scene)
365
+
366
+ prompts = [x["prompt"] for x in metadata]
367
+ prompt_embeddings = [data_root / "prompt_embeddings_revised" / (x["hash_code"] + '.safetensors') for x in metadata]
368
+ videos = [data_root / "video_latent" / "x".join(str(x) for x in self.args.train_resolution) / (x["hash_code"] + '.safetensors') for x in metadata]
369
+ images = [data_root / "first_frames" / (x["hash_code"] + '.png') for x in metadata]
370
+ flows = [data_root / "flow_direct_f_latent" / (x["hash_code"] + '.safetensors') for x in metadata]
371
+
372
+ # load prompt embedding
373
+ validation_prompts = []
374
+ validation_prompt_embeddings = []
375
+ validation_video_latents = []
376
+ validation_images = []
377
+ validation_flow_latents = []
378
+ for prompt, prompt_embedding, video_latent, image, flow_latent in zip(prompts, prompt_embeddings, videos, images, flows):
379
+ validation_prompts.append(prompt)
380
+ validation_prompt_embeddings.append(load_file(prompt_embedding)["prompt_embedding"].unsqueeze(0))
381
+ validation_video_latents.append(load_file(video_latent)["encoded_video"].unsqueeze(0))
382
+ validation_flow_latents.append(load_file(flow_latent)["encoded_flow_f"].unsqueeze(0))
383
+ # validation_images.append(preprocess_image_with_resize(image, self.args.train_resolution[1], self.args.train_resolution[2]))
384
+ validation_images.append(image)
385
+
386
+
387
+ validation_videos = [None] * len(validation_prompts)
388
+
389
+
390
+ self.state.validation_prompts = validation_prompts
391
+ self.state.validation_prompt_embeddings = validation_prompt_embeddings
392
+ self.state.validation_images = validation_images
393
+ self.state.validation_videos = validation_videos
394
+ self.state.validation_video_latents = validation_video_latents
395
+ self.state.validation_flow_latents = validation_flow_latents
396
+
397
+ # Debug..
398
+ # self.validate(0)
399
+
400
+
401
+ @override
402
+ def validation_step(
403
+ self, eval_data: Dict[str, Any], pipe: FloVDCogVideoXControlnetImageToVideoPipeline
404
+ ) -> List[Tuple[str, Image.Image | List[Image.Image]]]:
405
+ """
406
+ Return the data that needs to be saved. For videos, the data format is List[PIL],
407
+ and for images, the data format is PIL
408
+ """
409
+
410
+ prompt_embedding, image, flow_latent = eval_data["prompt_embedding"], eval_data["image"], eval_data["flow_latent"]
411
+
412
+ video_generate = pipe(
413
+ num_frames=self.state.train_frames,
414
+ height=self.state.train_height,
415
+ width=self.state.train_width,
416
+ prompt=None,
417
+ prompt_embeds=prompt_embedding,
418
+ image=image,
419
+ flow_latent=flow_latent,
420
+ generator=self.state.generator,
421
+ num_inference_steps=50,
422
+ controlnet_guidance_start = self.args.controlnet_guidance_start,
423
+ controlnet_guidance_end = self.args.controlnet_guidance_end,
424
+ ).frames[0]
425
+ return [("synthesized_video", video_generate)]
426
+
427
+
428
+ @override
429
+ def validate(self, step: int) -> None:
430
+ #TODO. Fix the codes!!!!
431
+ logger.info("Starting validation")
432
+
433
+ accelerator = self.accelerator
434
+ num_validation_samples = len(self.state.validation_prompts)
435
+
436
+ if num_validation_samples == 0:
437
+ logger.warning("No validation samples found. Skipping validation.")
438
+ return
439
+
440
+ self.components.controlnet.eval()
441
+ torch.set_grad_enabled(False)
442
+
443
+ memory_statistics = get_memory_statistics()
444
+ logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}")
445
+
446
+ ##### Initialize pipeline #####
447
+ pipe = self.initialize_pipeline()
448
+ camera_flow_generator = self.initialize_flow_generator(ckpt_path=self.args.depth_ckpt_path).to(device=self.accelerator.device, dtype=self.state.weight_dtype)
449
+
450
+ if self.state.using_deepspeed:
451
+ # Can't using model_cpu_offload in deepspeed,
452
+ # so we need to move all components in pipe to device
453
+ # pipe.to(self.accelerator.device, dtype=self.state.weight_dtype)
454
+ self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=["controlnet"])
455
+ # self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=["transformer", "controlnet"])
456
+ else:
457
+ # if not using deepspeed, use model_cpu_offload to further reduce memory usage
458
+ # Or use pipe.enable_sequential_cpu_offload() to further reduce memory usage
459
+ pipe.enable_model_cpu_offload(device=self.accelerator.device)
460
+
461
+ # Convert all model weights to training dtype
462
+ # Note, this will change LoRA weights in self.components.transformer to training dtype, rather than keep them in fp32
463
+ pipe = pipe.to(dtype=self.state.weight_dtype)
464
+
465
+
466
+ #################################
467
+ inference_type = ['training', 'inference']
468
+ # inference_type = ['inference']
469
+ for infer_type in inference_type:
470
+
471
+
472
+ all_processes_artifacts = []
473
+ for i in range(num_validation_samples):
474
+ if self.state.using_deepspeed and self.accelerator.deepspeed_plugin.zero_stage != 3:
475
+ # Skip current validation on all processes but one
476
+ if i % accelerator.num_processes != accelerator.process_index:
477
+ continue
478
+
479
+ prompt = self.state.validation_prompts[i]
480
+ image = self.state.validation_images[i]
481
+ video = self.state.validation_videos[i]
482
+ video_latent = self.state.validation_video_latents[i].permute(0,2,1,3,4) # [B,F,C,H,W] (e.g., [B, 13, 16, 60, 90])
483
+ prompt_embedding = self.state.validation_prompt_embeddings[i]
484
+ flow_latent = self.state.validation_flow_latents[i].permute(0,2,1,3,4) # [B,F,C,H,W] (e.g., [B, 13, 16, 60, 90])
485
+
486
+
487
+ if image is not None:
488
+ image = preprocess_image_with_resize(image, self.state.train_height, self.state.train_width)
489
+ image_torch = image.detach().clone()
490
+ # Convert image tensor (C, H, W) to PIL images
491
+ image = image.to(torch.uint8)
492
+ image = image.permute(1, 2, 0).cpu().numpy()
493
+ image = Image.fromarray(image)
494
+
495
+ if video is not None:
496
+ video = preprocess_video_with_resize(
497
+ video, self.state.train_frames, self.state.train_height, self.state.train_width
498
+ )
499
+ # Convert video tensor (F, C, H, W) to list of PIL images
500
+ video = video.round().clamp(0, 255).to(torch.uint8)
501
+ video = [Image.fromarray(frame.permute(1, 2, 0).cpu().numpy()) for frame in video]
502
+ else:
503
+ if infer_type == 'training':
504
+ with torch.cuda.amp.autocast(enabled=True, dtype=self.state.weight_dtype):
505
+ try:
506
+ video_decoded = decode_latents(video_latent.to(self.accelerator.device), self.components.vae)
507
+ except:
508
+ pass
509
+ video_decoded = decode_latents(video_latent.to(self.accelerator.device), self.components.vae)
510
+ video = ((video_decoded + 1.) / 2. * 255.)[0].permute(1,0,2,3).float().clip(0., 255.).to(torch.uint8)
511
+ video = [Image.fromarray(frame.permute(1, 2, 0).cpu().numpy()) for frame in video]
512
+
513
+ with torch.cuda.amp.autocast(enabled=True, dtype=self.state.weight_dtype):
514
+ try:
515
+ flow_decoded = decode_flow(flow_latent.to(self.accelerator.device), self.components.vae, flow_scale_factor=[60, 36])
516
+ except:
517
+ pass
518
+ flow_decoded = decode_flow(flow_latent.to(self.accelerator.device), self.components.vae, flow_scale_factor=[60, 36]) # (BF)CHW (C=2)
519
+
520
+
521
+ # Prepare camera flow
522
+ if infer_type == 'inference':
523
+ with torch.cuda.amp.autocast(enabled=True, dtype=self.state.weight_dtype):
524
+ camparam, cam_name = self.CameraSampler.sample()
525
+ camera_flow_generator_input = get_camera_flow_generator_input(image_torch, camparam, device=self.accelerator.device, speed=0.5)
526
+ image_torch = ((image_torch.unsqueeze(0) / 255.) * 2. - 1.).to(self.accelerator.device)
527
+ camera_flow, log_dict = camera_flow_generator(image_torch, camera_flow_generator_input)
528
+ camera_flow = camera_flow.to(self.accelerator.device)
529
+ # WTF, unknown bug. Need warm up inference.
530
+ try:
531
+ flow_latent = rearrange(encode_flow(camera_flow, self.components.vae, flow_scale_factor=[60, 36]), 'b c f h w -> b f c h w').to(self.accelerator.device, self.state.weight_dtype)
532
+ except:
533
+ pass
534
+ flow_latent = rearrange(encode_flow(camera_flow, self.components.vae, flow_scale_factor=[60, 36]), 'b c f h w -> b f c h w').to(self.accelerator.device, self.state.weight_dtype)
535
+
536
+
537
+ logger.debug(
538
+ f"Validating sample {i + 1}/{num_validation_samples} on process {accelerator.process_index}. Prompt: {prompt}",
539
+ main_process_only=False,
540
+ )
541
+ # validation_artifacts = self.validation_step({"prompt": prompt, "image": image, "video": video}, pipe)
542
+ validation_artifacts = self.validation_step({"prompt_embedding": prompt_embedding, "image": image, "flow_latent": flow_latent}, pipe)
543
+
544
+ if (
545
+ self.state.using_deepspeed
546
+ and self.accelerator.deepspeed_plugin.zero_stage == 3
547
+ and not accelerator.is_main_process
548
+ ):
549
+ continue
550
+
551
+ prompt_filename = string_to_filename(prompt)[:25]
552
+ # Calculate hash of reversed prompt as a unique identifier
553
+ reversed_prompt = prompt[::-1]
554
+ hash_suffix = hashlib.md5(reversed_prompt.encode()).hexdigest()[:5]
555
+
556
+ artifacts = {
557
+ "image": {"type": "image", "value": image},
558
+ "video": {"type": "video", "value": video},
559
+ }
560
+ for i, (artifact_type, artifact_value) in enumerate(validation_artifacts):
561
+ artifacts.update({f"artifact_{i}": {"type": artifact_type, "value": artifact_value}})
562
+ if infer_type == 'training':
563
+ # Log flow_warped_frames
564
+ image_tensor = repeat(rearrange(torch.tensor(np.array(image)).to(flow_decoded.device, torch.float), 'h w c -> 1 c h w'), 'b c h w -> (b f) c h w', f=flow_decoded.size(0)) # scale~(0,255) (BF) C H W
565
+ warped_video = forward_bilinear_splatting(image_tensor, flow_decoded.to(torch.float)) # if we have an occlusion mask from dataset, we can use it.
566
+ frame_list = []
567
+ for frame in warped_video:
568
+ frame = (frame.permute(1,2,0).float().detach().cpu().numpy()).astype(np.uint8).clip(0,255)
569
+ frame_list.append(Image.fromarray(frame))
570
+
571
+ artifacts.update({f"artifact_warped_video_{i}": {"type": 'warped_video', "value": frame_list}})
572
+
573
+ if infer_type == 'inference':
574
+ warped_video = log_dict['depth_warped_frames']
575
+ frame_list = []
576
+ for frame in warped_video:
577
+ frame = (frame + 1.)/2. * 255.
578
+ frame = (frame.permute(1,2,0).float().detach().cpu().numpy()).astype(np.uint8).clip(0,255)
579
+ frame_list.append(Image.fromarray(frame))
580
+
581
+ artifacts.update({f"artifact_warped_video_{i}": {"type": 'warped_video', "value": frame_list}})
582
+ logger.debug(
583
+ f"Validation artifacts on process {accelerator.process_index}: {list(artifacts.keys())}",
584
+ main_process_only=False,
585
+ )
586
+
587
+ for key, value in list(artifacts.items()):
588
+ artifact_type = value["type"]
589
+ artifact_value = value["value"]
590
+ if artifact_type not in ["image", "video", "warped_video", "synthesized_video"] or artifact_value is None:
591
+ continue
592
+
593
+ extension = "png" if artifact_type == "image" else "mp4"
594
+ if artifact_type == "warped_video":
595
+ filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}-{hash_suffix}-{infer_type}_warped_video.{extension}"
596
+ elif artifact_type == "synthesized_video":
597
+ filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}-{hash_suffix}-{infer_type}_synthesized_video.{extension}"
598
+ else:
599
+ filename = f"validation-{step}-{accelerator.process_index}-{prompt_filename}-{hash_suffix}-{infer_type}.{extension}"
600
+ validation_path = self.args.output_dir / "validation_res"
601
+ validation_path.mkdir(parents=True, exist_ok=True)
602
+ filename = str(validation_path / filename)
603
+
604
+ if artifact_type == "image":
605
+ logger.debug(f"Saving image to {filename}")
606
+ artifact_value.save(filename)
607
+ artifact_value = wandb.Image(filename)
608
+ elif artifact_type == "video" or artifact_type == "warped_video" or artifact_type == "synthesized_video":
609
+ logger.debug(f"Saving video to {filename}")
610
+ export_to_video(artifact_value, filename, fps=self.args.gen_fps)
611
+ artifact_value = wandb.Video(filename, caption=prompt)
612
+
613
+ all_processes_artifacts.append(artifact_value)
614
+
615
+ all_artifacts = gather_object(all_processes_artifacts)
616
+
617
+ if accelerator.is_main_process:
618
+ tracker_key = "validation"
619
+ for tracker in accelerator.trackers:
620
+ if tracker.name == "wandb":
621
+ image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)]
622
+ video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)]
623
+ tracker.log(
624
+ {
625
+ tracker_key: {f"images_{infer_type}": image_artifacts, f"videos_{infer_type}": video_artifacts},
626
+ },
627
+ step=step,
628
+ )
629
+
630
+ ########## Clean up ##########
631
+ if self.state.using_deepspeed:
632
+ del pipe
633
+ # Unload models except those needed for training
634
+ self.__move_components_to_cpu(unload_list=self.UNLOAD_LIST)
635
+ else:
636
+ pipe.remove_all_hooks()
637
+ del pipe
638
+ # Load models except those not needed for training
639
+ self.__move_components_to_device(dtype=self.state.weight_dtype, ignore_list=self.UNLOAD_LIST)
640
+ self.components.controlnet.to(self.accelerator.device, dtype=self.state.weight_dtype)
641
+
642
+ # Change trainable weights back to fp32 to keep with dtype after prepare the model
643
+ cast_training_params([self.components.controlnet], dtype=torch.float32)
644
+
645
+ del camera_flow_generator
646
+
647
+ free_memory()
648
+ accelerator.wait_for_everyone()
649
+ ################################
650
+
651
+ memory_statistics = get_memory_statistics()
652
+ logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}")
653
+ torch.cuda.reset_peak_memory_stats(accelerator.device)
654
+
655
+ torch.set_grad_enabled(True)
656
+ self.components.controlnet.train()
657
+
658
+
659
+ # mangling
660
+ def __move_components_to_device(self, dtype, ignore_list: List[str] = []):
661
+ ignore_list = set(ignore_list)
662
+ components = self.components.model_dump()
663
+ for name, component in components.items():
664
+ if not isinstance(component, type) and hasattr(component, "to"):
665
+ if name not in ignore_list:
666
+ setattr(self.components, name, component.to(self.accelerator.device, dtype=dtype))
667
+
668
+ # mangling
669
+ def __move_components_to_cpu(self, unload_list: List[str] = []):
670
+ unload_list = set(unload_list)
671
+ components = self.components.model_dump()
672
+ for name, component in components.items():
673
+ if not isinstance(component, type) and hasattr(component, "to"):
674
+ if name in unload_list:
675
+ setattr(self.components, name, component.to("cpu"))
676
+
677
+
678
+ register("cogvideox-flovd", "controlnet", FloVDCogVideoXI2VControlnetTrainer)
679
+
680
+
681
+ #--------------------------------------------------------------------------------------------------
682
+ # Extract function
683
+ def encode_text(prompt: str, components, device) -> torch.Tensor:
684
+ prompt_token_ids = components.tokenizer(
685
+ prompt,
686
+ padding="max_length",
687
+ max_length=components.transformer.config.max_text_seq_length,
688
+ truncation=True,
689
+ add_special_tokens=True,
690
+ return_tensors="pt",
691
+ )
692
+ prompt_token_ids = prompt_token_ids.input_ids
693
+ prompt_embedding = components.text_encoder(prompt_token_ids.to(device))[0]
694
+ return prompt_embedding
695
+
696
+ def encode_video(video: torch.Tensor, vae) -> torch.Tensor:
697
+ # shape of input video: [B, C, F, H, W]
698
+ video = video.to(vae.device, dtype=vae.dtype)
699
+ latent_dist = vae.encode(video).latent_dist
700
+ latent = latent_dist.sample() * vae.config.scaling_factor
701
+ return latent
702
+
703
+ def decode_latents(latents: torch.Tensor, vae) -> torch.Tensor:
704
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
705
+ latents = 1 / vae.config.scaling_factor * latents
706
+
707
+ frames = vae.decode(latents).sample
708
+ return frames
709
+
710
+ def compute_optical_flow(raft, ctxt, trgt, raft_iter=20, chunk=2, only_forward=True):
711
+ num_frames = ctxt.shape[0]
712
+ chunk_size = (num_frames // chunk) + 1
713
+
714
+ flow_f_list = []
715
+ if not only_forward:
716
+ flow_b_list = []
717
+ for i in range(chunk):
718
+ start = chunk_size * i
719
+ end = chunk_size * (i+1)
720
+
721
+ with torch.no_grad():
722
+ flow_f = raft(ctxt[start:end], trgt[start:end], num_flow_updates=raft_iter)[-1]
723
+ if not only_forward:
724
+ flow_b = raft(trgt[start:end], ctxt[start:end], num_flow_updates=raft_iter)[-1]
725
+
726
+ flow_f_list.append(flow_f)
727
+ if not only_forward:
728
+ flow_b_list.append(flow_b)
729
+
730
+ flow_f = torch.cat(flow_f_list)
731
+ if not only_forward:
732
+ flow_b = torch.cat(flow_b_list)
733
+
734
+ if not only_forward:
735
+ return flow_f, flow_b
736
+ else:
737
+ return flow_f, None
738
+
739
+ def encode_flow(flow, vae, flow_scale_factor):
740
+ # flow: BF,C,H,W
741
+ # flow_scale_factor [sf_x, sf_y]
742
+ assert flow.ndim == 4
743
+ num_frames, _, height, width = flow.shape
744
+
745
+ # Normalize optical flow
746
+ # ndim: 4 -> 5
747
+ flow = rearrange(flow, '(b f) c h w -> b f c h w', b=1)
748
+ flow_norm = adaptive_normalize(flow, flow_scale_factor[0], flow_scale_factor[1])
749
+
750
+ # ndim: 5 -> 4
751
+ flow_norm = rearrange(flow_norm, 'b f c h w -> (b f) c h w', b=1)
752
+
753
+ # Duplicate mean value for third channel
754
+ num_frames, _, H, W = flow_norm.shape
755
+ flow_norm_extended = torch.empty((num_frames, 3, height, width)).to(flow_norm)
756
+ flow_norm_extended[:,:2] = flow_norm
757
+ flow_norm_extended[:,-1:] = flow_norm.mean(dim=1, keepdim=True)
758
+ flow_norm_extended = rearrange(flow_norm_extended, '(b f) c h w -> b c f h w', f=num_frames)
759
+
760
+ return encode_video(flow_norm_extended, vae)
761
+
762
+ def decode_flow(flow_latent, vae, flow_scale_factor):
763
+ flow_latent = flow_latent.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
764
+ flow_latent = 1 / vae.config.scaling_factor * flow_latent
765
+
766
+ flow = vae.decode(flow_latent).sample # BCFHW
767
+
768
+ # discard third channel (which is a mean value of f_x and f_y)
769
+ flow = flow[:,:2].detach().clone()
770
+
771
+ # Unnormalize optical flow
772
+ flow = rearrange(flow, 'b c f h w -> b f c h w')
773
+ flow = adaptive_unnormalize(flow, flow_scale_factor[0], flow_scale_factor[1])
774
+
775
+ flow = rearrange(flow, 'b f c h w -> (b f) c h w')
776
+ return flow # BF,C,H,W
777
+
778
+ def adaptive_normalize(flow, sf_x, sf_y):
779
+ # x: BFCHW, optical flow
780
+ assert flow.ndim == 5, 'Set the shape of the flow input as (B, F, C, H, W)'
781
+ assert sf_x is not None and sf_y is not None
782
+ b, f, c, h, w = flow.shape
783
+
784
+ max_clip_x = math.sqrt(w/sf_x) * 1.0
785
+ max_clip_y = math.sqrt(h/sf_y) * 1.0
786
+
787
+ flow_norm = flow.detach().clone()
788
+ flow_x = flow[:, :, 0].detach().clone()
789
+ flow_y = flow[:, :, 1].detach().clone()
790
+
791
+ flow_x_norm = torch.sign(flow_x) * torch.sqrt(torch.abs(flow_x)/sf_x + 1e-7)
792
+ flow_y_norm = torch.sign(flow_y) * torch.sqrt(torch.abs(flow_y)/sf_y + 1e-7)
793
+
794
+ flow_norm[:, :, 0] = torch.clamp(flow_x_norm, min=-max_clip_x, max=max_clip_x)
795
+ flow_norm[:, :, 1] = torch.clamp(flow_y_norm, min=-max_clip_y, max=max_clip_y)
796
+
797
+ return flow_norm
798
+
799
+
800
+ def adaptive_unnormalize(flow, sf_x, sf_y):
801
+ # x: BFCHW, optical flow
802
+ assert flow.ndim == 5, 'Set the shape of the flow input as (B, F, C, H, W)'
803
+ assert sf_x is not None and sf_y is not None
804
+
805
+ flow_orig = flow.detach().clone()
806
+ flow_x = flow[:, :, 0].detach().clone()
807
+ flow_y = flow[:, :, 1].detach().clone()
808
+
809
+ flow_orig[:, :, 0] = torch.sign(flow_x) * sf_x * (flow_x**2 - 1e-7)
810
+ flow_orig[:, :, 1] = torch.sign(flow_y) * sf_y * (flow_y**2 - 1e-7)
811
+
812
+ return flow_orig
813
+
814
+ #--------------------------------------------------------------------------------------------------
finetune/models/cogvideox_i2v/lora_trainer.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Tuple
2
+
3
+ import torch
4
+ from diffusers import (
5
+ AutoencoderKLCogVideoX,
6
+ CogVideoXDPMScheduler,
7
+ CogVideoXImageToVideoPipeline,
8
+ CogVideoXTransformer3DModel,
9
+ )
10
+ from diffusers.models.embeddings import get_3d_rotary_pos_embed
11
+ from PIL import Image
12
+ from numpy import dtype
13
+ from transformers import AutoTokenizer, T5EncoderModel
14
+ from typing_extensions import override
15
+
16
+ from finetune.schemas import Components
17
+ from finetune.trainer import Trainer
18
+ from finetune.utils import unwrap_model
19
+
20
+ from ..utils import register
21
+
22
+
23
+ class CogVideoXI2VLoraTrainer(Trainer):
24
+ UNLOAD_LIST = ["text_encoder"]
25
+
26
+ @override
27
+ def load_components(self) -> Dict[str, Any]:
28
+ components = Components()
29
+ model_path = str(self.args.model_path)
30
+
31
+ components.pipeline_cls = CogVideoXImageToVideoPipeline
32
+
33
+ components.tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder="tokenizer")
34
+
35
+ components.text_encoder = T5EncoderModel.from_pretrained(model_path, subfolder="text_encoder")
36
+
37
+ components.transformer = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder="transformer")
38
+
39
+ components.vae = AutoencoderKLCogVideoX.from_pretrained(model_path, subfolder="vae")
40
+
41
+ components.scheduler = CogVideoXDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
42
+
43
+ return components
44
+
45
+ @override
46
+ def initialize_pipeline(self) -> CogVideoXImageToVideoPipeline:
47
+ pipe = CogVideoXImageToVideoPipeline(
48
+ tokenizer=self.components.tokenizer,
49
+ text_encoder=self.components.text_encoder,
50
+ vae=self.components.vae,
51
+ transformer=unwrap_model(self.accelerator, self.components.transformer),
52
+ scheduler=self.components.scheduler,
53
+ )
54
+ return pipe
55
+
56
+ @override
57
+ def encode_video(self, video: torch.Tensor) -> torch.Tensor:
58
+ # shape of input video: [B, C, F, H, W]
59
+ vae = self.components.vae
60
+ video = video.to(vae.device, dtype=vae.dtype)
61
+ latent_dist = vae.encode(video).latent_dist
62
+ latent = latent_dist.sample() * vae.config.scaling_factor
63
+ return latent
64
+
65
+ @override
66
+ def encode_text(self, prompt: str) -> torch.Tensor:
67
+ prompt_token_ids = self.components.tokenizer(
68
+ prompt,
69
+ padding="max_length",
70
+ max_length=self.state.transformer_config.max_text_seq_length,
71
+ truncation=True,
72
+ add_special_tokens=True,
73
+ return_tensors="pt",
74
+ )
75
+ prompt_token_ids = prompt_token_ids.input_ids
76
+ prompt_embedding = self.components.text_encoder(prompt_token_ids.to(self.accelerator.device))[0]
77
+ return prompt_embedding
78
+
79
+ @override
80
+ def collate_fn(self, samples: List[Dict[str, Any]]) -> Dict[str, Any]:
81
+ ret = {"encoded_videos": [], "prompt_embedding": [], "images": []}
82
+
83
+ for sample in samples:
84
+ encoded_video = sample["encoded_video"]
85
+ prompt_embedding = sample["prompt_embedding"]
86
+ image = sample["image"]
87
+
88
+ ret["encoded_videos"].append(encoded_video)
89
+ ret["prompt_embedding"].append(prompt_embedding)
90
+ ret["images"].append(image)
91
+
92
+ ret["encoded_videos"] = torch.stack(ret["encoded_videos"])
93
+ ret["prompt_embedding"] = torch.stack(ret["prompt_embedding"])
94
+ ret["images"] = torch.stack(ret["images"])
95
+
96
+ return ret
97
+
98
+ @override
99
+ def compute_loss(self, batch) -> torch.Tensor:
100
+ prompt_embedding = batch["prompt_embedding"]
101
+ latent = batch["encoded_videos"]
102
+ images = batch["images"]
103
+
104
+ # Shape of prompt_embedding: [B, seq_len, hidden_size]
105
+ # Shape of latent: [B, C, F, H, W]
106
+ # Shape of images: [B, C, H, W]
107
+
108
+ patch_size_t = self.state.transformer_config.patch_size_t
109
+ if patch_size_t is not None:
110
+ ncopy = latent.shape[2] % patch_size_t
111
+ # Copy the first frame ncopy times to match patch_size_t
112
+ first_frame = latent[:, :, :1, :, :] # Get first frame [B, C, 1, H, W]
113
+ latent = torch.cat([first_frame.repeat(1, 1, ncopy, 1, 1), latent], dim=2)
114
+ assert latent.shape[2] % patch_size_t == 0
115
+
116
+ batch_size, num_channels, num_frames, height, width = latent.shape
117
+
118
+ # Get prompt embeddings
119
+ _, seq_len, _ = prompt_embedding.shape
120
+ prompt_embedding = prompt_embedding.view(batch_size, seq_len, -1).to(dtype=latent.dtype)
121
+
122
+ # Add frame dimension to images [B,C,H,W] -> [B,C,F,H,W]
123
+ images = images.unsqueeze(2)
124
+ # Add noise to images
125
+ image_noise_sigma = torch.normal(mean=-3.0, std=0.5, size=(1,), device=self.accelerator.device)
126
+ image_noise_sigma = torch.exp(image_noise_sigma).to(dtype=images.dtype)
127
+ noisy_images = images + torch.randn_like(images) * image_noise_sigma[:, None, None, None, None]
128
+ image_latent_dist = self.components.vae.encode(noisy_images.to(dtype=self.components.vae.dtype)).latent_dist
129
+ image_latents = image_latent_dist.sample() * self.components.vae.config.scaling_factor
130
+
131
+ # Sample a random timestep for each sample
132
+ timesteps = torch.randint(
133
+ 0, self.components.scheduler.config.num_train_timesteps, (batch_size,), device=self.accelerator.device
134
+ )
135
+ timesteps = timesteps.long()
136
+
137
+ # from [B, C, F, H, W] to [B, F, C, H, W]
138
+ latent = latent.permute(0, 2, 1, 3, 4)
139
+ image_latents = image_latents.permute(0, 2, 1, 3, 4)
140
+ assert (latent.shape[0], *latent.shape[2:]) == (image_latents.shape[0], *image_latents.shape[2:])
141
+
142
+ # Padding image_latents to the same frame number as latent
143
+ padding_shape = (latent.shape[0], latent.shape[1] - 1, *latent.shape[2:])
144
+ latent_padding = image_latents.new_zeros(padding_shape)
145
+ image_latents = torch.cat([image_latents, latent_padding], dim=1)
146
+
147
+ # Add noise to latent
148
+ noise = torch.randn_like(latent)
149
+ latent_noisy = self.components.scheduler.add_noise(latent, noise, timesteps)
150
+
151
+ # Concatenate latent and image_latents in the channel dimension
152
+ latent_img_noisy = torch.cat([latent_noisy, image_latents], dim=2)
153
+
154
+ # Prepare rotary embeds
155
+ vae_scale_factor_spatial = 2 ** (len(self.components.vae.config.block_out_channels) - 1)
156
+ transformer_config = self.state.transformer_config
157
+ rotary_emb = (
158
+ self.prepare_rotary_positional_embeddings(
159
+ height=height * vae_scale_factor_spatial,
160
+ width=width * vae_scale_factor_spatial,
161
+ num_frames=num_frames,
162
+ transformer_config=transformer_config,
163
+ vae_scale_factor_spatial=vae_scale_factor_spatial,
164
+ device=self.accelerator.device,
165
+ )
166
+ if transformer_config.use_rotary_positional_embeddings
167
+ else None
168
+ )
169
+
170
+ # Predict noise, For CogVideoX1.5 Only.
171
+ ofs_emb = (
172
+ None if self.state.transformer_config.ofs_embed_dim is None else latent.new_full((1,), fill_value=2.0)
173
+ )
174
+ predicted_noise = self.components.transformer(
175
+ hidden_states=latent_img_noisy,
176
+ encoder_hidden_states=prompt_embedding,
177
+ timestep=timesteps,
178
+ ofs=ofs_emb,
179
+ image_rotary_emb=rotary_emb,
180
+ return_dict=False,
181
+ )[0]
182
+
183
+ # Denoise
184
+ latent_pred = self.components.scheduler.get_velocity(predicted_noise, latent_noisy, timesteps)
185
+
186
+ alphas_cumprod = self.components.scheduler.alphas_cumprod[timesteps]
187
+ weights = 1 / (1 - alphas_cumprod)
188
+ while len(weights.shape) < len(latent_pred.shape):
189
+ weights = weights.unsqueeze(-1)
190
+
191
+ loss = torch.mean((weights * (latent_pred - latent) ** 2).reshape(batch_size, -1), dim=1)
192
+ loss = loss.mean()
193
+
194
+ return loss
195
+
196
+ @override
197
+ def validation_step(
198
+ self, eval_data: Dict[str, Any], pipe: CogVideoXImageToVideoPipeline
199
+ ) -> List[Tuple[str, Image.Image | List[Image.Image]]]:
200
+ """
201
+ Return the data that needs to be saved. For videos, the data format is List[PIL],
202
+ and for images, the data format is PIL
203
+ """
204
+ prompt, image, video = eval_data["prompt"], eval_data["image"], eval_data["video"]
205
+
206
+ video_generate = pipe(
207
+ num_frames=self.state.train_frames,
208
+ height=self.state.train_height,
209
+ width=self.state.train_width,
210
+ prompt=prompt,
211
+ image=image,
212
+ generator=self.state.generator,
213
+ ).frames[0]
214
+ return [("video", video_generate)]
215
+
216
+ def prepare_rotary_positional_embeddings(
217
+ self,
218
+ height: int,
219
+ width: int,
220
+ num_frames: int,
221
+ transformer_config: Dict,
222
+ vae_scale_factor_spatial: int,
223
+ device: torch.device,
224
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
225
+ grid_height = height // (vae_scale_factor_spatial * transformer_config.patch_size)
226
+ grid_width = width // (vae_scale_factor_spatial * transformer_config.patch_size)
227
+
228
+ if transformer_config.patch_size_t is None:
229
+ base_num_frames = num_frames
230
+ else:
231
+ base_num_frames = (num_frames + transformer_config.patch_size_t - 1) // transformer_config.patch_size_t
232
+
233
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
234
+ embed_dim=transformer_config.attention_head_dim,
235
+ crops_coords=None,
236
+ grid_size=(grid_height, grid_width),
237
+ temporal_size=base_num_frames,
238
+ grid_type="slice",
239
+ max_size=(grid_height, grid_width),
240
+ device=device,
241
+ )
242
+
243
+ return freqs_cos, freqs_sin
244
+
245
+
246
+ register("cogvideox-i2v", "lora", CogVideoXI2VLoraTrainer)
finetune/models/cogvideox_i2v/sft_trainer.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from ..cogvideox_i2v.lora_trainer import CogVideoXI2VLoraTrainer
2
+ from ..utils import register
3
+
4
+
5
+ class CogVideoXI2VSftTrainer(CogVideoXI2VLoraTrainer):
6
+ pass
7
+
8
+
9
+ register("cogvideox-i2v", "sft", CogVideoXI2VSftTrainer)
finetune/models/utils.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Literal
2
+
3
+ from finetune.trainer import Trainer
4
+
5
+
6
+ SUPPORTED_MODELS: Dict[str, Dict[str, Trainer]] = {}
7
+
8
+
9
+ def register(model_name: str, training_type: Literal["lora", "sft", "controlnet"], trainer_cls: Trainer):
10
+ """Register a model and its associated functions for a specific training type.
11
+
12
+ Args:
13
+ model_name (str): Name of the model to register (e.g. "cogvideox-5b")
14
+ training_type (Literal["lora", "sft", "controlnet"]): Type of training - either "lora" or "sft" or "controlnet"
15
+ trainer_cls (Trainer): Trainer class to register.
16
+ """
17
+
18
+ # Check if model_name and training_type exists in SUPPORTED_MODELS
19
+ if model_name not in SUPPORTED_MODELS:
20
+ SUPPORTED_MODELS[model_name] = {}
21
+ else:
22
+ if training_type in SUPPORTED_MODELS[model_name]:
23
+ raise ValueError(f"Training type {training_type} already exists for model {model_name}")
24
+
25
+ SUPPORTED_MODELS[model_name][training_type] = trainer_cls
26
+
27
+
28
+ def show_supported_models():
29
+ """Print all currently supported models and their training types."""
30
+
31
+ print("\nSupported Models:")
32
+ print("================")
33
+
34
+ for model_name, training_types in SUPPORTED_MODELS.items():
35
+ print(f"\n{model_name}")
36
+ print("-" * len(model_name))
37
+ for training_type in training_types:
38
+ print(f" • {training_type}")
39
+
40
+
41
+ def get_model_cls(model_type: str, training_type: Literal["lora", "sft"]) -> Trainer:
42
+ """Get the trainer class for a specific model and training type."""
43
+ if model_type not in SUPPORTED_MODELS:
44
+ print(f"\nModel '{model_type}' is not supported.")
45
+ print("\nSupported models are:")
46
+ for supported_model in SUPPORTED_MODELS:
47
+ print(f" • {supported_model}")
48
+ raise ValueError(f"Model '{model_type}' is not supported")
49
+
50
+ if training_type not in SUPPORTED_MODELS[model_type]:
51
+ print(f"\nTraining type '{training_type}' is not supported for model '{model_type}'.")
52
+ print(f"\nSupported training types for '{model_type}' are:")
53
+ for supported_type in SUPPORTED_MODELS[model_type]:
54
+ print(f" • {supported_type}")
55
+ raise ValueError(f"Training type '{training_type}' is not supported for model '{model_type}'")
56
+
57
+ return SUPPORTED_MODELS[model_type][training_type]
finetune/modules/__init__.py ADDED
File without changes
finetune/modules/camera_flow_generator.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange, repeat
6
+
7
+ from .utils import instantiate_from_config, get_camera_flow_generator_input, warp_image
8
+
9
+ import pdb
10
+
11
+ class CameraFlowGenerator(nn.Module):
12
+ def __init__(
13
+ self,
14
+ depth_estimator_kwargs,
15
+ use_observed_mask=False,
16
+ cycle_th=3.,
17
+ ):
18
+ super().__init__()
19
+
20
+ self.depth_warping_module = instantiate_from_config(depth_estimator_kwargs)
21
+ self.use_observed_mask = use_observed_mask
22
+ self.cycle_th = cycle_th
23
+
24
+ def forward(self, condition_image, camera_flow_generator_input):
25
+ # NOTE. camera_flow_generator_input is a dict of network inputs!
26
+ # camera_flow_generator_input: Dict
27
+ # - image
28
+ # - intrinsics
29
+ # - extrinsics
30
+ with torch.no_grad():
31
+ flow_f, flow_b, depth_warped_frames, depth_ctxt, depth_trgt = self.depth_warping_module(camera_flow_generator_input)
32
+ image_ctxt = repeat(condition_image, "b c h w -> (b v) c h w", v=(depth_warped_frames.shape[0]//condition_image.shape[0]))
33
+
34
+ log_dict = {
35
+ 'depth_warped_frames': depth_warped_frames,
36
+ 'depth_ctxt': depth_ctxt,
37
+ 'depth_trgt': depth_trgt,
38
+ }
39
+
40
+ # if self.use_observed_mask:
41
+ # observed_mask = run_filtering(flow_f, flow_b, cycle_th=self.cycle_th)
42
+ # log_dict[
43
+ # 'observed_mask': observed_mask
44
+ # ]
45
+
46
+ return flow_f, log_dict
finetune/modules/camera_sampler.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from glob import glob
3
+ import random
4
+ import os
5
+ import pdb
6
+ random.seed(7777)
7
+
8
+ class SampleManualCam:
9
+ def __init__(
10
+ self,
11
+ pose_type = 'manual',
12
+ root_path = '../assets/manual_poses',
13
+ ):
14
+ self.root_path = root_path
15
+ if pose_type == 'manual':
16
+ self.MANUAL_CAM = ['I', 'D', 'L', 'O', 'R', 'U']
17
+ elif pose_type == 're10k':
18
+ self.RE10K_CAM = os.listdir(root_path)
19
+ # self.pose_path = glob(root_path, "*.txt")
20
+
21
+ self.pose_type = pose_type
22
+
23
+ def sample(self, order=None, name=None):
24
+ # Sample camera parameters (W2C)
25
+
26
+ if self.pose_type == 'manual':
27
+ if name is not None:
28
+ assert name in self.MANUAL_CAM
29
+ cam_name = name
30
+ elif order is not None:
31
+ order = order % len(self.MANUAL_CAM)
32
+ cam_name = self.MANUAL_CAM[order]
33
+ else:
34
+ cam_name = random.choice(self.MANUAL_CAM)
35
+ path = os.path.join(self.root_path, f"camera_{cam_name}.txt")
36
+ elif self.pose_type == 're10k':
37
+ if name is not None:
38
+ assert name in self.RE10K_CAM
39
+ cam_name = name
40
+ elif order is not None:
41
+ order = order % len(self.RE10K_CAM)
42
+ cam_name = self.RE10K_CAM[order]
43
+ else:
44
+ cam_name = random.choice(self.RE10K_CAM)
45
+ path = os.path.join(self.root_path, cam_name)
46
+ with open(path, 'r') as f:
47
+ poses = f.readlines()
48
+
49
+ poses = [pose.strip().split(' ') for pose in poses]
50
+ poses = [[float(x) for x in pose] for pose in poses]
51
+
52
+ return poses, cam_name
finetune/modules/cogvideox_controlnet.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Tuple, Union
2
+
3
+ import torch
4
+ from torch import nn
5
+ from einops import rearrange
6
+ import torch.nn.functional as F
7
+
8
+ from diffusers.configuration_utils import FrozenDict
9
+
10
+ from diffusers import CogVideoXTransformer3DModel
11
+ from diffusers.models.transformers.cogvideox_transformer_3d import Transformer2DModelOutput, CogVideoXBlock
12
+ from diffusers.utils import is_torch_version
13
+ from diffusers.loaders import PeftAdapterMixin
14
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
15
+ from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
16
+ from diffusers.models.modeling_utils import ModelMixin
17
+ from diffusers.models.attention import Attention, FeedForward
18
+ from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor2_0
19
+ from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero, AdaLayerNormZeroSingle
20
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
21
+
22
+ from .cogvideox_custom_modules import CustomCogVideoXPatchEmbed, CustomCogVideoXBlock
23
+
24
+ import pdb
25
+
26
+ class CogVideoXControlnet(ModelMixin, ConfigMixin, PeftAdapterMixin):
27
+ _supports_gradient_checkpointing = True
28
+
29
+ @register_to_config
30
+ def __init__(
31
+ self,
32
+ num_attention_heads: int = 30, # 48 for 5B, 30 for 2B.
33
+ attention_head_dim: int = 64,
34
+ # in_channels: int = 3,
35
+ in_channels: int = 16,
36
+ out_channels: Optional[int] = 16, # Not used
37
+ flip_sin_to_cos: bool = True,
38
+ freq_shift: int = 0,
39
+ time_embed_dim: int = 512,
40
+ ofs_embed_dim: Optional[int] = None,
41
+ text_embed_dim: int = 4096,
42
+ num_layers: int = 30,
43
+ dropout: float = 0.0,
44
+ attention_bias: bool = True,
45
+ sample_width: int = 90,
46
+ sample_height: int = 60,
47
+ sample_frames: int = 49,
48
+ patch_size: int = 2,
49
+ patch_size_t: Optional[int] = None,
50
+ temporal_compression_ratio: int = 4,
51
+ max_text_seq_length: int = 226,
52
+ activation_fn: str = "gelu-approximate",
53
+ timestep_activation_fn: str = "silu",
54
+ norm_elementwise_affine: bool = True,
55
+ norm_eps: float = 1e-5,
56
+ spatial_interpolation_scale: float = 1.875,
57
+ temporal_interpolation_scale: float = 1.0,
58
+ use_rotary_positional_embeddings: bool = False,
59
+ use_learned_positional_embeddings: bool = False,
60
+ patch_bias: bool = True,
61
+ out_proj_dim_factor: int = 8,
62
+ out_proj_dim_zero_init: bool = True,
63
+ notextinflow: bool = False,
64
+ ):
65
+ super().__init__()
66
+ inner_dim = num_attention_heads * attention_head_dim
67
+
68
+ self.notextinflow = notextinflow
69
+
70
+ if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
71
+ raise ValueError(
72
+ "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
73
+ "embeddings. If you're using a custom model and/or believe this should be supported, please open an "
74
+ "issue at https://github.com/huggingface/diffusers/issues."
75
+ )
76
+
77
+ """
78
+ Delete below.
79
+ In our case, FloVD, controlnet_hidden_states is already flow_latents encoded by 3D-Causal-VAE
80
+ """
81
+ # start_channels = in_channels * (downscale_coef ** 2)
82
+ # input_channels = [start_channels, start_channels // 2, start_channels // 4]
83
+ # self.unshuffle = nn.PixelUnshuffle(downscale_coef)
84
+
85
+ # self.controlnet_encode_first = nn.Sequential(
86
+ # nn.Conv2d(input_channels[0], input_channels[1], kernel_size=1, stride=1, padding=0),
87
+ # nn.GroupNorm(2, input_channels[1]),
88
+ # nn.ReLU(),
89
+ # )
90
+
91
+ # self.controlnet_encode_second = nn.Sequential(
92
+ # nn.Conv2d(input_channels[1], input_channels[2], kernel_size=1, stride=1, padding=0),
93
+ # nn.GroupNorm(2, input_channels[2]),
94
+ # nn.ReLU(),
95
+ # )
96
+
97
+ # """
98
+ # Modify below.
99
+ # In our case, patch_embed takes encoder_hidden_states, hidden_states, controlnet_hidden_states (flow)
100
+ # """
101
+ # 1. Patch embedding
102
+ self.patch_embed = CogVideoXPatchEmbed(
103
+ patch_size=patch_size,
104
+ in_channels=in_channels,
105
+ embed_dim=inner_dim,
106
+ bias=True,
107
+ sample_width=sample_width,
108
+ sample_height=sample_height,
109
+ sample_frames=sample_frames,
110
+ temporal_compression_ratio=temporal_compression_ratio,
111
+ spatial_interpolation_scale=spatial_interpolation_scale,
112
+ temporal_interpolation_scale=temporal_interpolation_scale,
113
+ use_positional_embeddings=not use_rotary_positional_embeddings,
114
+ use_learned_positional_embeddings=use_learned_positional_embeddings,
115
+ )
116
+ # self.patch_embed = CustomCogVideoXPatchEmbed(
117
+ # patch_size=patch_size,
118
+ # patch_size_t=patch_size_t,
119
+ # in_channels=in_channels,
120
+ # embed_dim=inner_dim,
121
+ # text_embed_dim=text_embed_dim,
122
+ # bias=patch_bias,
123
+ # sample_width=sample_width,
124
+ # sample_height=sample_height,
125
+ # sample_frames=sample_frames,
126
+ # temporal_compression_ratio=temporal_compression_ratio,
127
+ # max_text_seq_length=max_text_seq_length,
128
+ # spatial_interpolation_scale=spatial_interpolation_scale,
129
+ # temporal_interpolation_scale=temporal_interpolation_scale,
130
+ # use_positional_embeddings=not use_rotary_positional_embeddings,
131
+ # use_learned_positional_embeddings=use_learned_positional_embeddings,
132
+ # )
133
+
134
+ self.embedding_dropout = nn.Dropout(dropout)
135
+
136
+ # 2. Time embeddings
137
+ self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
138
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
139
+
140
+ # 3. Define spatio-temporal transformers blocks
141
+ # self.transformer_blocks = nn.ModuleList(
142
+ # [
143
+ # CogVideoXBlock(
144
+ # dim=inner_dim,
145
+ # num_attention_heads=num_attention_heads,
146
+ # attention_head_dim=attention_head_dim,
147
+ # time_embed_dim=time_embed_dim,
148
+ # dropout=dropout,
149
+ # activation_fn=activation_fn,
150
+ # attention_bias=attention_bias,
151
+ # norm_elementwise_affine=norm_elementwise_affine,
152
+ # norm_eps=norm_eps,
153
+ # )
154
+ # for _ in range(num_layers)
155
+ # ]
156
+ # )
157
+ self.transformer_blocks = nn.ModuleList(
158
+ [
159
+ CustomCogVideoXBlock(
160
+ dim=inner_dim,
161
+ num_attention_heads=num_attention_heads,
162
+ attention_head_dim=attention_head_dim,
163
+ time_embed_dim=time_embed_dim,
164
+ dropout=dropout,
165
+ activation_fn=activation_fn,
166
+ attention_bias=attention_bias,
167
+ norm_elementwise_affine=norm_elementwise_affine,
168
+ norm_eps=norm_eps,
169
+ )
170
+ for _ in range(num_layers)
171
+ ]
172
+ )
173
+
174
+ self.out_projectors = None
175
+ if out_proj_dim_factor is not None:
176
+ out_proj_dim = num_attention_heads * out_proj_dim_factor
177
+ self.out_projectors = nn.ModuleList(
178
+ [nn.Linear(inner_dim, out_proj_dim) for _ in range(num_layers)]
179
+ )
180
+ if out_proj_dim_zero_init:
181
+ for out_projector in self.out_projectors:
182
+ self.zeros_init_linear(out_projector)
183
+
184
+ self.gradient_checkpointing = False
185
+
186
+ def zeros_init_linear(self, linear: nn.Module):
187
+ if isinstance(linear, (nn.Linear, nn.Conv1d)):
188
+ if hasattr(linear, "weight"):
189
+ nn.init.zeros_(linear.weight)
190
+ if hasattr(linear, "bias"):
191
+ nn.init.zeros_(linear.bias)
192
+
193
+ def _set_gradient_checkpointing(self, module, value=False):
194
+ self.gradient_checkpointing = value
195
+
196
+ def compress_time(self, x, num_frames):
197
+ x = rearrange(x, '(b f) c h w -> b f c h w', f=num_frames)
198
+ batch_size, frames, channels, height, width = x.shape
199
+ x = rearrange(x, 'b f c h w -> (b h w) c f')
200
+
201
+ if x.shape[-1] % 2 == 1:
202
+ x_first, x_rest = x[..., 0], x[..., 1:]
203
+ if x_rest.shape[-1] > 0:
204
+ x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
205
+
206
+ x = torch.cat([x_first[..., None], x_rest], dim=-1)
207
+ else:
208
+ x = F.avg_pool1d(x, kernel_size=2, stride=2)
209
+ x = rearrange(x, '(b h w) c f -> (b f) c h w', b=batch_size, h=height, w=width)
210
+ return x
211
+
212
+ # """
213
+ # Add below.
214
+ # Load pre-trained weight from Diffusers
215
+ # For patch_embed, copy a projection layer for controlnet_states
216
+ # """
217
+ @classmethod
218
+ def from_pretrained(cls, model_path, subfolder, **additional_kwargs):
219
+ base = CogVideoXTransformer3DModel.from_pretrained(model_path, subfolder=subfolder)
220
+ controlnet_config = FrozenDict({**base.config, **additional_kwargs})
221
+ model = cls(**controlnet_config)
222
+
223
+ missing, unexpected = model.load_state_dict(base.state_dict(), strict=False)
224
+ print(f"Load CogVideoXTransformer3DModel.")
225
+ # if len(missing) != 0 or len(unexpected) != 0:
226
+ # print(f"Missing keys: {missing}")
227
+ # print(f"Unexpected keys: {unexpected}")
228
+
229
+ del base
230
+ torch.cuda.empty_cache()
231
+
232
+
233
+ return model
234
+
235
+ def forward(
236
+ self,
237
+ hidden_states: torch.Tensor,
238
+ encoder_hidden_states: torch.Tensor,
239
+ controlnet_hidden_states: torch.Tensor,
240
+ timestep: Union[int, float, torch.LongTensor],
241
+ controlnet_valid_mask: torch.Tensor = None,
242
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
243
+ timestep_cond: Optional[torch.Tensor] = None,
244
+ return_dict: bool = True,
245
+ ):
246
+ """
247
+ Delete below.
248
+ In our case, FloVD, controlnet_hidden_states is already flow_latents encoded by 3D-Causal-VAE
249
+ """
250
+ # batch_size, num_frames, channels, height, width = controlnet_states.shape
251
+ # # 0. Controlnet encoder
252
+ # controlnet_states = rearrange(controlnet_states, 'b f c h w -> (b f) c h w')
253
+ # controlnet_states = self.unshuffle(controlnet_states)
254
+ # controlnet_states = self.controlnet_encode_first(controlnet_states)
255
+ # controlnet_states = self.compress_time(controlnet_states, num_frames=num_frames)
256
+ # num_frames = controlnet_states.shape[0] // batch_size
257
+
258
+ # controlnet_states = self.controlnet_encode_second(controlnet_states)
259
+ # controlnet_states = self.compress_time(controlnet_states, num_frames=num_frames)
260
+ # controlnet_states = rearrange(controlnet_states, '(b f) c h w -> b f c h w', b=batch_size)
261
+
262
+ batch_size, num_frames, channels, height, width = hidden_states.shape
263
+
264
+
265
+ # """
266
+ # Modify below.
267
+ # Distinguish hidden_states and controlnet_states (i.e., flow_hidden_states)
268
+ # """
269
+ hidden_states = torch.cat([hidden_states, controlnet_hidden_states], dim=2) # instead of image_latents, we use flow_latents for condition.
270
+
271
+ # controlnet_states = self.controlnext_encoder(controlnet_states, timestep=timestep)
272
+ # 1. Time embedding
273
+ timesteps = timestep
274
+ t_emb = self.time_proj(timesteps)
275
+
276
+ # timesteps does not contain any weights and will always return f32 tensors
277
+ # but time_embedding might actually be running in fp16. so we need to cast here.
278
+ # there might be better ways to encapsulate this.
279
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
280
+ emb = self.time_embedding(t_emb, timestep_cond)
281
+
282
+ # """
283
+ # Modify below.
284
+ # patch_embed takes encoder, hidden_states, controlnet_hidden_states
285
+ # """
286
+ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
287
+ # hidden_states = self.patch_embed(encoder_hidden_states, hidden_states, controlnet_hidden_states) # output: [text_embeds, image_embeds, flow_embeds] [B, 35326, 3072]
288
+ hidden_states = self.embedding_dropout(hidden_states)
289
+
290
+ """
291
+ Not modified below.
292
+ hidden_states include both hidden_states and controlnet_hidden_states
293
+ """
294
+ text_seq_length = encoder_hidden_states.shape[1]
295
+ encoder_hidden_states = hidden_states[:, :text_seq_length] # [text_embeds] [B, 226, 3072]
296
+ hidden_states = hidden_states[:, text_seq_length:] # [image_embeds, flow_embeds] [B, 35100, 3072]
297
+
298
+ # attention mask
299
+ if controlnet_valid_mask is not None:
300
+ mask_shape = controlnet_valid_mask.shape
301
+ attention_mask = torch.nn.functional.interpolate(controlnet_valid_mask, size=(mask_shape[2], mask_shape[3]//2, mask_shape[4]//2), mode='trilinear', align_corners=False) # CFHW
302
+ attention_mask[attention_mask>=0.5] = 1
303
+ attention_mask[attention_mask<0.5] = 0
304
+ attention_mask = attention_mask.to(torch.bool)
305
+ attention_mask = rearrange(attention_mask.squeeze(1), 'b f h w -> b (f h w)') # (B, N=(fxhxw))
306
+
307
+ # Consider encoder_hidden_states.. or do not use?? not sure..
308
+ if not self.notextinflow:
309
+ attention_mask = F.pad(attention_mask, (text_seq_length, 0), value=0.0)
310
+
311
+ attention_kwargs = {
312
+ 'attention_mask': attention_mask if controlnet_valid_mask is not None else None,
313
+ 'notextinflow': self.notextinflow,
314
+ }
315
+
316
+ controlnet_hidden_states = ()
317
+ # 3. Transformer blocks
318
+ for i, block in enumerate(self.transformer_blocks):
319
+ if self.training and self.gradient_checkpointing:
320
+
321
+ def create_custom_forward(module):
322
+ def custom_forward(*inputs):
323
+ return module(*inputs)
324
+
325
+ return custom_forward
326
+
327
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
328
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
329
+ create_custom_forward(block),
330
+ hidden_states,
331
+ encoder_hidden_states,
332
+ emb,
333
+ image_rotary_emb,
334
+ attention_kwargs,
335
+ **ckpt_kwargs,
336
+ )
337
+ else:
338
+ hidden_states, encoder_hidden_states = block(
339
+ hidden_states=hidden_states,
340
+ encoder_hidden_states=encoder_hidden_states,
341
+ temb=emb,
342
+ image_rotary_emb=image_rotary_emb,
343
+ attention_kwargs=attention_kwargs,
344
+ )
345
+
346
+ if self.out_projectors is not None:
347
+ controlnet_hidden_states += (self.out_projectors[i](hidden_states),)
348
+ else:
349
+ controlnet_hidden_states += (hidden_states,)
350
+
351
+ if not return_dict:
352
+ return (controlnet_hidden_states,)
353
+ return Transformer2DModelOutput(sample=controlnet_hidden_states)
finetune/modules/cogvideox_custom_model.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import numpy as np
5
+ from diffusers.utils import is_torch_version
6
+ from diffusers.models.transformers.cogvideox_transformer_3d import CogVideoXTransformer3DModel, Transformer2DModelOutput
7
+
8
+ import pdb
9
+
10
+ class CustomCogVideoXTransformer3DModel(CogVideoXTransformer3DModel):
11
+ def forward(
12
+ self,
13
+ hidden_states: torch.Tensor,
14
+ encoder_hidden_states: torch.Tensor,
15
+ timestep: Union[int, float, torch.LongTensor],
16
+ start_frame = None,
17
+ timestep_cond: Optional[torch.Tensor] = None,
18
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
19
+ controlnet_states: torch.Tensor = None,
20
+ controlnet_weights: Optional[Union[float, int, list, np.ndarray, torch.FloatTensor]] = 1.0,
21
+ return_dict: bool = True,
22
+ ):
23
+ batch_size, num_frames, channels, height, width = hidden_states.shape
24
+
25
+ if start_frame is not None:
26
+ hidden_states = torch.cat([start_frame, hidden_states], dim=2)
27
+ # 1. Time embedding
28
+ timesteps = timestep
29
+ t_emb = self.time_proj(timesteps)
30
+
31
+ # timesteps does not contain any weights and will always return f32 tensors
32
+ # but time_embedding might actually be running in fp16. so we need to cast here.
33
+ # there might be better ways to encapsulate this.
34
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
35
+ emb = self.time_embedding(t_emb, timestep_cond)
36
+
37
+ # 2. Patch embedding
38
+ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
39
+ hidden_states = self.embedding_dropout(hidden_states)
40
+
41
+ text_seq_length = encoder_hidden_states.shape[1]
42
+ encoder_hidden_states = hidden_states[:, :text_seq_length]
43
+ hidden_states = hidden_states[:, text_seq_length:]
44
+
45
+ # 3. Transformer blocks
46
+ for i, block in enumerate(self.transformer_blocks):
47
+ if self.training and self.gradient_checkpointing:
48
+
49
+ def create_custom_forward(module):
50
+ def custom_forward(*inputs):
51
+ return module(*inputs)
52
+
53
+ return custom_forward
54
+
55
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
56
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
57
+ create_custom_forward(block),
58
+ hidden_states,
59
+ encoder_hidden_states,
60
+ emb,
61
+ image_rotary_emb,
62
+ **ckpt_kwargs,
63
+ )
64
+ else:
65
+ hidden_states, encoder_hidden_states = block(
66
+ hidden_states=hidden_states,
67
+ encoder_hidden_states=encoder_hidden_states,
68
+ temb=emb,
69
+ image_rotary_emb=image_rotary_emb,
70
+ )
71
+
72
+ if (controlnet_states is not None) and (i < len(controlnet_states)):
73
+ controlnet_states_block = controlnet_states[i]
74
+ controlnet_block_weight = 1.0
75
+ if isinstance(controlnet_weights, (list, np.ndarray)) or torch.is_tensor(controlnet_weights):
76
+ controlnet_block_weight = controlnet_weights[i]
77
+ elif isinstance(controlnet_weights, (float, int)):
78
+ controlnet_block_weight = controlnet_weights
79
+ hidden_states = hidden_states + controlnet_states_block * controlnet_block_weight
80
+
81
+ if not self.config.use_rotary_positional_embeddings:
82
+ # CogVideoX-2B
83
+ hidden_states = self.norm_final(hidden_states)
84
+ else:
85
+ # CogVideoX-5B
86
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
87
+ hidden_states = self.norm_final(hidden_states)
88
+ hidden_states = hidden_states[:, text_seq_length:]
89
+
90
+ # 4. Final block
91
+ hidden_states = self.norm_out(hidden_states, temb=emb)
92
+ hidden_states = self.proj_out(hidden_states)
93
+
94
+ # 5. Unpatchify
95
+ p = self.config.patch_size
96
+ p_t = self.config.patch_size_t
97
+
98
+ if p_t is None:
99
+ output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
100
+ output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
101
+ else:
102
+ output = hidden_states.reshape(
103
+ batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
104
+ )
105
+ output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
106
+
107
+ if not return_dict:
108
+ return (output,)
109
+ return Transformer2DModelOutput(sample=output)
finetune/modules/cogvideox_custom_modules.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional, Tuple, Union, Dict, Any
3
+ import copy
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+ from diffusers import CogVideoXTransformer3DModel
10
+ from diffusers.models.transformers.cogvideox_transformer_3d import CogVideoXBlock
11
+ from diffusers.models.normalization import CogVideoXLayerNormZero
12
+ from diffusers.models.attention import FeedForward
13
+ from diffusers.models.attention_processor import CogVideoXAttnProcessor2_0, Attention
14
+ from diffusers.models.embeddings import CogVideoXPatchEmbed
15
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
16
+ from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
17
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
18
+
19
+ from contextlib import contextmanager
20
+ from peft.tuners.lora.layer import LoraLayer # PEFT의 LoRA 레이어 기본 클래스
21
+
22
+ import pdb
23
+
24
+ # Code heavily borrowed from https://github.com/huggingface/diffusers
25
+
26
+
27
+ class enable_lora:
28
+ def __init__(self, modules, enable=True):
29
+ self.modules = modules
30
+ self.enable = enable
31
+ self.prev_states = {}
32
+
33
+ def __enter__(self):
34
+ for module in self.modules:
35
+ self.prev_states[module] = getattr(module, "lora_enabled", True)
36
+ setattr(module, "lora_enabled", self.enable)
37
+ return self
38
+
39
+ def __exit__(self, exc_type, exc_val, exc_tb):
40
+ for module in self.modules:
41
+ setattr(module, "lora_enabled", self.prev_states[module])
42
+ return False
43
+
44
+
45
+
46
+ class CustomCogVideoXPatchEmbed(CogVideoXPatchEmbed):
47
+ def __init__(self, **kwargs):
48
+ super().__init__(**kwargs)
49
+
50
+ patch_size = kwargs['patch_size']
51
+ patch_size_t = kwargs['patch_size_t']
52
+ bias = kwargs['bias']
53
+ in_channels = kwargs['in_channels']
54
+ embed_dim = kwargs['embed_dim']
55
+
56
+ # projection layer for flow latents
57
+ if patch_size_t is None:
58
+ # CogVideoX 1.0 checkpoints
59
+ self.flow_proj = nn.Conv2d(in_channels//2, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias)
60
+ else:
61
+ # CogVideoX 1.5 checkpoints
62
+ self.flow_proj = nn.Linear(in_channels//2 * patch_size * patch_size * patch_size_t, embed_dim)
63
+
64
+ # Add positional embedding for flow_embeds
65
+ if self.use_positional_embeddings or self.use_learned_positional_embeddings:
66
+ flow_pos_embedding = self._get_positional_embeddings(self.sample_height, self.sample_width, self.sample_frames)[:,self.max_text_seq_length:] # shape: [1, 17550, 3072]
67
+ self.flow_pos_embedding = nn.Parameter(flow_pos_embedding)
68
+
69
+ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor, flow_embeds: torch.Tensor):
70
+ r"""
71
+ Args:
72
+ text_embeds (`torch.Tensor`):
73
+ Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
74
+ image_embeds (`torch.Tensor`):
75
+ Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
76
+ flow_embeds (`torch.Tensor`):
77
+ Input flow embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
78
+ """
79
+ text_embeds = self.text_proj(text_embeds)
80
+
81
+ batch_size, num_frames, channels, height, width = image_embeds.shape
82
+
83
+ if self.patch_size_t is None:
84
+ # embed video latents
85
+ image_embeds = image_embeds.reshape(-1, channels, height, width)
86
+ image_embeds = self.proj(image_embeds)
87
+ image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
88
+ image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
89
+ image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
90
+
91
+ # embed flow latents
92
+ flow_embeds = flow_embeds.reshape(-1, channels//2, height, width)
93
+ flow_embeds = self.flow_proj(flow_embeds)
94
+ flow_embeds = flow_embeds.view(batch_size, num_frames, *flow_embeds.shape[1:])
95
+ flow_embeds = flow_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
96
+ flow_embeds = flow_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
97
+ else:
98
+ p = self.patch_size
99
+ p_t = self.patch_size_t
100
+
101
+ # embed video latents
102
+ image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
103
+ image_embeds = image_embeds.reshape(
104
+ batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
105
+ )
106
+ image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
107
+ image_embeds = self.proj(image_embeds)
108
+
109
+ # embed flow latents
110
+ flow_embeds = flow_embeds.permute(0, 1, 3, 4, 2)
111
+ flow_embeds = flow_embeds.reshape(
112
+ batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels//2
113
+ )
114
+ flow_embeds = flow_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
115
+ flow_embeds = self.flow_proj(flow_embeds)
116
+
117
+ # Curriculum learning of flow token
118
+ # flow_embeds = self.flow_scale * flow_embeds
119
+
120
+
121
+ embeds = torch.cat(
122
+ [text_embeds, image_embeds, flow_embeds], dim=1
123
+ ).contiguous() # [batch, num_frames x height x width + seq_length + num_frames x height x width, channels]
124
+
125
+ if self.use_positional_embeddings or self.use_learned_positional_embeddings:
126
+ if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height):
127
+ raise ValueError(
128
+ "It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'."
129
+ "If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues."
130
+ )
131
+
132
+ pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
133
+
134
+ if (
135
+ self.sample_height != height
136
+ or self.sample_width != width
137
+ or self.sample_frames != pre_time_compression_frames
138
+ ):
139
+ pos_embedding = self._get_positional_embeddings(
140
+ height, width, pre_time_compression_frames, device=embeds.device
141
+ )
142
+ else:
143
+ pos_embedding = self.pos_embedding
144
+
145
+ # Previous version..
146
+ # pos_embedding = pos_embedding.to(dtype=embeds.dtype)
147
+ # embeds = embeds + pos_embedding
148
+
149
+ # Add flow embedding..
150
+ # flow_pos_embedding = self.flow_pos_scale * self.flow_pos_embedding
151
+ flow_pos_embedding = self.flow_pos_embedding
152
+ pos_embedding_total = torch.cat([pos_embedding, flow_pos_embedding], dim=1).to(dtype=embeds.dtype)
153
+ embeds = embeds + pos_embedding_total
154
+
155
+ return embeds
156
+
157
+
158
+
159
+ @maybe_allow_in_graph
160
+ class CustomCogVideoXBlock(nn.Module):
161
+ r"""
162
+ Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
163
+
164
+ Parameters:
165
+ dim (`int`):
166
+ The number of channels in the input and output.
167
+ num_attention_heads (`int`):
168
+ The number of heads to use for multi-head attention.
169
+ attention_head_dim (`int`):
170
+ The number of channels in each head.
171
+ time_embed_dim (`int`):
172
+ The number of channels in timestep embedding.
173
+ dropout (`float`, defaults to `0.0`):
174
+ The dropout probability to use.
175
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
176
+ Activation function to be used in feed-forward.
177
+ attention_bias (`bool`, defaults to `False`):
178
+ Whether or not to use bias in attention projection layers.
179
+ qk_norm (`bool`, defaults to `True`):
180
+ Whether or not to use normalization after query and key projections in Attention.
181
+ norm_elementwise_affine (`bool`, defaults to `True`):
182
+ Whether to use learnable elementwise affine parameters for normalization.
183
+ norm_eps (`float`, defaults to `1e-5`):
184
+ Epsilon value for normalization layers.
185
+ final_dropout (`bool` defaults to `False`):
186
+ Whether to apply a final dropout after the last feed-forward layer.
187
+ ff_inner_dim (`int`, *optional*, defaults to `None`):
188
+ Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
189
+ ff_bias (`bool`, defaults to `True`):
190
+ Whether or not to use bias in Feed-forward layer.
191
+ attention_out_bias (`bool`, defaults to `True`):
192
+ Whether or not to use bias in Attention output projection layer.
193
+ """
194
+
195
+ def __init__(
196
+ self,
197
+ dim: int,
198
+ num_attention_heads: int,
199
+ attention_head_dim: int,
200
+ time_embed_dim: int,
201
+ dropout: float = 0.0,
202
+ activation_fn: str = "gelu-approximate",
203
+ attention_bias: bool = False,
204
+ qk_norm: bool = True,
205
+ norm_elementwise_affine: bool = True,
206
+ norm_eps: float = 1e-5,
207
+ final_dropout: bool = True,
208
+ ff_inner_dim: Optional[int] = None,
209
+ ff_bias: bool = True,
210
+ attention_out_bias: bool = True,
211
+ ):
212
+ super().__init__()
213
+
214
+ # 1. Self Attention
215
+ self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
216
+
217
+ self.attn1 = Attention(
218
+ query_dim=dim,
219
+ dim_head=attention_head_dim,
220
+ heads=num_attention_heads,
221
+ qk_norm="layer_norm" if qk_norm else None,
222
+ eps=1e-6,
223
+ bias=attention_bias,
224
+ out_bias=attention_out_bias,
225
+ processor=CustomCogVideoXAttnProcessor2_0(),
226
+ )
227
+
228
+ # 2. Feed Forward
229
+ self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
230
+
231
+ self.ff = FeedForward(
232
+ dim,
233
+ dropout=dropout,
234
+ activation_fn=activation_fn,
235
+ final_dropout=final_dropout,
236
+ inner_dim=ff_inner_dim,
237
+ bias=ff_bias,
238
+ )
239
+
240
+ def forward(
241
+ self,
242
+ hidden_states: torch.Tensor,
243
+ encoder_hidden_states: torch.Tensor,
244
+ temb: torch.Tensor,
245
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
246
+ attention_kwargs: Optional[Dict[str, Any]] = None,
247
+ ) -> torch.Tensor:
248
+ text_seq_length = encoder_hidden_states.size(1)
249
+ attention_kwargs = attention_kwargs or {}
250
+
251
+ # norm & modulate
252
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
253
+ hidden_states, encoder_hidden_states, temb
254
+ )
255
+
256
+ # attention
257
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
258
+ hidden_states=norm_hidden_states,
259
+ encoder_hidden_states=norm_encoder_hidden_states,
260
+ image_rotary_emb=image_rotary_emb,
261
+ **attention_kwargs,
262
+ )
263
+
264
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
265
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
266
+
267
+ # norm & modulate
268
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
269
+ hidden_states, encoder_hidden_states, temb
270
+ )
271
+
272
+ # feed-forward
273
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
274
+ ff_output = self.ff(norm_hidden_states)
275
+
276
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
277
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
278
+
279
+ return hidden_states, encoder_hidden_states
280
+
281
+
282
+ class CustomCogVideoXAttnProcessor2_0:
283
+ r"""
284
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
285
+ query and key vectors, but does not include spatial normalization.
286
+ """
287
+
288
+ def __init__(self):
289
+ if not hasattr(F, "scaled_dot_product_attention"):
290
+ raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
291
+
292
+ def __call__(
293
+ self,
294
+ attn: Attention,
295
+ hidden_states: torch.Tensor,
296
+ encoder_hidden_states: torch.Tensor,
297
+ attention_mask: Optional[torch.Tensor] = None,
298
+ image_rotary_emb: Optional[torch.Tensor] = None,
299
+ notextinflow: Optional[bool] = False,
300
+ ) -> torch.Tensor:
301
+ text_seq_length = encoder_hidden_states.size(1)
302
+
303
+ if not notextinflow:
304
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
305
+
306
+ batch_size, sequence_length, _ = hidden_states.shape
307
+
308
+ if attention_mask is not None:
309
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
310
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
311
+
312
+ query = attn.to_q(hidden_states)
313
+ key = attn.to_k(hidden_states)
314
+ value = attn.to_v(hidden_states)
315
+
316
+ inner_dim = key.shape[-1]
317
+ head_dim = inner_dim // attn.heads
318
+
319
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
320
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
321
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
322
+
323
+ if attn.norm_q is not None:
324
+ query = attn.norm_q(query)
325
+ if attn.norm_k is not None:
326
+ key = attn.norm_k(key)
327
+
328
+ # Apply RoPE if needed
329
+ if image_rotary_emb is not None:
330
+ from diffusers.models.embeddings import apply_rotary_emb
331
+
332
+ if not notextinflow:
333
+ query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
334
+ if not attn.is_cross_attention:
335
+ key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
336
+ else:
337
+ query[:, :, :] = apply_rotary_emb(query[:, :, :], image_rotary_emb)
338
+ if not attn.is_cross_attention:
339
+ key[:, :, :] = apply_rotary_emb(key[:, :, :], image_rotary_emb)
340
+
341
+ hidden_states = F.scaled_dot_product_attention(
342
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
343
+ )
344
+
345
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
346
+
347
+ # linear proj
348
+ hidden_states = attn.to_out[0](hidden_states)
349
+ # dropout
350
+ hidden_states = attn.to_out[1](hidden_states)
351
+
352
+ if not notextinflow:
353
+ encoder_hidden_states, hidden_states = hidden_states.split(
354
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
355
+ )
356
+
357
+ return hidden_states, encoder_hidden_states
finetune/modules/depth_warping/__init__.py ADDED
File without changes
finetune/modules/depth_warping/camera/Camera.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import json
4
+ import torch
5
+
6
+ import torch.nn as nn
7
+ import torchvision.transforms as transforms
8
+ import torchvision.transforms.functional as F
9
+ import numpy as np
10
+ from einops import rearrange, repeat
11
+
12
+ class Camera(object):
13
+ def __init__(self, entry):
14
+ fx, fy, cx, cy = entry[1:5]
15
+ self.fx = fx
16
+ self.fy = fy
17
+ self.cx = cx
18
+ self.cy = cy
19
+ w2c_mat = np.array(entry[7:]).reshape(3, 4)
20
+ w2c_mat_4x4 = np.eye(4)
21
+ w2c_mat_4x4[:3, :] = w2c_mat
22
+ self.w2c_mat = w2c_mat_4x4
23
+ self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
24
+
25
+ def load_cameras(path):
26
+ with open(path, 'r') as f:
27
+ poses = f.readlines()
28
+ poses = [pose.strip().split(' ') for pose in poses[1:]]
29
+ cam_params = [[float(x) for x in pose] for pose in poses]
30
+ cam_params = [Camera(cam_param) for cam_param in cam_params]
31
+ return cam_params
32
+
33
+ def get_relative_pose(cam_params):
34
+ abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
35
+ abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
36
+ source_cam_c2w = abs_c2ws[0]
37
+ cam_to_origin = 0
38
+ target_cam_c2w = np.array([
39
+ [1, 0, 0, 0],
40
+ [0, 1, 0, -cam_to_origin],
41
+ [0, 0, 1, 0],
42
+ [0, 0, 0, 1]
43
+ ])
44
+ abs2rel = target_cam_c2w @ abs_w2cs[0]
45
+ ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
46
+ ret_poses = np.array(ret_poses, dtype=np.float32)
47
+ return ret_poses
48
+
49
+ def get_K(intrinsics, size, do_normalize=False):
50
+ def normalize_intrinsic(x, size):
51
+ h, w = size
52
+ x[:,:,0:1] = x[:,:,0:1] / w
53
+ x[:,:,1:2] = x[:,:,1:2] / h
54
+ return x
55
+
56
+ b, _, t, _ = intrinsics.shape
57
+ K = torch.zeros((b, t, 9), dtype=intrinsics.dtype, device=intrinsics.device)
58
+ fx, fy, cx, cy = intrinsics.squeeze(1).chunk(4, dim=-1)
59
+
60
+ K[:,:,0:1] = fx
61
+ K[:,:,2:3] = cx
62
+ K[:,:,4:5] = fy
63
+ K[:,:,5:6] = cy
64
+ K[:,:,8:9] = 1.0
65
+
66
+ K = rearrange(K, "b t (h w) -> b t h w", h=3, w=3)
67
+ if do_normalize:
68
+ K = normalize_intrinsic(K, size)
69
+
70
+ return K
finetune/modules/depth_warping/camera/WarperPytorch.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Shree KRISHNAya Namaha
2
+ # Differentiable warper implemented in PyTorch. Warping is done on batches.
3
+ # Tested on PyTorch 1.8.1
4
+ # Author: Nagabhushan S N
5
+ # Last Modified: 27/09/2021
6
+ # Code from https://github.com/NagabhushanSN95/Pose-Warping
7
+
8
+ import datetime
9
+ import time
10
+ import traceback
11
+ from pathlib import Path
12
+ from typing import Tuple, Optional
13
+
14
+ import numpy
15
+ # import skimage.io
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from einops import rearrange, repeat
19
+ # import Imath
20
+ # import OpenEXR
21
+
22
+ import pdb
23
+
24
+ class Warper:
25
+ def __init__(self, resolution: tuple = None):
26
+ self.resolution = resolution
27
+
28
+ def forward_warp(self, frame1: torch.Tensor, mask1: Optional[torch.Tensor], depth1: torch.Tensor,
29
+ transformation1: torch.Tensor, transformation2: torch.Tensor, intrinsic1: torch.Tensor,
30
+ intrinsic2: Optional[torch.Tensor], is_image=True) -> \
31
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
32
+ """
33
+ Given a frame1 and global transformations transformation1 and transformation2, warps frame1 to next view using
34
+ bilinear splatting.
35
+ All arrays should be torch tensors with batch dimension and channel first
36
+ :param frame1: (b, 3, h, w). If frame1 is not in the range [-1, 1], either set is_image=False when calling
37
+ bilinear_splatting on frame within this function, or modify clipping in bilinear_splatting()
38
+ method accordingly.
39
+ :param mask1: (b, 1, h, w) - 1 for known, 0 for unknown. Optional
40
+ :param depth1: (b, 1, h, w)
41
+ :param transformation1: (b, 4, 4) extrinsic transformation matrix of first view: [R, t; 0, 1]
42
+ :param transformation2: (b, 4, 4) extrinsic transformation matrix of second view: [R, t; 0, 1]
43
+ :param intrinsic1: (b, 3, 3) camera intrinsic matrix
44
+ :param intrinsic2: (b, 3, 3) camera intrinsic matrix. Optional
45
+ """
46
+ self.device = frame1.device
47
+
48
+ if self.resolution is not None:
49
+ assert frame1.shape[2:4] == self.resolution
50
+ b, c, h, w = frame1.shape
51
+ if mask1 is None:
52
+ mask1 = torch.ones(size=(b, 1, h, w)).to(frame1)
53
+ if intrinsic2 is None:
54
+ intrinsic2 = intrinsic1.clone()
55
+
56
+ assert frame1.shape == (b, 3, h, w) or frame1.shape == (b, 2, h, w) # flow b2hw
57
+ assert mask1.shape == (b, 1, h, w)
58
+ assert depth1.shape == (b, 1, h, w)
59
+ assert transformation1.shape == (b, 4, 4)
60
+ assert transformation2.shape == (b, 4, 4)
61
+ assert intrinsic1.shape == (b, 3, 3)
62
+ assert intrinsic2.shape == (b, 3, 3)
63
+
64
+ frame1 = frame1.to(self.device)
65
+ mask1 = mask1.to(self.device)
66
+ depth1 = depth1.to(self.device)
67
+ transformation1 = transformation1.to(self.device)
68
+ transformation2 = transformation2.to(self.device)
69
+ intrinsic1 = intrinsic1.to(self.device)
70
+ intrinsic2 = intrinsic2.to(self.device)
71
+
72
+ trans_points1 = self.compute_transformed_points(depth1, transformation1, transformation2, intrinsic1,
73
+ intrinsic2)
74
+ # trans_coordinates = trans_points1[:, :, :2, 0] / trans_points1[:, :, 2:3, 0]
75
+ trans_coordinates = trans_points1[:, :, :, :2, 0] / (trans_points1[:, :, :, 2:3, 0]+1e-7)
76
+ trans_depth1 = rearrange(trans_points1[:, :, :, 2:3, 0], "b h w c -> b c h w")
77
+
78
+ grid = self.create_grid(b, h, w).to(trans_coordinates)
79
+ flow12 = rearrange(trans_coordinates, "b h w c -> b c h w") - grid
80
+
81
+ warped_frame2, mask2 = self.bilinear_splatting(frame1, mask1, trans_depth1, flow12, None, is_image=is_image)
82
+ warped_depth2 = self.bilinear_splatting(trans_depth1, mask1, trans_depth1, flow12, None, is_image=False)[0] # [0][:, :, 0]
83
+
84
+ return warped_frame2, mask2, warped_depth2, flow12
85
+
86
+ def forward_warp_displacement(self, depth1: torch.Tensor, flow1: torch.Tensor,
87
+ transformation1: torch.Tensor, transformation2: torch.Tensor, intrinsic1: torch.Tensor, intrinsic2: Optional[torch.Tensor],):
88
+ """
89
+ Given a frame1 and global transformations transformation1 and transformation2, warps frame1 to next view using
90
+ bilinear splatting.
91
+ All arrays should be torch tensors with batch dimension and channel first
92
+ :param depth1: (b, 1, h, w)
93
+ :param flow1: (b, 2, h, w)
94
+ :param transformation1: (b, 4, 4) extrinsic transformation matrix of first view: [R, t; 0, 1]
95
+ :param transformation2: (b, 4, 4) extrinsic transformation matrix of second view: [R, t; 0, 1]
96
+ :param intrinsic1: (b, 3, 3) camera intrinsic matrix
97
+ :param intrinsic2: (b, 3, 3) camera intrinsic matrix. Optional
98
+ """
99
+ self.device = flow1.device
100
+
101
+ if self.resolution is not None:
102
+ assert flow1.shape[2:4] == self.resolution
103
+ b, c, h, w = flow1.shape
104
+ if intrinsic2 is None:
105
+ intrinsic2 = intrinsic1.clone()
106
+
107
+ assert flow1.shape == (b, 2, h, w)
108
+ assert depth1.shape == (b, 1, h, w)
109
+ assert transformation1.shape == (b, 4, 4)
110
+ assert transformation2.shape == (b, 4, 4)
111
+ assert intrinsic1.shape == (b, 3, 3)
112
+ assert intrinsic2.shape == (b, 3, 3)
113
+
114
+ depth1 = depth1.to(self.device)
115
+ flow1 = flow1.to(self.device)
116
+ transformation1 = transformation1.to(self.device)
117
+ transformation2 = transformation2.to(self.device)
118
+ intrinsic1 = intrinsic1.to(self.device)
119
+ intrinsic2 = intrinsic2.to(self.device)
120
+
121
+ trans_points1 = self.compute_transformed_points(depth1, transformation1, transformation2, intrinsic1, intrinsic2)
122
+ trans_coordinates1 = trans_points1[:, :, :, :2, 0] / (trans_points1[:, :, :, 2:3, 0]+1e-7)
123
+
124
+ trans_points2 = self.compute_transformed_points(depth1, transformation1, transformation2, intrinsic1, intrinsic2, flow1)
125
+ trans_coordinates2 = trans_points2[:, :, :, :2, 0] / (trans_points2[:, :, :, 2:3, 0]+1e-7)
126
+
127
+ flow12_displacement = rearrange(trans_coordinates2 - trans_coordinates1, "b h w c -> b c h w")
128
+
129
+ return flow12_displacement
130
+
131
+ def compute_transformed_points(self, depth1: torch.Tensor, transformation1: torch.Tensor, transformation2: torch.Tensor,
132
+ intrinsic1: torch.Tensor, intrinsic2: Optional[torch.Tensor], flow1: Optional[torch.Tensor]=None):
133
+ """
134
+ Computes transformed position for each pixel location
135
+ """
136
+ if self.resolution is not None:
137
+ assert depth1.shape[2:4] == self.resolution
138
+ b, _, h, w = depth1.shape
139
+ if intrinsic2 is None:
140
+ intrinsic2 = intrinsic1.clone()
141
+ transformation = torch.bmm(transformation2, torch.linalg.inv(transformation1)).to(transformation1.dtype) # (b, 4, 4)
142
+
143
+ x1d = torch.arange(0, w)[None]
144
+ y1d = torch.arange(0, h)[:, None]
145
+ x2d = x1d.repeat([h, 1]).to(depth1) # (h, w)
146
+ y2d = y1d.repeat([1, w]).to(depth1) # (h, w)
147
+
148
+ ones_2d = torch.ones(size=(h, w)).to(depth1) # (h, w)
149
+ ones_4d = ones_2d[None, :, :, None, None].repeat([b, 1, 1, 1, 1]) # (b, h, w, 1, 1)
150
+
151
+
152
+ if flow1 is not None:
153
+ x4d = repeat(x2d[None, :, :, None], '1 h w c -> b h w c', b=b)
154
+ y4d = repeat(y2d[None, :, :, None], '1 h w c -> b h w c', b=b)
155
+ flow1_x4d = rearrange(flow1[:,:1].detach().clone(), "b c h w -> b h w c")
156
+ flow1_y4d = rearrange(flow1[:,1:].detach().clone(), "b c h w -> b h w c")
157
+
158
+ x4d = x4d + flow1_x4d
159
+ y4d = y4d + flow1_y4d
160
+
161
+ pos_vectors_homo = torch.stack([x4d, y4d, ones_4d.squeeze(-1)], dim=3) # (b, h, w, 3, 1)
162
+ else:
163
+ pos_vectors_homo = torch.stack([x2d, y2d, ones_2d], dim=2)[None, :, :, :, None] # (1, h, w, 3, 1)
164
+
165
+ intrinsic1_inv = torch.linalg.inv(intrinsic1) # (b, 3, 3)
166
+ intrinsic1_inv_4d = intrinsic1_inv[:, None, None] # (b, 1, 1, 3, 3)
167
+ intrinsic2_4d = intrinsic2[:, None, None] # (b, 1, 1, 3, 3)
168
+ depth_4d = depth1[:, 0][:, :, :, None, None] # (b, h, w, 1, 1)
169
+ trans_4d = transformation[:, None, None] # (b, 1, 1, 4, 4)
170
+
171
+ unnormalized_pos = torch.matmul(intrinsic1_inv_4d, pos_vectors_homo).to(transformation1.dtype) # (b, h, w, 3, 1)
172
+ world_points = depth_4d * unnormalized_pos # (b, h, w, 3, 1)
173
+ world_points_homo = torch.cat([world_points, ones_4d], dim=3) # (b, h, w, 4, 1)
174
+ trans_world_homo = torch.matmul(trans_4d, world_points_homo).to(transformation1.dtype) # (b, h, w, 4, 1)
175
+ trans_world = trans_world_homo[:, :, :, :3] # (b, h, w, 3, 1)
176
+ trans_norm_points = torch.matmul(intrinsic2_4d, trans_world).to(transformation1.dtype) # (b, h, w, 3, 1)
177
+ return trans_norm_points
178
+
179
+ def bilinear_splatting(self, frame1: torch.Tensor, mask1: Optional[torch.Tensor], depth1: torch.Tensor,
180
+ flow12: torch.Tensor, flow12_mask: Optional[torch.Tensor], is_image: bool = False) -> \
181
+ Tuple[torch.Tensor, torch.Tensor]:
182
+ """
183
+ Bilinear splatting
184
+ :param frame1: (b,c,h,w)
185
+ :param mask1: (b,1,h,w): 1 for known, 0 for unknown. Optional
186
+ :param depth1: (b,1,h,w)
187
+ :param flow12: (b,2,h,w)
188
+ :param flow12_mask: (b,1,h,w): 1 for valid flow, 0 for invalid flow. Optional
189
+ :param is_image: if true, output will be clipped to (-1,1) range
190
+ :return: warped_frame2: (b,c,h,w)
191
+ mask2: (b,1,h,w): 1 for known and 0 for unknown
192
+ """
193
+ if self.resolution is not None:
194
+ assert frame1.shape[2:4] == self.resolution
195
+ b, c, h, w = frame1.shape
196
+ if mask1 is None:
197
+ mask1 = torch.ones(size=(b, 1, h, w)).to(frame1)
198
+ if flow12_mask is None:
199
+ flow12_mask = torch.ones(size=(b, 1, h, w)).to(flow12)
200
+ grid = self.create_grid(b, h, w).to(frame1)
201
+ trans_pos = flow12 + grid
202
+
203
+ trans_pos_offset = trans_pos + 1
204
+ trans_pos_floor = torch.floor(trans_pos_offset).long()
205
+ trans_pos_ceil = torch.ceil(trans_pos_offset).long()
206
+ trans_pos_offset = torch.stack([
207
+ torch.clamp(trans_pos_offset[:, 0], min=0, max=w + 1),
208
+ torch.clamp(trans_pos_offset[:, 1], min=0, max=h + 1)], dim=1)
209
+ trans_pos_floor = torch.stack([
210
+ torch.clamp(trans_pos_floor[:, 0], min=0, max=w + 1),
211
+ torch.clamp(trans_pos_floor[:, 1], min=0, max=h + 1)], dim=1)
212
+ trans_pos_ceil = torch.stack([
213
+ torch.clamp(trans_pos_ceil[:, 0], min=0, max=w + 1),
214
+ torch.clamp(trans_pos_ceil[:, 1], min=0, max=h + 1)], dim=1)
215
+
216
+ prox_weight_nw = (1 - (trans_pos_offset[:, 1:2] - trans_pos_floor[:, 1:2])) * \
217
+ (1 - (trans_pos_offset[:, 0:1] - trans_pos_floor[:, 0:1]))
218
+ prox_weight_sw = (1 - (trans_pos_ceil[:, 1:2] - trans_pos_offset[:, 1:2])) * \
219
+ (1 - (trans_pos_offset[:, 0:1] - trans_pos_floor[:, 0:1]))
220
+ prox_weight_ne = (1 - (trans_pos_offset[:, 1:2] - trans_pos_floor[:, 1:2])) * \
221
+ (1 - (trans_pos_ceil[:, 0:1] - trans_pos_offset[:, 0:1]))
222
+ prox_weight_se = (1 - (trans_pos_ceil[:, 1:2] - trans_pos_offset[:, 1:2])) * \
223
+ (1 - (trans_pos_ceil[:, 0:1] - trans_pos_offset[:, 0:1]))
224
+
225
+ sat_depth1 = torch.clamp(depth1, min=0, max=1000)
226
+ log_depth1 = torch.log(1 + sat_depth1)
227
+ depth_weights = torch.exp(log_depth1 / log_depth1.max() * 50)
228
+
229
+ weight_nw = torch.moveaxis(prox_weight_nw * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3], [0, 3, 1, 2])
230
+ weight_sw = torch.moveaxis(prox_weight_sw * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3], [0, 3, 1, 2])
231
+ weight_ne = torch.moveaxis(prox_weight_ne * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3], [0, 3, 1, 2])
232
+ weight_se = torch.moveaxis(prox_weight_se * mask1 * flow12_mask / depth_weights, [0, 1, 2, 3], [0, 3, 1, 2])
233
+
234
+ warped_frame = torch.zeros(size=(b, h + 2, w + 2, c), dtype=torch.float32).to(frame1)
235
+ warped_weights = torch.zeros(size=(b, h + 2, w + 2, 1), dtype=torch.float32).to(frame1)
236
+
237
+ frame1_cl = torch.moveaxis(frame1, [0, 1, 2, 3], [0, 3, 1, 2])
238
+ batch_indices = torch.arange(b)[:, None, None].to(frame1.device)
239
+ warped_frame.index_put_((batch_indices, trans_pos_floor[:, 1], trans_pos_floor[:, 0]),
240
+ frame1_cl * weight_nw, accumulate=True)
241
+ warped_frame.index_put_((batch_indices, trans_pos_ceil[:, 1], trans_pos_floor[:, 0]),
242
+ frame1_cl * weight_sw, accumulate=True)
243
+ warped_frame.index_put_((batch_indices, trans_pos_floor[:, 1], trans_pos_ceil[:, 0]),
244
+ frame1_cl * weight_ne, accumulate=True)
245
+ warped_frame.index_put_((batch_indices, trans_pos_ceil[:, 1], trans_pos_ceil[:, 0]),
246
+ frame1_cl * weight_se, accumulate=True)
247
+
248
+ warped_weights.index_put_((batch_indices, trans_pos_floor[:, 1], trans_pos_floor[:, 0]),
249
+ weight_nw, accumulate=True)
250
+ warped_weights.index_put_((batch_indices, trans_pos_ceil[:, 1], trans_pos_floor[:, 0]),
251
+ weight_sw, accumulate=True)
252
+ warped_weights.index_put_((batch_indices, trans_pos_floor[:, 1], trans_pos_ceil[:, 0]),
253
+ weight_ne, accumulate=True)
254
+ warped_weights.index_put_((batch_indices, trans_pos_ceil[:, 1], trans_pos_ceil[:, 0]),
255
+ weight_se, accumulate=True)
256
+
257
+ warped_frame_cf = torch.moveaxis(warped_frame, [0, 1, 2, 3], [0, 2, 3, 1])
258
+ warped_weights_cf = torch.moveaxis(warped_weights, [0, 1, 2, 3], [0, 2, 3, 1])
259
+ cropped_warped_frame = warped_frame_cf[:, :, 1:-1, 1:-1]
260
+ cropped_weights = warped_weights_cf[:, :, 1:-1, 1:-1]
261
+
262
+ mask = cropped_weights > 0
263
+ zero_value = -1 if is_image else 0
264
+ zero_tensor = torch.tensor(zero_value, dtype=frame1.dtype, device=frame1.device)
265
+ warped_frame2 = torch.where(mask, cropped_warped_frame / cropped_weights, zero_tensor)
266
+ mask2 = mask.to(frame1)
267
+
268
+ if is_image:
269
+ assert warped_frame2.min() >= -1.1 # Allow for rounding errors
270
+ assert warped_frame2.max() <= 1.1
271
+ warped_frame2 = torch.clamp(warped_frame2, min=-1, max=1)
272
+ return warped_frame2, mask2
273
+
274
+ def bilinear_interpolation(self, frame2: torch.Tensor, mask2: Optional[torch.Tensor], flow12: torch.Tensor,
275
+ flow12_mask: Optional[torch.Tensor], is_image: bool = False) -> \
276
+ Tuple[torch.Tensor, torch.Tensor]:
277
+ """
278
+ Bilinear interpolation
279
+ :param frame2: (b, c, h, w)
280
+ :param mask2: (b, 1, h, w): 1 for known, 0 for unknown. Optional
281
+ :param flow12: (b, 2, h, w)
282
+ :param flow12_mask: (b, 1, h, w): 1 for valid flow, 0 for invalid flow. Optional
283
+ :param is_image: if true, output will be clipped to (-1,1) range
284
+ :return: warped_frame1: (b, c, h, w)
285
+ mask1: (b, 1, h, w): 1 for known and 0 for unknown
286
+ """
287
+ if self.resolution is not None:
288
+ assert frame2.shape[2:4] == self.resolution
289
+ b, c, h, w = frame2.shape
290
+ if mask2 is None:
291
+ mask2 = torch.ones(size=(b, 1, h, w)).to(frame2)
292
+ if flow12_mask is None:
293
+ flow12_mask = torch.ones(size=(b, 1, h, w)).to(flow12)
294
+ grid = self.create_grid(b, h, w).to(frame2)
295
+ trans_pos = flow12 + grid
296
+
297
+ trans_pos_offset = trans_pos + 1
298
+ trans_pos_floor = torch.floor(trans_pos_offset).long()
299
+ trans_pos_ceil = torch.ceil(trans_pos_offset).long()
300
+ trans_pos_offset = torch.stack([
301
+ torch.clamp(trans_pos_offset[:, 0], min=0, max=w + 1),
302
+ torch.clamp(trans_pos_offset[:, 1], min=0, max=h + 1)], dim=1)
303
+ trans_pos_floor = torch.stack([
304
+ torch.clamp(trans_pos_floor[:, 0], min=0, max=w + 1),
305
+ torch.clamp(trans_pos_floor[:, 1], min=0, max=h + 1)], dim=1)
306
+ trans_pos_ceil = torch.stack([
307
+ torch.clamp(trans_pos_ceil[:, 0], min=0, max=w + 1),
308
+ torch.clamp(trans_pos_ceil[:, 1], min=0, max=h + 1)], dim=1)
309
+
310
+ prox_weight_nw = (1 - (trans_pos_offset[:, 1:2] - trans_pos_floor[:, 1:2])) * \
311
+ (1 - (trans_pos_offset[:, 0:1] - trans_pos_floor[:, 0:1]))
312
+ prox_weight_sw = (1 - (trans_pos_ceil[:, 1:2] - trans_pos_offset[:, 1:2])) * \
313
+ (1 - (trans_pos_offset[:, 0:1] - trans_pos_floor[:, 0:1]))
314
+ prox_weight_ne = (1 - (trans_pos_offset[:, 1:2] - trans_pos_floor[:, 1:2])) * \
315
+ (1 - (trans_pos_ceil[:, 0:1] - trans_pos_offset[:, 0:1]))
316
+ prox_weight_se = (1 - (trans_pos_ceil[:, 1:2] - trans_pos_offset[:, 1:2])) * \
317
+ (1 - (trans_pos_ceil[:, 0:1] - trans_pos_offset[:, 0:1]))
318
+
319
+ weight_nw = torch.moveaxis(prox_weight_nw * flow12_mask, [0, 1, 2, 3], [0, 3, 1, 2])
320
+ weight_sw = torch.moveaxis(prox_weight_sw * flow12_mask, [0, 1, 2, 3], [0, 3, 1, 2])
321
+ weight_ne = torch.moveaxis(prox_weight_ne * flow12_mask, [0, 1, 2, 3], [0, 3, 1, 2])
322
+ weight_se = torch.moveaxis(prox_weight_se * flow12_mask, [0, 1, 2, 3], [0, 3, 1, 2])
323
+
324
+ frame2_offset = F.pad(frame2, [1, 1, 1, 1])
325
+ mask2_offset = F.pad(mask2, [1, 1, 1, 1])
326
+ bi = torch.arange(b)[:, None, None]
327
+
328
+ f2_nw = frame2_offset[bi, :, trans_pos_floor[:, 1], trans_pos_floor[:, 0]]
329
+ f2_sw = frame2_offset[bi, :, trans_pos_ceil[:, 1], trans_pos_floor[:, 0]]
330
+ f2_ne = frame2_offset[bi, :, trans_pos_floor[:, 1], trans_pos_ceil[:, 0]]
331
+ f2_se = frame2_offset[bi, :, trans_pos_ceil[:, 1], trans_pos_ceil[:, 0]]
332
+
333
+ m2_nw = mask2_offset[bi, :, trans_pos_floor[:, 1], trans_pos_floor[:, 0]]
334
+ m2_sw = mask2_offset[bi, :, trans_pos_ceil[:, 1], trans_pos_floor[:, 0]]
335
+ m2_ne = mask2_offset[bi, :, trans_pos_floor[:, 1], trans_pos_ceil[:, 0]]
336
+ m2_se = mask2_offset[bi, :, trans_pos_ceil[:, 1], trans_pos_ceil[:, 0]]
337
+
338
+ nr = weight_nw * f2_nw * m2_nw + weight_sw * f2_sw * m2_sw + \
339
+ weight_ne * f2_ne * m2_ne + weight_se * f2_se * m2_se
340
+ dr = weight_nw * m2_nw + weight_sw * m2_sw + weight_ne * m2_ne + weight_se * m2_se
341
+
342
+ zero_value = -1 if is_image else 0
343
+ zero_tensor = torch.tensor(zero_value, dtype=nr.dtype, device=nr.device)
344
+ warped_frame1 = torch.where(dr > 0, nr / dr, zero_tensor)
345
+ mask1 = (dr > 0).to(frame2)
346
+
347
+ # Convert to channel first
348
+ warped_frame1 = torch.moveaxis(warped_frame1, [0, 1, 2, 3], [0, 2, 3, 1])
349
+ mask1 = torch.moveaxis(mask1, [0, 1, 2, 3], [0, 2, 3, 1])
350
+
351
+ if is_image:
352
+ assert warped_frame1.min() >= -1.1 # Allow for rounding errors
353
+ assert warped_frame1.max() <= 1.1
354
+ warped_frame1 = torch.clamp(warped_frame1, min=-1, max=1)
355
+ return warped_frame1, mask1
356
+
357
+ @staticmethod
358
+ def create_grid(b, h, w):
359
+ x_1d = torch.arange(0, w)[None]
360
+ y_1d = torch.arange(0, h)[:, None]
361
+ x_2d = x_1d.repeat([h, 1])
362
+ y_2d = y_1d.repeat([1, w])
363
+ grid = torch.stack([x_2d, y_2d], dim=0)
364
+ batch_grid = grid[None].repeat([b, 1, 1, 1])
365
+ return batch_grid
366
+
367
+ # @staticmethod
368
+ # def read_image(path: Path) -> torch.Tensor:
369
+ # image = skimage.io.imread(path.as_posix())
370
+ # return image
371
+
372
+ # @staticmethod
373
+ # def read_depth(path: Path) -> torch.Tensor:
374
+ # if path.suffix == '.png':
375
+ # depth = skimage.io.imread(path.as_posix())
376
+ # elif path.suffix == '.npy':
377
+ # depth = numpy.load(path.as_posix())
378
+ # elif path.suffix == '.npz':
379
+ # with numpy.load(path.as_posix()) as depth_data:
380
+ # depth = depth_data['depth']
381
+ # elif path.suffix == '.exr':
382
+ # exr_file = OpenEXR.InputFile(path.as_posix())
383
+ # raw_bytes = exr_file.channel('B', Imath.PixelType(Imath.PixelType.FLOAT))
384
+ # depth_vector = numpy.frombuffer(raw_bytes, dtype=numpy.float32)
385
+ # height = exr_file.header()['displayWindow'].max.y + 1 - exr_file.header()['displayWindow'].min.y
386
+ # width = exr_file.header()['displayWindow'].max.x + 1 - exr_file.header()['displayWindow'].min.x
387
+ # depth = numpy.reshape(depth_vector, (height, width))
388
+ # else:
389
+ # raise RuntimeError(f'Unknown depth format: {path.suffix}')
390
+ # return depth
391
+
392
+ # @staticmethod
393
+ # def camera_intrinsic_transform(capture_width=1920, capture_height=1080, patch_start_point: tuple = (0, 0)):
394
+ # start_y, start_x = patch_start_point
395
+ # camera_intrinsics = numpy.eye(4)
396
+ # camera_intrinsics[0, 0] = 2100
397
+ # camera_intrinsics[0, 2] = capture_width / 2.0 - start_x
398
+ # camera_intrinsics[1, 1] = 2100
399
+ # camera_intrinsics[1, 2] = capture_height / 2.0 - start_y
400
+ # return camera_intrinsics
401
+
402
+ # @staticmethod
403
+ # def get_device(device: str):
404
+ # """
405
+ # Returns torch device object
406
+ # :param device: cpu/gpu0/gpu1
407
+ # :return:
408
+ # """
409
+ # if device == 'cpu':
410
+ # device = torch.device('cpu')
411
+ # elif device.startswith('gpu') and torch.cuda.is_available():
412
+ # gpu_num = int(device[3:])
413
+ # device = torch.device(f'cuda:{gpu_num}')
414
+ # else:
415
+ # device = torch.device('cpu')
416
+ # return device
finetune/modules/depth_warping/depth_anything_v2/depth_anything_wrapper.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from cameractrl.modules.depth_anything_v2.dpt import DepthAnythingV2
3
+
4
+ class MVSplat_wrapper(nn.Module):
5
+ def __init__(
6
+ self,
7
+ model_configs,
8
+ ckpt_path,
9
+ ):
10
+ super().__init__()
11
+
12
+ depth_anything = DepthAnythingV2(model_configs)
finetune/modules/depth_warping/depth_anything_v2/dinov2.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ from functools import partial
11
+ import math
12
+ import logging
13
+ from typing import Sequence, Tuple, Union, Callable
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.utils.checkpoint
18
+ from torch.nn.init import trunc_normal_
19
+
20
+ from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
21
+
22
+
23
+ logger = logging.getLogger("dinov2")
24
+
25
+
26
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
27
+ if not depth_first and include_root:
28
+ fn(module=module, name=name)
29
+ for child_name, child_module in module.named_children():
30
+ child_name = ".".join((name, child_name)) if name else child_name
31
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
32
+ if depth_first and include_root:
33
+ fn(module=module, name=name)
34
+ return module
35
+
36
+
37
+ class BlockChunk(nn.ModuleList):
38
+ def forward(self, x):
39
+ for b in self:
40
+ x = b(x)
41
+ return x
42
+
43
+
44
+ class DinoVisionTransformer(nn.Module):
45
+ def __init__(
46
+ self,
47
+ img_size=224,
48
+ patch_size=16,
49
+ in_chans=3,
50
+ embed_dim=768,
51
+ depth=12,
52
+ num_heads=12,
53
+ mlp_ratio=4.0,
54
+ qkv_bias=True,
55
+ ffn_bias=True,
56
+ proj_bias=True,
57
+ drop_path_rate=0.0,
58
+ drop_path_uniform=False,
59
+ init_values=None, # for layerscale: None or 0 => no layerscale
60
+ embed_layer=PatchEmbed,
61
+ act_layer=nn.GELU,
62
+ block_fn=Block,
63
+ ffn_layer="mlp",
64
+ block_chunks=1,
65
+ num_register_tokens=0,
66
+ interpolate_antialias=False,
67
+ interpolate_offset=0.1,
68
+ ):
69
+ """
70
+ Args:
71
+ img_size (int, tuple): input image size
72
+ patch_size (int, tuple): patch size
73
+ in_chans (int): number of input channels
74
+ embed_dim (int): embedding dimension
75
+ depth (int): depth of transformer
76
+ num_heads (int): number of attention heads
77
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
78
+ qkv_bias (bool): enable bias for qkv if True
79
+ proj_bias (bool): enable bias for proj in attn if True
80
+ ffn_bias (bool): enable bias for ffn if True
81
+ drop_path_rate (float): stochastic depth rate
82
+ drop_path_uniform (bool): apply uniform drop rate across blocks
83
+ weight_init (str): weight init scheme
84
+ init_values (float): layer-scale init values
85
+ embed_layer (nn.Module): patch embedding layer
86
+ act_layer (nn.Module): MLP activation layer
87
+ block_fn (nn.Module): transformer block class
88
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
89
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
90
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
91
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
92
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
93
+ """
94
+ super().__init__()
95
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
96
+
97
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
98
+ self.num_tokens = 1
99
+ self.n_blocks = depth
100
+ self.num_heads = num_heads
101
+ self.patch_size = patch_size
102
+ self.num_register_tokens = num_register_tokens
103
+ self.interpolate_antialias = interpolate_antialias
104
+ self.interpolate_offset = interpolate_offset
105
+
106
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
107
+ num_patches = self.patch_embed.num_patches
108
+
109
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
110
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
111
+ assert num_register_tokens >= 0
112
+ self.register_tokens = (
113
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
114
+ )
115
+
116
+ if drop_path_uniform is True:
117
+ dpr = [drop_path_rate] * depth
118
+ else:
119
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
120
+
121
+ if ffn_layer == "mlp":
122
+ logger.info("using MLP layer as FFN")
123
+ ffn_layer = Mlp
124
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
125
+ logger.info("using SwiGLU layer as FFN")
126
+ ffn_layer = SwiGLUFFNFused
127
+ elif ffn_layer == "identity":
128
+ logger.info("using Identity layer as FFN")
129
+
130
+ def f(*args, **kwargs):
131
+ return nn.Identity()
132
+
133
+ ffn_layer = f
134
+ else:
135
+ raise NotImplementedError
136
+
137
+ blocks_list = [
138
+ block_fn(
139
+ dim=embed_dim,
140
+ num_heads=num_heads,
141
+ mlp_ratio=mlp_ratio,
142
+ qkv_bias=qkv_bias,
143
+ proj_bias=proj_bias,
144
+ ffn_bias=ffn_bias,
145
+ drop_path=dpr[i],
146
+ norm_layer=norm_layer,
147
+ act_layer=act_layer,
148
+ ffn_layer=ffn_layer,
149
+ init_values=init_values,
150
+ )
151
+ for i in range(depth)
152
+ ]
153
+ if block_chunks > 0:
154
+ self.chunked_blocks = True
155
+ chunked_blocks = []
156
+ chunksize = depth // block_chunks
157
+ for i in range(0, depth, chunksize):
158
+ # this is to keep the block index consistent if we chunk the block list
159
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
160
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
161
+ else:
162
+ self.chunked_blocks = False
163
+ self.blocks = nn.ModuleList(blocks_list)
164
+
165
+ self.norm = norm_layer(embed_dim)
166
+ self.head = nn.Identity()
167
+
168
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
169
+
170
+ self.init_weights()
171
+
172
+ def init_weights(self):
173
+ trunc_normal_(self.pos_embed, std=0.02)
174
+ nn.init.normal_(self.cls_token, std=1e-6)
175
+ if self.register_tokens is not None:
176
+ nn.init.normal_(self.register_tokens, std=1e-6)
177
+ named_apply(init_weights_vit_timm, self)
178
+
179
+ def interpolate_pos_encoding(self, x, w, h):
180
+ previous_dtype = x.dtype
181
+ npatch = x.shape[1] - 1
182
+ N = self.pos_embed.shape[1] - 1
183
+ if npatch == N and w == h:
184
+ return self.pos_embed
185
+ pos_embed = self.pos_embed.float()
186
+ class_pos_embed = pos_embed[:, 0]
187
+ patch_pos_embed = pos_embed[:, 1:]
188
+ dim = x.shape[-1]
189
+ w0 = w // self.patch_size
190
+ h0 = h // self.patch_size
191
+ # we add a small number to avoid floating point error in the interpolation
192
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
193
+ # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0
194
+ w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
195
+ # w0, h0 = w0 + 0.1, h0 + 0.1
196
+
197
+ sqrt_N = math.sqrt(N)
198
+ sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
199
+ patch_pos_embed = nn.functional.interpolate(
200
+ patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
201
+ scale_factor=(sx, sy),
202
+ # (int(w0), int(h0)), # to solve the upsampling shape issue
203
+ mode="bicubic",
204
+ antialias=self.interpolate_antialias
205
+ )
206
+
207
+ assert int(w0) == patch_pos_embed.shape[-2]
208
+ assert int(h0) == patch_pos_embed.shape[-1]
209
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
210
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
211
+
212
+ def prepare_tokens_with_masks(self, x, masks=None):
213
+ B, nc, w, h = x.shape
214
+ x = self.patch_embed(x)
215
+ if masks is not None:
216
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
217
+
218
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
219
+ x = x + self.interpolate_pos_encoding(x, w, h)
220
+
221
+ if self.register_tokens is not None:
222
+ x = torch.cat(
223
+ (
224
+ x[:, :1],
225
+ self.register_tokens.expand(x.shape[0], -1, -1),
226
+ x[:, 1:],
227
+ ),
228
+ dim=1,
229
+ )
230
+
231
+ return x
232
+
233
+ def forward_features_list(self, x_list, masks_list):
234
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
235
+ for blk in self.blocks:
236
+ x = blk(x)
237
+
238
+ all_x = x
239
+ output = []
240
+ for x, masks in zip(all_x, masks_list):
241
+ x_norm = self.norm(x)
242
+ output.append(
243
+ {
244
+ "x_norm_clstoken": x_norm[:, 0],
245
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
246
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
247
+ "x_prenorm": x,
248
+ "masks": masks,
249
+ }
250
+ )
251
+ return output
252
+
253
+ def forward_features(self, x, masks=None):
254
+ if isinstance(x, list):
255
+ return self.forward_features_list(x, masks)
256
+
257
+ x = self.prepare_tokens_with_masks(x, masks)
258
+
259
+ for blk in self.blocks:
260
+ x = blk(x)
261
+
262
+ x_norm = self.norm(x)
263
+ return {
264
+ "x_norm_clstoken": x_norm[:, 0],
265
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
266
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
267
+ "x_prenorm": x,
268
+ "masks": masks,
269
+ }
270
+
271
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
272
+ x = self.prepare_tokens_with_masks(x)
273
+ # If n is an int, take the n last blocks. If it's a list, take them
274
+ output, total_block_len = [], len(self.blocks)
275
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
276
+ for i, blk in enumerate(self.blocks):
277
+ x = blk(x)
278
+ if i in blocks_to_take:
279
+ output.append(x)
280
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
281
+ return output
282
+
283
+ def _get_intermediate_layers_chunked(self, x, n=1):
284
+ x = self.prepare_tokens_with_masks(x)
285
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
286
+ # If n is an int, take the n last blocks. If it's a list, take them
287
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
288
+ for block_chunk in self.blocks:
289
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
290
+ x = blk(x)
291
+ if i in blocks_to_take:
292
+ output.append(x)
293
+ i += 1
294
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
295
+ return output
296
+
297
+ def get_intermediate_layers(
298
+ self,
299
+ x: torch.Tensor,
300
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
301
+ reshape: bool = False,
302
+ return_class_token: bool = False,
303
+ norm=True
304
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
305
+ if self.chunked_blocks:
306
+ outputs = self._get_intermediate_layers_chunked(x, n)
307
+ else:
308
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
309
+ if norm:
310
+ outputs = [self.norm(out) for out in outputs]
311
+ class_tokens = [out[:, 0] for out in outputs]
312
+ outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
313
+ if reshape:
314
+ B, _, w, h = x.shape
315
+ outputs = [
316
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
317
+ for out in outputs
318
+ ]
319
+ if return_class_token:
320
+ return tuple(zip(outputs, class_tokens))
321
+ return tuple(outputs)
322
+
323
+ def forward(self, *args, is_training=False, **kwargs):
324
+ ret = self.forward_features(*args, **kwargs)
325
+ if is_training:
326
+ return ret
327
+ else:
328
+ return self.head(ret["x_norm_clstoken"])
329
+
330
+
331
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
332
+ """ViT weight initialization, original timm impl (for reproducibility)"""
333
+ if isinstance(module, nn.Linear):
334
+ trunc_normal_(module.weight, std=0.02)
335
+ if module.bias is not None:
336
+ nn.init.zeros_(module.bias)
337
+
338
+
339
+ def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
340
+ model = DinoVisionTransformer(
341
+ patch_size=patch_size,
342
+ embed_dim=384,
343
+ depth=12,
344
+ num_heads=6,
345
+ mlp_ratio=4,
346
+ block_fn=partial(Block, attn_class=MemEffAttention),
347
+ num_register_tokens=num_register_tokens,
348
+ **kwargs,
349
+ )
350
+ return model
351
+
352
+
353
+ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
354
+ model = DinoVisionTransformer(
355
+ patch_size=patch_size,
356
+ embed_dim=768,
357
+ depth=12,
358
+ num_heads=12,
359
+ mlp_ratio=4,
360
+ block_fn=partial(Block, attn_class=MemEffAttention),
361
+ num_register_tokens=num_register_tokens,
362
+ **kwargs,
363
+ )
364
+ return model
365
+
366
+
367
+ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
368
+ model = DinoVisionTransformer(
369
+ patch_size=patch_size,
370
+ embed_dim=1024,
371
+ depth=24,
372
+ num_heads=16,
373
+ mlp_ratio=4,
374
+ block_fn=partial(Block, attn_class=MemEffAttention),
375
+ num_register_tokens=num_register_tokens,
376
+ **kwargs,
377
+ )
378
+ return model
379
+
380
+
381
+ def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
382
+ """
383
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
384
+ """
385
+ model = DinoVisionTransformer(
386
+ patch_size=patch_size,
387
+ embed_dim=1536,
388
+ depth=40,
389
+ num_heads=24,
390
+ mlp_ratio=4,
391
+ block_fn=partial(Block, attn_class=MemEffAttention),
392
+ num_register_tokens=num_register_tokens,
393
+ **kwargs,
394
+ )
395
+ return model
396
+
397
+
398
+ def DINOv2(model_name):
399
+ model_zoo = {
400
+ "vits": vit_small,
401
+ "vitb": vit_base,
402
+ "vitl": vit_large,
403
+ "vitg": vit_giant2
404
+ }
405
+
406
+ return model_zoo[model_name](
407
+ img_size=518,
408
+ patch_size=14,
409
+ init_values=1.0,
410
+ ffn_layer="mlp" if model_name != "vitg" else "swiglufused",
411
+ block_chunks=0,
412
+ num_register_tokens=0,
413
+ interpolate_antialias=False,
414
+ interpolate_offset=0.1
415
+ )
finetune/modules/depth_warping/depth_anything_v2/dinov2_layers/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .mlp import Mlp
8
+ from .patch_embed import PatchEmbed
9
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
10
+ from .block import NestedTensorBlock
11
+ from .attention import MemEffAttention
finetune/modules/depth_warping/depth_anything_v2/dinov2_layers/attention.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10
+
11
+ import logging
12
+
13
+ from torch import Tensor
14
+ from torch import nn
15
+
16
+
17
+ logger = logging.getLogger("dinov2")
18
+
19
+
20
+ try:
21
+ from xformers.ops import memory_efficient_attention, unbind, fmha
22
+
23
+ XFORMERS_AVAILABLE = True
24
+ except ImportError:
25
+ logger.warning("xFormers not available")
26
+ XFORMERS_AVAILABLE = False
27
+
28
+
29
+ class Attention(nn.Module):
30
+ def __init__(
31
+ self,
32
+ dim: int,
33
+ num_heads: int = 8,
34
+ qkv_bias: bool = False,
35
+ proj_bias: bool = True,
36
+ attn_drop: float = 0.0,
37
+ proj_drop: float = 0.0,
38
+ ) -> None:
39
+ super().__init__()
40
+ self.num_heads = num_heads
41
+ head_dim = dim // num_heads
42
+ self.scale = head_dim**-0.5
43
+
44
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
45
+ self.attn_drop = nn.Dropout(attn_drop)
46
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
47
+ self.proj_drop = nn.Dropout(proj_drop)
48
+
49
+ def forward(self, x: Tensor) -> Tensor:
50
+ B, N, C = x.shape
51
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
52
+
53
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
54
+ attn = q @ k.transpose(-2, -1)
55
+
56
+ attn = attn.softmax(dim=-1)
57
+ attn = self.attn_drop(attn)
58
+
59
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
60
+ x = self.proj(x)
61
+ x = self.proj_drop(x)
62
+ return x
63
+
64
+
65
+ class MemEffAttention(Attention):
66
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
67
+ if not XFORMERS_AVAILABLE:
68
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
69
+ return super().forward(x)
70
+
71
+ B, N, C = x.shape
72
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
73
+
74
+ q, k, v = unbind(qkv, 2)
75
+
76
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
77
+ x = x.reshape([B, N, C])
78
+
79
+ x = self.proj(x)
80
+ x = self.proj_drop(x)
81
+ return x
82
+
83
+
finetune/modules/depth_warping/depth_anything_v2/dinov2_layers/block.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10
+
11
+ import logging
12
+ from typing import Callable, List, Any, Tuple, Dict
13
+
14
+ import torch
15
+ from torch import nn, Tensor
16
+
17
+ from .attention import Attention, MemEffAttention
18
+ from .drop_path import DropPath
19
+ from .layer_scale import LayerScale
20
+ from .mlp import Mlp
21
+
22
+
23
+ logger = logging.getLogger("dinov2")
24
+
25
+
26
+ try:
27
+ from xformers.ops import fmha
28
+ from xformers.ops import scaled_index_add, index_select_cat
29
+
30
+ XFORMERS_AVAILABLE = True
31
+ except ImportError:
32
+ logger.warning("xFormers not available")
33
+ XFORMERS_AVAILABLE = False
34
+
35
+
36
+ class Block(nn.Module):
37
+ def __init__(
38
+ self,
39
+ dim: int,
40
+ num_heads: int,
41
+ mlp_ratio: float = 4.0,
42
+ qkv_bias: bool = False,
43
+ proj_bias: bool = True,
44
+ ffn_bias: bool = True,
45
+ drop: float = 0.0,
46
+ attn_drop: float = 0.0,
47
+ init_values=None,
48
+ drop_path: float = 0.0,
49
+ act_layer: Callable[..., nn.Module] = nn.GELU,
50
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
51
+ attn_class: Callable[..., nn.Module] = Attention,
52
+ ffn_layer: Callable[..., nn.Module] = Mlp,
53
+ ) -> None:
54
+ super().__init__()
55
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
56
+ self.norm1 = norm_layer(dim)
57
+ self.attn = attn_class(
58
+ dim,
59
+ num_heads=num_heads,
60
+ qkv_bias=qkv_bias,
61
+ proj_bias=proj_bias,
62
+ attn_drop=attn_drop,
63
+ proj_drop=drop,
64
+ )
65
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
66
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
67
+
68
+ self.norm2 = norm_layer(dim)
69
+ mlp_hidden_dim = int(dim * mlp_ratio)
70
+ self.mlp = ffn_layer(
71
+ in_features=dim,
72
+ hidden_features=mlp_hidden_dim,
73
+ act_layer=act_layer,
74
+ drop=drop,
75
+ bias=ffn_bias,
76
+ )
77
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
78
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
79
+
80
+ self.sample_drop_ratio = drop_path
81
+
82
+ def forward(self, x: Tensor) -> Tensor:
83
+ def attn_residual_func(x: Tensor) -> Tensor:
84
+ return self.ls1(self.attn(self.norm1(x)))
85
+
86
+ def ffn_residual_func(x: Tensor) -> Tensor:
87
+ return self.ls2(self.mlp(self.norm2(x)))
88
+
89
+ if self.training and self.sample_drop_ratio > 0.1:
90
+ # the overhead is compensated only for a drop path rate larger than 0.1
91
+ x = drop_add_residual_stochastic_depth(
92
+ x,
93
+ residual_func=attn_residual_func,
94
+ sample_drop_ratio=self.sample_drop_ratio,
95
+ )
96
+ x = drop_add_residual_stochastic_depth(
97
+ x,
98
+ residual_func=ffn_residual_func,
99
+ sample_drop_ratio=self.sample_drop_ratio,
100
+ )
101
+ elif self.training and self.sample_drop_ratio > 0.0:
102
+ x = x + self.drop_path1(attn_residual_func(x))
103
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
104
+ else:
105
+ x = x + attn_residual_func(x)
106
+ x = x + ffn_residual_func(x)
107
+ return x
108
+
109
+
110
+ def drop_add_residual_stochastic_depth(
111
+ x: Tensor,
112
+ residual_func: Callable[[Tensor], Tensor],
113
+ sample_drop_ratio: float = 0.0,
114
+ ) -> Tensor:
115
+ # 1) extract subset using permutation
116
+ b, n, d = x.shape
117
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
118
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
119
+ x_subset = x[brange]
120
+
121
+ # 2) apply residual_func to get residual
122
+ residual = residual_func(x_subset)
123
+
124
+ x_flat = x.flatten(1)
125
+ residual = residual.flatten(1)
126
+
127
+ residual_scale_factor = b / sample_subset_size
128
+
129
+ # 3) add the residual
130
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
131
+ return x_plus_residual.view_as(x)
132
+
133
+
134
+ def get_branges_scales(x, sample_drop_ratio=0.0):
135
+ b, n, d = x.shape
136
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
137
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
138
+ residual_scale_factor = b / sample_subset_size
139
+ return brange, residual_scale_factor
140
+
141
+
142
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
143
+ if scaling_vector is None:
144
+ x_flat = x.flatten(1)
145
+ residual = residual.flatten(1)
146
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
147
+ else:
148
+ x_plus_residual = scaled_index_add(
149
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
150
+ )
151
+ return x_plus_residual
152
+
153
+
154
+ attn_bias_cache: Dict[Tuple, Any] = {}
155
+
156
+
157
+ def get_attn_bias_and_cat(x_list, branges=None):
158
+ """
159
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
160
+ """
161
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
162
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
163
+ if all_shapes not in attn_bias_cache.keys():
164
+ seqlens = []
165
+ for b, x in zip(batch_sizes, x_list):
166
+ for _ in range(b):
167
+ seqlens.append(x.shape[1])
168
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
169
+ attn_bias._batch_sizes = batch_sizes
170
+ attn_bias_cache[all_shapes] = attn_bias
171
+
172
+ if branges is not None:
173
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
174
+ else:
175
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
176
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
177
+
178
+ return attn_bias_cache[all_shapes], cat_tensors
179
+
180
+
181
+ def drop_add_residual_stochastic_depth_list(
182
+ x_list: List[Tensor],
183
+ residual_func: Callable[[Tensor, Any], Tensor],
184
+ sample_drop_ratio: float = 0.0,
185
+ scaling_vector=None,
186
+ ) -> Tensor:
187
+ # 1) generate random set of indices for dropping samples in the batch
188
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
189
+ branges = [s[0] for s in branges_scales]
190
+ residual_scale_factors = [s[1] for s in branges_scales]
191
+
192
+ # 2) get attention bias and index+concat the tensors
193
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
194
+
195
+ # 3) apply residual_func to get residual, and split the result
196
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
197
+
198
+ outputs = []
199
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
200
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
201
+ return outputs
202
+
203
+
204
+ class NestedTensorBlock(Block):
205
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
206
+ """
207
+ x_list contains a list of tensors to nest together and run
208
+ """
209
+ assert isinstance(self.attn, MemEffAttention)
210
+
211
+ if self.training and self.sample_drop_ratio > 0.0:
212
+
213
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
214
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
215
+
216
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
217
+ return self.mlp(self.norm2(x))
218
+
219
+ x_list = drop_add_residual_stochastic_depth_list(
220
+ x_list,
221
+ residual_func=attn_residual_func,
222
+ sample_drop_ratio=self.sample_drop_ratio,
223
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
224
+ )
225
+ x_list = drop_add_residual_stochastic_depth_list(
226
+ x_list,
227
+ residual_func=ffn_residual_func,
228
+ sample_drop_ratio=self.sample_drop_ratio,
229
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
230
+ )
231
+ return x_list
232
+ else:
233
+
234
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
235
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
236
+
237
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
238
+ return self.ls2(self.mlp(self.norm2(x)))
239
+
240
+ attn_bias, x = get_attn_bias_and_cat(x_list)
241
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
242
+ x = x + ffn_residual_func(x)
243
+ return attn_bias.split(x)
244
+
245
+ def forward(self, x_or_x_list):
246
+ if isinstance(x_or_x_list, Tensor):
247
+ return super().forward(x_or_x_list)
248
+ elif isinstance(x_or_x_list, list):
249
+ assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
250
+ return self.forward_nested(x_or_x_list)
251
+ else:
252
+ raise AssertionError
finetune/modules/depth_warping/depth_anything_v2/dinov2_layers/drop_path.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
10
+
11
+
12
+ from torch import nn
13
+
14
+
15
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
16
+ if drop_prob == 0.0 or not training:
17
+ return x
18
+ keep_prob = 1 - drop_prob
19
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
20
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
21
+ if keep_prob > 0.0:
22
+ random_tensor.div_(keep_prob)
23
+ output = x * random_tensor
24
+ return output
25
+
26
+
27
+ class DropPath(nn.Module):
28
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
29
+
30
+ def __init__(self, drop_prob=None):
31
+ super(DropPath, self).__init__()
32
+ self.drop_prob = drop_prob
33
+
34
+ def forward(self, x):
35
+ return drop_path(x, self.drop_prob, self.training)
finetune/modules/depth_warping/depth_anything_v2/dinov2_layers/layer_scale.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
8
+
9
+ from typing import Union
10
+
11
+ import torch
12
+ from torch import Tensor
13
+ from torch import nn
14
+
15
+
16
+ class LayerScale(nn.Module):
17
+ def __init__(
18
+ self,
19
+ dim: int,
20
+ init_values: Union[float, Tensor] = 1e-5,
21
+ inplace: bool = False,
22
+ ) -> None:
23
+ super().__init__()
24
+ self.inplace = inplace
25
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
26
+
27
+ def forward(self, x: Tensor) -> Tensor:
28
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
finetune/modules/depth_warping/depth_anything_v2/dinov2_layers/mlp.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
10
+
11
+
12
+ from typing import Callable, Optional
13
+
14
+ from torch import Tensor, nn
15
+
16
+
17
+ class Mlp(nn.Module):
18
+ def __init__(
19
+ self,
20
+ in_features: int,
21
+ hidden_features: Optional[int] = None,
22
+ out_features: Optional[int] = None,
23
+ act_layer: Callable[..., nn.Module] = nn.GELU,
24
+ drop: float = 0.0,
25
+ bias: bool = True,
26
+ ) -> None:
27
+ super().__init__()
28
+ out_features = out_features or in_features
29
+ hidden_features = hidden_features or in_features
30
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
31
+ self.act = act_layer()
32
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
33
+ self.drop = nn.Dropout(drop)
34
+
35
+ def forward(self, x: Tensor) -> Tensor:
36
+ x = self.fc1(x)
37
+ x = self.act(x)
38
+ x = self.drop(x)
39
+ x = self.fc2(x)
40
+ x = self.drop(x)
41
+ return x
finetune/modules/depth_warping/depth_anything_v2/dinov2_layers/patch_embed.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10
+
11
+ from typing import Callable, Optional, Tuple, Union
12
+
13
+ from torch import Tensor
14
+ import torch.nn as nn
15
+
16
+
17
+ def make_2tuple(x):
18
+ if isinstance(x, tuple):
19
+ assert len(x) == 2
20
+ return x
21
+
22
+ assert isinstance(x, int)
23
+ return (x, x)
24
+
25
+
26
+ class PatchEmbed(nn.Module):
27
+ """
28
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
29
+
30
+ Args:
31
+ img_size: Image size.
32
+ patch_size: Patch token size.
33
+ in_chans: Number of input image channels.
34
+ embed_dim: Number of linear projection output channels.
35
+ norm_layer: Normalization layer.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ img_size: Union[int, Tuple[int, int]] = 224,
41
+ patch_size: Union[int, Tuple[int, int]] = 16,
42
+ in_chans: int = 3,
43
+ embed_dim: int = 768,
44
+ norm_layer: Optional[Callable] = None,
45
+ flatten_embedding: bool = True,
46
+ ) -> None:
47
+ super().__init__()
48
+
49
+ image_HW = make_2tuple(img_size)
50
+ patch_HW = make_2tuple(patch_size)
51
+ patch_grid_size = (
52
+ image_HW[0] // patch_HW[0],
53
+ image_HW[1] // patch_HW[1],
54
+ )
55
+
56
+ self.img_size = image_HW
57
+ self.patch_size = patch_HW
58
+ self.patches_resolution = patch_grid_size
59
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
60
+
61
+ self.in_chans = in_chans
62
+ self.embed_dim = embed_dim
63
+
64
+ self.flatten_embedding = flatten_embedding
65
+
66
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
67
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
68
+
69
+ def forward(self, x: Tensor) -> Tensor:
70
+ _, _, H, W = x.shape
71
+ patch_H, patch_W = self.patch_size
72
+
73
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
74
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
75
+
76
+ x = self.proj(x) # B C H W
77
+ H, W = x.size(2), x.size(3)
78
+ x = x.flatten(2).transpose(1, 2) # B HW C
79
+ x = self.norm(x)
80
+ if not self.flatten_embedding:
81
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
82
+ return x
83
+
84
+ def flops(self) -> float:
85
+ Ho, Wo = self.patches_resolution
86
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
87
+ if self.norm is not None:
88
+ flops += Ho * Wo * self.embed_dim
89
+ return flops
finetune/modules/depth_warping/depth_anything_v2/dinov2_layers/swiglu_ffn.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Callable, Optional
8
+
9
+ from torch import Tensor, nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ class SwiGLUFFN(nn.Module):
14
+ def __init__(
15
+ self,
16
+ in_features: int,
17
+ hidden_features: Optional[int] = None,
18
+ out_features: Optional[int] = None,
19
+ act_layer: Callable[..., nn.Module] = None,
20
+ drop: float = 0.0,
21
+ bias: bool = True,
22
+ ) -> None:
23
+ super().__init__()
24
+ out_features = out_features or in_features
25
+ hidden_features = hidden_features or in_features
26
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
27
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
28
+
29
+ def forward(self, x: Tensor) -> Tensor:
30
+ x12 = self.w12(x)
31
+ x1, x2 = x12.chunk(2, dim=-1)
32
+ hidden = F.silu(x1) * x2
33
+ return self.w3(hidden)
34
+
35
+
36
+ try:
37
+ from xformers.ops import SwiGLU
38
+
39
+ XFORMERS_AVAILABLE = True
40
+ except ImportError:
41
+ SwiGLU = SwiGLUFFN
42
+ XFORMERS_AVAILABLE = False
43
+
44
+
45
+ class SwiGLUFFNFused(SwiGLU):
46
+ def __init__(
47
+ self,
48
+ in_features: int,
49
+ hidden_features: Optional[int] = None,
50
+ out_features: Optional[int] = None,
51
+ act_layer: Callable[..., nn.Module] = None,
52
+ drop: float = 0.0,
53
+ bias: bool = True,
54
+ ) -> None:
55
+ out_features = out_features or in_features
56
+ hidden_features = hidden_features or in_features
57
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
58
+ super().__init__(
59
+ in_features=in_features,
60
+ hidden_features=hidden_features,
61
+ out_features=out_features,
62
+ bias=bias,
63
+ )
finetune/modules/depth_warping/depth_anything_v2/dpt.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torchvision.transforms as tf
6
+ from torchvision.transforms import Compose
7
+
8
+ from .dinov2 import DINOv2
9
+ from .util.blocks import FeatureFusionBlock, _make_scratch
10
+ from .util.transform import Resize, NormalizeImage, PrepareForNet
11
+
12
+
13
+ def _make_fusion_block(features, use_bn, size=None):
14
+ return FeatureFusionBlock(
15
+ features,
16
+ nn.ReLU(False),
17
+ deconv=False,
18
+ bn=use_bn,
19
+ expand=False,
20
+ align_corners=True,
21
+ size=size,
22
+ )
23
+
24
+
25
+ class ConvBlock(nn.Module):
26
+ def __init__(self, in_feature, out_feature):
27
+ super().__init__()
28
+
29
+ self.conv_block = nn.Sequential(
30
+ nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1),
31
+ nn.BatchNorm2d(out_feature),
32
+ nn.ReLU(True)
33
+ )
34
+
35
+ def forward(self, x):
36
+ return self.conv_block(x)
37
+
38
+
39
+ class DPTHead(nn.Module):
40
+ def __init__(
41
+ self,
42
+ in_channels,
43
+ features=256,
44
+ use_bn=False,
45
+ out_channels=[256, 512, 1024, 1024],
46
+ use_clstoken=False
47
+ ):
48
+ super(DPTHead, self).__init__()
49
+
50
+ self.use_clstoken = use_clstoken
51
+
52
+ self.projects = nn.ModuleList([
53
+ nn.Conv2d(
54
+ in_channels=in_channels,
55
+ out_channels=out_channel,
56
+ kernel_size=1,
57
+ stride=1,
58
+ padding=0,
59
+ ) for out_channel in out_channels
60
+ ])
61
+
62
+ self.resize_layers = nn.ModuleList([
63
+ nn.ConvTranspose2d(
64
+ in_channels=out_channels[0],
65
+ out_channels=out_channels[0],
66
+ kernel_size=4,
67
+ stride=4,
68
+ padding=0),
69
+ nn.ConvTranspose2d(
70
+ in_channels=out_channels[1],
71
+ out_channels=out_channels[1],
72
+ kernel_size=2,
73
+ stride=2,
74
+ padding=0),
75
+ nn.Identity(),
76
+ nn.Conv2d(
77
+ in_channels=out_channels[3],
78
+ out_channels=out_channels[3],
79
+ kernel_size=3,
80
+ stride=2,
81
+ padding=1)
82
+ ])
83
+
84
+ if use_clstoken:
85
+ self.readout_projects = nn.ModuleList()
86
+ for _ in range(len(self.projects)):
87
+ self.readout_projects.append(
88
+ nn.Sequential(
89
+ nn.Linear(2 * in_channels, in_channels),
90
+ nn.GELU()))
91
+
92
+ self.scratch = _make_scratch(
93
+ out_channels,
94
+ features,
95
+ groups=1,
96
+ expand=False,
97
+ )
98
+
99
+ self.scratch.stem_transpose = None
100
+
101
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
102
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
103
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
104
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
105
+
106
+ head_features_1 = features
107
+ head_features_2 = 32
108
+
109
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
110
+ self.scratch.output_conv2 = nn.Sequential(
111
+ nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
112
+ nn.ReLU(True),
113
+ nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
114
+ nn.Sigmoid()
115
+ )
116
+
117
+ def forward(self, out_features, patch_h, patch_w):
118
+ out = []
119
+ for i, x in enumerate(out_features):
120
+ if self.use_clstoken:
121
+ x, cls_token = x[0], x[1]
122
+ readout = cls_token.unsqueeze(1).expand_as(x)
123
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
124
+ else:
125
+ x = x[0]
126
+
127
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
128
+
129
+ x = self.projects[i](x)
130
+ x = self.resize_layers[i](x)
131
+
132
+ out.append(x)
133
+
134
+ layer_1, layer_2, layer_3, layer_4 = out
135
+
136
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
137
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
138
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
139
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
140
+
141
+ path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
142
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
143
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
144
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
145
+
146
+ out = self.scratch.output_conv1(path_1)
147
+ out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
148
+ out = self.scratch.output_conv2(out)
149
+
150
+ return out
151
+
152
+
153
+ class DepthAnythingV2(nn.Module):
154
+ def __init__(
155
+ self,
156
+ encoder='vitl',
157
+ features=256,
158
+ out_channels=[256, 512, 1024, 1024],
159
+ use_bn=False,
160
+ use_clstoken=False,
161
+ max_depth=20.0
162
+ ):
163
+ super(DepthAnythingV2, self).__init__()
164
+
165
+ self.intermediate_layer_idx = {
166
+ 'vits': [2, 5, 8, 11],
167
+ 'vitb': [2, 5, 8, 11],
168
+ 'vitl': [4, 11, 17, 23],
169
+ 'vitg': [9, 19, 29, 39]
170
+ }
171
+
172
+ self.max_depth = max_depth
173
+
174
+ self.encoder = encoder
175
+ self.pretrained = DINOv2(model_name=encoder)
176
+
177
+ self.depth_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
178
+
179
+ def forward(self, x):
180
+ patch_h, patch_w = x.shape[-2] // 14, x.shape[-1] // 14
181
+
182
+ features = self.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True)
183
+
184
+ depth = self.depth_head(features, patch_h, patch_w) * self.max_depth
185
+
186
+ return depth.squeeze(1)
187
+
188
+ @torch.no_grad()
189
+ def infer_image(self, raw_image, input_size=518):
190
+ image, (h, w) = self.image2tensor(raw_image, input_size)
191
+
192
+ depth = self.forward(image)
193
+
194
+ depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True)[0, 0]
195
+
196
+ return depth
197
+ # return depth.cpu().numpy()
198
+
199
+
200
+ # TODO. transform for torch.Tensor
201
+ # TODO. inference for torch.Tensor
202
+ # def image2tensor_pt(self, raw_image, input_size=518):
203
+ # transform = Compose([
204
+ # tf
205
+ # ])
206
+
207
+
208
+ def image2tensor(self, raw_image, input_size=518):
209
+ transform = Compose([
210
+ Resize(
211
+ width=input_size,
212
+ height=input_size,
213
+ resize_target=False,
214
+ keep_aspect_ratio=True,
215
+ ensure_multiple_of=14,
216
+ resize_method='lower_bound',
217
+ image_interpolation_method=cv2.INTER_CUBIC,
218
+ ),
219
+ NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
220
+ PrepareForNet(),
221
+ ])
222
+
223
+ h, w = raw_image.shape[:2]
224
+
225
+ # raw_image already has RGB order, [0,255]
226
+ # image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0
227
+ image = raw_image / 255.0
228
+
229
+ image = transform({'image': image})['image']
230
+ image = torch.from_numpy(image).unsqueeze(0)
231
+
232
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
233
+ image = image.to(DEVICE)
234
+
235
+ return image, (h, w)
finetune/modules/depth_warping/depth_anything_v2/util/blocks.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
5
+ scratch = nn.Module()
6
+
7
+ out_shape1 = out_shape
8
+ out_shape2 = out_shape
9
+ out_shape3 = out_shape
10
+ if len(in_shape) >= 4:
11
+ out_shape4 = out_shape
12
+
13
+ if expand:
14
+ out_shape1 = out_shape
15
+ out_shape2 = out_shape * 2
16
+ out_shape3 = out_shape * 4
17
+ if len(in_shape) >= 4:
18
+ out_shape4 = out_shape * 8
19
+
20
+ scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
21
+ scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
22
+ scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
23
+ if len(in_shape) >= 4:
24
+ scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
25
+
26
+ return scratch
27
+
28
+
29
+ class ResidualConvUnit(nn.Module):
30
+ """Residual convolution module.
31
+ """
32
+
33
+ def __init__(self, features, activation, bn):
34
+ """Init.
35
+
36
+ Args:
37
+ features (int): number of features
38
+ """
39
+ super().__init__()
40
+
41
+ self.bn = bn
42
+
43
+ self.groups=1
44
+
45
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
46
+
47
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
48
+
49
+ if self.bn == True:
50
+ self.bn1 = nn.BatchNorm2d(features)
51
+ self.bn2 = nn.BatchNorm2d(features)
52
+
53
+ self.activation = activation
54
+
55
+ self.skip_add = nn.quantized.FloatFunctional()
56
+
57
+ def forward(self, x):
58
+ """Forward pass.
59
+
60
+ Args:
61
+ x (tensor): input
62
+
63
+ Returns:
64
+ tensor: output
65
+ """
66
+
67
+ out = self.activation(x)
68
+ out = self.conv1(out)
69
+ if self.bn == True:
70
+ out = self.bn1(out)
71
+
72
+ out = self.activation(out)
73
+ out = self.conv2(out)
74
+ if self.bn == True:
75
+ out = self.bn2(out)
76
+
77
+ if self.groups > 1:
78
+ out = self.conv_merge(out)
79
+
80
+ return self.skip_add.add(out, x)
81
+
82
+
83
+ class FeatureFusionBlock(nn.Module):
84
+ """Feature fusion block.
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ features,
90
+ activation,
91
+ deconv=False,
92
+ bn=False,
93
+ expand=False,
94
+ align_corners=True,
95
+ size=None
96
+ ):
97
+ """Init.
98
+
99
+ Args:
100
+ features (int): number of features
101
+ """
102
+ super(FeatureFusionBlock, self).__init__()
103
+
104
+ self.deconv = deconv
105
+ self.align_corners = align_corners
106
+
107
+ self.groups=1
108
+
109
+ self.expand = expand
110
+ out_features = features
111
+ if self.expand == True:
112
+ out_features = features // 2
113
+
114
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
115
+
116
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
117
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
118
+
119
+ self.skip_add = nn.quantized.FloatFunctional()
120
+
121
+ self.size=size
122
+
123
+ def forward(self, *xs, size=None):
124
+ """Forward pass.
125
+
126
+ Returns:
127
+ tensor: output
128
+ """
129
+ output = xs[0]
130
+
131
+ if len(xs) == 2:
132
+ res = self.resConfUnit1(xs[1])
133
+ output = self.skip_add.add(output, res)
134
+
135
+ output = self.resConfUnit2(output)
136
+
137
+ if (size is None) and (self.size is None):
138
+ modifier = {"scale_factor": 2}
139
+ elif size is None:
140
+ modifier = {"size": self.size}
141
+ else:
142
+ modifier = {"size": size}
143
+
144
+ output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
145
+
146
+ output = self.out_conv(output)
147
+
148
+ return output
finetune/modules/depth_warping/depth_anything_v2/util/transform.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+
4
+
5
+ class Resize(object):
6
+ """Resize sample to given size (width, height).
7
+ """
8
+
9
+ def __init__(
10
+ self,
11
+ width,
12
+ height,
13
+ resize_target=True,
14
+ keep_aspect_ratio=False,
15
+ ensure_multiple_of=1,
16
+ resize_method="lower_bound",
17
+ image_interpolation_method=cv2.INTER_AREA,
18
+ ):
19
+ """Init.
20
+
21
+ Args:
22
+ width (int): desired output width
23
+ height (int): desired output height
24
+ resize_target (bool, optional):
25
+ True: Resize the full sample (image, mask, target).
26
+ False: Resize image only.
27
+ Defaults to True.
28
+ keep_aspect_ratio (bool, optional):
29
+ True: Keep the aspect ratio of the input sample.
30
+ Output sample might not have the given width and height, and
31
+ resize behaviour depends on the parameter 'resize_method'.
32
+ Defaults to False.
33
+ ensure_multiple_of (int, optional):
34
+ Output width and height is constrained to be multiple of this parameter.
35
+ Defaults to 1.
36
+ resize_method (str, optional):
37
+ "lower_bound": Output will be at least as large as the given size.
38
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
39
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
40
+ Defaults to "lower_bound".
41
+ """
42
+ self.__width = width
43
+ self.__height = height
44
+
45
+ self.__resize_target = resize_target
46
+ self.__keep_aspect_ratio = keep_aspect_ratio
47
+ self.__multiple_of = ensure_multiple_of
48
+ self.__resize_method = resize_method
49
+ self.__image_interpolation_method = image_interpolation_method
50
+
51
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
52
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
53
+
54
+ if max_val is not None and y > max_val:
55
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
56
+
57
+ if y < min_val:
58
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
59
+
60
+ return y
61
+
62
+ def get_size(self, width, height):
63
+ # determine new height and width
64
+ scale_height = self.__height / height
65
+ scale_width = self.__width / width
66
+
67
+ if self.__keep_aspect_ratio:
68
+ if self.__resize_method == "lower_bound":
69
+ # scale such that output size is lower bound
70
+ if scale_width > scale_height:
71
+ # fit width
72
+ scale_height = scale_width
73
+ else:
74
+ # fit height
75
+ scale_width = scale_height
76
+ elif self.__resize_method == "upper_bound":
77
+ # scale such that output size is upper bound
78
+ if scale_width < scale_height:
79
+ # fit width
80
+ scale_height = scale_width
81
+ else:
82
+ # fit height
83
+ scale_width = scale_height
84
+ elif self.__resize_method == "minimal":
85
+ # scale as least as possbile
86
+ if abs(1 - scale_width) < abs(1 - scale_height):
87
+ # fit width
88
+ scale_height = scale_width
89
+ else:
90
+ # fit height
91
+ scale_width = scale_height
92
+ else:
93
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
94
+
95
+ if self.__resize_method == "lower_bound":
96
+ new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
97
+ new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
98
+ elif self.__resize_method == "upper_bound":
99
+ new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
100
+ new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
101
+ elif self.__resize_method == "minimal":
102
+ new_height = self.constrain_to_multiple_of(scale_height * height)
103
+ new_width = self.constrain_to_multiple_of(scale_width * width)
104
+ else:
105
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
106
+
107
+ return (new_width, new_height)
108
+
109
+ def __call__(self, sample):
110
+ width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
111
+
112
+ # resize sample
113
+ sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method)
114
+
115
+ if self.__resize_target:
116
+ if "depth" in sample:
117
+ sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
118
+
119
+ if "mask" in sample:
120
+ sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST)
121
+
122
+ return sample
123
+
124
+
125
+ class NormalizeImage(object):
126
+ """Normlize image by given mean and std.
127
+ """
128
+
129
+ def __init__(self, mean, std):
130
+ self.__mean = mean
131
+ self.__std = std
132
+
133
+ def __call__(self, sample):
134
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
135
+
136
+ return sample
137
+
138
+
139
+ class PrepareForNet(object):
140
+ """Prepare sample for usage as network input.
141
+ """
142
+
143
+ def __init__(self):
144
+ pass
145
+
146
+ def __call__(self, sample):
147
+ image = np.transpose(sample["image"], (2, 0, 1))
148
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
149
+
150
+ if "depth" in sample:
151
+ depth = sample["depth"].astype(np.float32)
152
+ sample["depth"] = np.ascontiguousarray(depth)
153
+
154
+ if "mask" in sample:
155
+ sample["mask"] = sample["mask"].astype(np.float32)
156
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
157
+
158
+ return sample
finetune/modules/depth_warping/depth_pro/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
2
+ """Depth Pro package."""
3
+
4
+ from .depth_pro import create_model_and_transforms # noqa
5
+ from .utils import load_rgb # noqa
finetune/modules/depth_warping/depth_pro/cli/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
2
+ """Depth Pro CLI and tools."""
3
+
4
+ from .run import main as run_main # noqa
finetune/modules/depth_warping/depth_pro/cli/run.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Sample script to run DepthPro.
3
+
4
+ Copyright (C) 2024 Apple Inc. All Rights Reserved.
5
+ """
6
+
7
+
8
+ import argparse
9
+ import logging
10
+ from pathlib import Path
11
+
12
+ import numpy as np
13
+ import PIL.Image
14
+ import torch
15
+ from matplotlib import pyplot as plt
16
+ from tqdm import tqdm
17
+
18
+ from depth_pro import create_model_and_transforms, load_rgb
19
+
20
+ LOGGER = logging.getLogger(__name__)
21
+
22
+
23
+ def get_torch_device() -> torch.device:
24
+ """Get the Torch device."""
25
+ device = torch.device("cpu")
26
+ if torch.cuda.is_available():
27
+ device = torch.device("cuda:0")
28
+ elif torch.backends.mps.is_available():
29
+ device = torch.device("mps")
30
+ return device
31
+
32
+
33
+ def run(args):
34
+ """Run Depth Pro on a sample image."""
35
+ if args.verbose:
36
+ logging.basicConfig(level=logging.INFO)
37
+
38
+ # Load model.
39
+ model, transform = create_model_and_transforms(
40
+ device=get_torch_device(),
41
+ precision=torch.half,
42
+ )
43
+ model.eval()
44
+
45
+ image_paths = [args.image_path]
46
+ if args.image_path.is_dir():
47
+ image_paths = args.image_path.glob("**/*")
48
+ relative_path = args.image_path
49
+ else:
50
+ relative_path = args.image_path.parent
51
+
52
+ if not args.skip_display:
53
+ plt.ion()
54
+ fig = plt.figure()
55
+ ax_rgb = fig.add_subplot(121)
56
+ ax_disp = fig.add_subplot(122)
57
+
58
+ for image_path in tqdm(image_paths):
59
+ # Load image and focal length from exif info (if found.).
60
+ try:
61
+ LOGGER.info(f"Loading image {image_path} ...")
62
+ image, _, f_px = load_rgb(image_path)
63
+ except Exception as e:
64
+ LOGGER.error(str(e))
65
+ continue
66
+ # Run prediction. If `f_px` is provided, it is used to estimate the final metric depth,
67
+ # otherwise the model estimates `f_px` to compute the depth metricness.
68
+ prediction = model.infer(transform(image), f_px=f_px)
69
+
70
+ # Extract the depth and focal length.
71
+ depth = prediction["depth"].detach().cpu().numpy().squeeze()
72
+ if f_px is not None:
73
+ LOGGER.debug(f"Focal length (from exif): {f_px:0.2f}")
74
+ elif prediction["focallength_px"] is not None:
75
+ focallength_px = prediction["focallength_px"].detach().cpu().item()
76
+ LOGGER.info(f"Estimated focal length: {focallength_px}")
77
+
78
+ inverse_depth = 1 / depth
79
+ # Visualize inverse depth instead of depth, clipped to [0.1m;250m] range for better visualization.
80
+ max_invdepth_vizu = min(inverse_depth.max(), 1 / 0.1)
81
+ min_invdepth_vizu = max(1 / 250, inverse_depth.min())
82
+ inverse_depth_normalized = (inverse_depth - min_invdepth_vizu) / (
83
+ max_invdepth_vizu - min_invdepth_vizu
84
+ )
85
+
86
+ # Save Depth as npz file.
87
+ if args.output_path is not None:
88
+ output_file = (
89
+ args.output_path
90
+ / image_path.relative_to(relative_path).parent
91
+ / image_path.stem
92
+ )
93
+ LOGGER.info(f"Saving depth map to: {str(output_file)}")
94
+ output_file.parent.mkdir(parents=True, exist_ok=True)
95
+ np.savez_compressed(output_file, depth=depth)
96
+
97
+ # Save as color-mapped "turbo" jpg image.
98
+ cmap = plt.get_cmap("turbo")
99
+ color_depth = (cmap(inverse_depth_normalized)[..., :3] * 255).astype(
100
+ np.uint8
101
+ )
102
+ color_map_output_file = str(output_file) + ".jpg"
103
+ LOGGER.info(f"Saving color-mapped depth to: : {color_map_output_file}")
104
+ PIL.Image.fromarray(color_depth).save(
105
+ color_map_output_file, format="JPEG", quality=90
106
+ )
107
+
108
+ # Display the image and estimated depth map.
109
+ if not args.skip_display:
110
+ ax_rgb.imshow(image)
111
+ ax_disp.imshow(inverse_depth_normalized, cmap="turbo")
112
+ fig.canvas.draw()
113
+ fig.canvas.flush_events()
114
+
115
+ LOGGER.info("Done predicting depth!")
116
+ if not args.skip_display:
117
+ plt.show(block=True)
118
+
119
+
120
+ def main():
121
+ """Run DepthPro inference example."""
122
+ parser = argparse.ArgumentParser(
123
+ description="Inference scripts of DepthPro with PyTorch models."
124
+ )
125
+ parser.add_argument(
126
+ "-i",
127
+ "--image-path",
128
+ type=Path,
129
+ default="./data/example.jpg",
130
+ help="Path to input image.",
131
+ )
132
+ parser.add_argument(
133
+ "-o",
134
+ "--output-path",
135
+ type=Path,
136
+ help="Path to store output files.",
137
+ )
138
+ parser.add_argument(
139
+ "--skip-display",
140
+ action="store_true",
141
+ help="Skip matplotlib display.",
142
+ )
143
+ parser.add_argument(
144
+ "-v",
145
+ "--verbose",
146
+ action="store_true",
147
+ help="Show verbose output."
148
+ )
149
+
150
+ run(parser.parse_args())
151
+
152
+
153
+ if __name__ == "__main__":
154
+ main()
finetune/modules/depth_warping/depth_pro/depth_pro.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
2
+ # Depth Pro: Sharp Monocular Metric Depth in Less Than a Second
3
+
4
+
5
+ from __future__ import annotations
6
+
7
+ from dataclasses import dataclass
8
+ from typing import Mapping, Optional, Tuple, Union
9
+
10
+ import torch
11
+ from torch import nn
12
+ from torchvision.transforms import (
13
+ Compose,
14
+ ConvertImageDtype,
15
+ Lambda,
16
+ Normalize,
17
+ ToTensor,
18
+ )
19
+
20
+ from .network.decoder import MultiresConvDecoder
21
+ from .network.encoder import DepthProEncoder
22
+ from .network.fov import FOVNetwork
23
+ from .network.vit_factory import VIT_CONFIG_DICT, ViTPreset, create_vit
24
+
25
+
26
+ @dataclass
27
+ class DepthProConfig:
28
+ """Configuration for DepthPro."""
29
+
30
+ patch_encoder_preset: ViTPreset
31
+ image_encoder_preset: ViTPreset
32
+ decoder_features: int
33
+
34
+ checkpoint_uri: Optional[str] = None
35
+ fov_encoder_preset: Optional[ViTPreset] = None
36
+ use_fov_head: bool = True
37
+
38
+
39
+ DEFAULT_MONODEPTH_CONFIG_DICT = DepthProConfig(
40
+ patch_encoder_preset="dinov2l16_384",
41
+ image_encoder_preset="dinov2l16_384",
42
+ checkpoint_uri="./checkpoints/depth_pro.pt",
43
+ decoder_features=256,
44
+ use_fov_head=True,
45
+ fov_encoder_preset="dinov2l16_384",
46
+ )
47
+
48
+
49
+ def create_backbone_model(
50
+ preset: ViTPreset
51
+ ) -> Tuple[nn.Module, ViTPreset]:
52
+ """Create and load a backbone model given a config.
53
+
54
+ Args:
55
+ ----
56
+ preset: A backbone preset to load pre-defind configs.
57
+
58
+ Returns:
59
+ -------
60
+ A Torch module and the associated config.
61
+
62
+ """
63
+ if preset in VIT_CONFIG_DICT:
64
+ config = VIT_CONFIG_DICT[preset]
65
+ model = create_vit(preset=preset, use_pretrained=False)
66
+ else:
67
+ raise KeyError(f"Preset {preset} not found.")
68
+
69
+ return model, config
70
+
71
+
72
+ def create_model_and_transforms(
73
+ config: DepthProConfig = DEFAULT_MONODEPTH_CONFIG_DICT,
74
+ device: torch.device = torch.device("cpu"),
75
+ precision: torch.dtype = torch.float32,
76
+ ) -> Tuple[DepthPro, Compose]:
77
+ """Create a DepthPro model and load weights from `config.checkpoint_uri`.
78
+
79
+ Args:
80
+ ----
81
+ config: The configuration for the DPT model architecture.
82
+ device: The optional Torch device to load the model onto, default runs on "cpu".
83
+ precision: The optional precision used for the model, default is FP32.
84
+
85
+ Returns:
86
+ -------
87
+ The Torch DepthPro model and associated Transform.
88
+
89
+ """
90
+ patch_encoder, patch_encoder_config = create_backbone_model(
91
+ preset=config.patch_encoder_preset
92
+ )
93
+ image_encoder, _ = create_backbone_model(
94
+ preset=config.image_encoder_preset
95
+ )
96
+
97
+ fov_encoder = None
98
+ if config.use_fov_head and config.fov_encoder_preset is not None:
99
+ fov_encoder, _ = create_backbone_model(preset=config.fov_encoder_preset)
100
+
101
+ dims_encoder = patch_encoder_config.encoder_feature_dims
102
+ hook_block_ids = patch_encoder_config.encoder_feature_layer_ids
103
+ encoder = DepthProEncoder(
104
+ dims_encoder=dims_encoder,
105
+ patch_encoder=patch_encoder,
106
+ image_encoder=image_encoder,
107
+ hook_block_ids=hook_block_ids,
108
+ decoder_features=config.decoder_features,
109
+ )
110
+ decoder = MultiresConvDecoder(
111
+ dims_encoder=[config.decoder_features] + list(encoder.dims_encoder),
112
+ dim_decoder=config.decoder_features,
113
+ )
114
+ model = DepthPro(
115
+ encoder=encoder,
116
+ decoder=decoder,
117
+ last_dims=(32, 1),
118
+ use_fov_head=config.use_fov_head,
119
+ fov_encoder=fov_encoder,
120
+ ).to(device)
121
+
122
+ if precision == torch.half:
123
+ model.half()
124
+
125
+ transform = Compose(
126
+ [
127
+ ToTensor(),
128
+ Lambda(lambda x: x.to(device)),
129
+ Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
130
+ ConvertImageDtype(precision),
131
+ ]
132
+ )
133
+
134
+ if config.checkpoint_uri is not None:
135
+ state_dict = torch.load(config.checkpoint_uri, map_location="cpu")
136
+ missing_keys, unexpected_keys = model.load_state_dict(
137
+ state_dict=state_dict, strict=True
138
+ )
139
+
140
+ if len(unexpected_keys) != 0:
141
+ raise KeyError(
142
+ f"Found unexpected keys when loading monodepth: {unexpected_keys}"
143
+ )
144
+
145
+ # fc_norm is only for the classification head,
146
+ # which we would not use. We only use the encoding.
147
+ missing_keys = [key for key in missing_keys if "fc_norm" not in key]
148
+ if len(missing_keys) != 0:
149
+ raise KeyError(f"Keys are missing when loading monodepth: {missing_keys}")
150
+
151
+ return model, transform
152
+
153
+
154
+ class DepthPro(nn.Module):
155
+ """DepthPro network."""
156
+
157
+ def __init__(
158
+ self,
159
+ encoder: DepthProEncoder,
160
+ decoder: MultiresConvDecoder,
161
+ last_dims: tuple[int, int],
162
+ use_fov_head: bool = True,
163
+ fov_encoder: Optional[nn.Module] = None,
164
+ ):
165
+ """Initialize DepthPro.
166
+
167
+ Args:
168
+ ----
169
+ encoder: The DepthProEncoder backbone.
170
+ decoder: The MultiresConvDecoder decoder.
171
+ last_dims: The dimension for the last convolution layers.
172
+ use_fov_head: Whether to use the field-of-view head.
173
+ fov_encoder: A separate encoder for the field of view.
174
+
175
+ """
176
+ super().__init__()
177
+
178
+ self.encoder = encoder
179
+ self.decoder = decoder
180
+
181
+ dim_decoder = decoder.dim_decoder
182
+ self.head = nn.Sequential(
183
+ nn.Conv2d(
184
+ dim_decoder, dim_decoder // 2, kernel_size=3, stride=1, padding=1
185
+ ),
186
+ nn.ConvTranspose2d(
187
+ in_channels=dim_decoder // 2,
188
+ out_channels=dim_decoder // 2,
189
+ kernel_size=2,
190
+ stride=2,
191
+ padding=0,
192
+ bias=True,
193
+ ),
194
+ nn.Conv2d(
195
+ dim_decoder // 2,
196
+ last_dims[0],
197
+ kernel_size=3,
198
+ stride=1,
199
+ padding=1,
200
+ ),
201
+ nn.ReLU(True),
202
+ nn.Conv2d(last_dims[0], last_dims[1], kernel_size=1, stride=1, padding=0),
203
+ nn.ReLU(),
204
+ )
205
+
206
+ # Set the final convolution layer's bias to be 0.
207
+ self.head[4].bias.data.fill_(0)
208
+
209
+ # Set the FOV estimation head.
210
+ if use_fov_head:
211
+ self.fov = FOVNetwork(num_features=dim_decoder, fov_encoder=fov_encoder)
212
+
213
+ @property
214
+ def img_size(self) -> int:
215
+ """Return the internal image size of the network."""
216
+ return self.encoder.img_size
217
+
218
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
219
+ """Decode by projection and fusion of multi-resolution encodings.
220
+
221
+ Args:
222
+ ----
223
+ x (torch.Tensor): Input image.
224
+
225
+ Returns:
226
+ -------
227
+ The canonical inverse depth map [m] and the optional estimated field of view [deg].
228
+
229
+ """
230
+ _, _, H, W = x.shape
231
+ assert H == self.img_size and W == self.img_size
232
+
233
+ encodings = self.encoder(x)
234
+ features, features_0 = self.decoder(encodings)
235
+ canonical_inverse_depth = self.head(features)
236
+
237
+ fov_deg = None
238
+ if hasattr(self, "fov"):
239
+ fov_deg = self.fov.forward(x, features_0.detach())
240
+
241
+ return canonical_inverse_depth, fov_deg
242
+
243
+ @torch.no_grad()
244
+ def infer(
245
+ self,
246
+ x: torch.Tensor,
247
+ f_px: Optional[Union[float, torch.Tensor]] = None,
248
+ interpolation_mode="bilinear",
249
+ ) -> Mapping[str, torch.Tensor]:
250
+ """Infer depth and fov for a given image.
251
+
252
+ If the image is not at network resolution, it is resized to 1536x1536 and
253
+ the estimated depth is resized to the original image resolution.
254
+ Note: if the focal length is given, the estimated value is ignored and the provided
255
+ focal length is use to generate the metric depth values.
256
+
257
+ Args:
258
+ ----
259
+ x (torch.Tensor): Input image
260
+ f_px (torch.Tensor): Optional focal length in pixels corresponding to `x`.
261
+ interpolation_mode (str): Interpolation function for downsampling/upsampling.
262
+
263
+ Returns:
264
+ -------
265
+ Tensor dictionary (torch.Tensor): depth [m], focallength [pixels].
266
+
267
+ """
268
+ if len(x.shape) == 3:
269
+ x = x.unsqueeze(0)
270
+ _, _, H, W = x.shape
271
+ resize = H != self.img_size or W != self.img_size
272
+
273
+ if resize:
274
+ x = nn.functional.interpolate(
275
+ x,
276
+ size=(self.img_size, self.img_size),
277
+ mode=interpolation_mode,
278
+ align_corners=False,
279
+ )
280
+
281
+ canonical_inverse_depth, fov_deg = self.forward(x)
282
+ if f_px is None:
283
+ f_px = 0.5 * W / torch.tan(0.5 * torch.deg2rad(fov_deg.to(torch.float)))
284
+
285
+ inverse_depth = canonical_inverse_depth * (W / f_px)
286
+ f_px = f_px.squeeze()
287
+
288
+ if resize:
289
+ inverse_depth = nn.functional.interpolate(
290
+ inverse_depth, size=(H, W), mode=interpolation_mode, align_corners=False
291
+ )
292
+
293
+ depth = 1.0 / torch.clamp(inverse_depth, min=1e-4, max=1e4)
294
+
295
+ return {
296
+ "depth": depth.squeeze(),
297
+ "focallength_px": f_px,
298
+ }
finetune/modules/depth_warping/depth_pro/eval/boundary_metrics.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import numpy as np
4
+
5
+
6
+ def connected_component(r: np.ndarray, c: np.ndarray) -> List[List[int]]:
7
+ """Find connected components in the given row and column indices.
8
+
9
+ Args:
10
+ ----
11
+ r (np.ndarray): Row indices.
12
+ c (np.ndarray): Column indices.
13
+
14
+ Yields:
15
+ ------
16
+ List[int]: Indices of connected components.
17
+
18
+ """
19
+ indices = [0]
20
+ for i in range(1, r.size):
21
+ if r[i] == r[indices[-1]] and c[i] == c[indices[-1]] + 1:
22
+ indices.append(i)
23
+ else:
24
+ yield indices
25
+ indices = [i]
26
+ yield indices
27
+
28
+
29
+ def nms_horizontal(ratio: np.ndarray, threshold: float) -> np.ndarray:
30
+ """Apply Non-Maximum Suppression (NMS) horizontally on the given ratio matrix.
31
+
32
+ Args:
33
+ ----
34
+ ratio (np.ndarray): Input ratio matrix.
35
+ threshold (float): Threshold for NMS.
36
+
37
+ Returns:
38
+ -------
39
+ np.ndarray: Binary mask after applying NMS.
40
+
41
+ """
42
+ mask = np.zeros_like(ratio, dtype=bool)
43
+ r, c = np.nonzero(ratio > threshold)
44
+ if len(r) == 0:
45
+ return mask
46
+ for ids in connected_component(r, c):
47
+ values = [ratio[r[i], c[i]] for i in ids]
48
+ mi = np.argmax(values)
49
+ mask[r[ids[mi]], c[ids[mi]]] = True
50
+ return mask
51
+
52
+
53
+ def nms_vertical(ratio: np.ndarray, threshold: float) -> np.ndarray:
54
+ """Apply Non-Maximum Suppression (NMS) vertically on the given ratio matrix.
55
+
56
+ Args:
57
+ ----
58
+ ratio (np.ndarray): Input ratio matrix.
59
+ threshold (float): Threshold for NMS.
60
+
61
+ Returns:
62
+ -------
63
+ np.ndarray: Binary mask after applying NMS.
64
+
65
+ """
66
+ return np.transpose(nms_horizontal(np.transpose(ratio), threshold))
67
+
68
+
69
+ def fgbg_depth(
70
+ d: np.ndarray, t: float
71
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
72
+ """Find foreground-background relations between neighboring pixels.
73
+
74
+ Args:
75
+ ----
76
+ d (np.ndarray): Depth matrix.
77
+ t (float): Threshold for comparison.
78
+
79
+ Returns:
80
+ -------
81
+ Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: Four matrices indicating
82
+ left, top, right, and bottom foreground-background relations.
83
+
84
+ """
85
+ right_is_big_enough = (d[..., :, 1:] / d[..., :, :-1]) > t
86
+ left_is_big_enough = (d[..., :, :-1] / d[..., :, 1:]) > t
87
+ bottom_is_big_enough = (d[..., 1:, :] / d[..., :-1, :]) > t
88
+ top_is_big_enough = (d[..., :-1, :] / d[..., 1:, :]) > t
89
+ return (
90
+ left_is_big_enough,
91
+ top_is_big_enough,
92
+ right_is_big_enough,
93
+ bottom_is_big_enough,
94
+ )
95
+
96
+
97
+ def fgbg_depth_thinned(
98
+ d: np.ndarray, t: float
99
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
100
+ """Find foreground-background relations between neighboring pixels with Non-Maximum Suppression.
101
+
102
+ Args:
103
+ ----
104
+ d (np.ndarray): Depth matrix.
105
+ t (float): Threshold for NMS.
106
+
107
+ Returns:
108
+ -------
109
+ Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: Four matrices indicating
110
+ left, top, right, and bottom foreground-background relations with NMS applied.
111
+
112
+ """
113
+ right_is_big_enough = nms_horizontal(d[..., :, 1:] / d[..., :, :-1], t)
114
+ left_is_big_enough = nms_horizontal(d[..., :, :-1] / d[..., :, 1:], t)
115
+ bottom_is_big_enough = nms_vertical(d[..., 1:, :] / d[..., :-1, :], t)
116
+ top_is_big_enough = nms_vertical(d[..., :-1, :] / d[..., 1:, :], t)
117
+ return (
118
+ left_is_big_enough,
119
+ top_is_big_enough,
120
+ right_is_big_enough,
121
+ bottom_is_big_enough,
122
+ )
123
+
124
+
125
+ def fgbg_binary_mask(
126
+ d: np.ndarray,
127
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
128
+ """Find foreground-background relations between neighboring pixels in binary masks.
129
+
130
+ Args:
131
+ ----
132
+ d (np.ndarray): Binary depth matrix.
133
+
134
+ Returns:
135
+ -------
136
+ Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: Four matrices indicating
137
+ left, top, right, and bottom foreground-background relations in binary masks.
138
+
139
+ """
140
+ assert d.dtype == bool
141
+ right_is_big_enough = d[..., :, 1:] & ~d[..., :, :-1]
142
+ left_is_big_enough = d[..., :, :-1] & ~d[..., :, 1:]
143
+ bottom_is_big_enough = d[..., 1:, :] & ~d[..., :-1, :]
144
+ top_is_big_enough = d[..., :-1, :] & ~d[..., 1:, :]
145
+ return (
146
+ left_is_big_enough,
147
+ top_is_big_enough,
148
+ right_is_big_enough,
149
+ bottom_is_big_enough,
150
+ )
151
+
152
+
153
+ def edge_recall_matting(pr: np.ndarray, gt: np.ndarray, t: float) -> float:
154
+ """Calculate edge recall for image matting.
155
+
156
+ Args:
157
+ ----
158
+ pr (np.ndarray): Predicted depth matrix.
159
+ gt (np.ndarray): Ground truth binary mask.
160
+ t (float): Threshold for NMS.
161
+
162
+ Returns:
163
+ -------
164
+ float: Edge recall value.
165
+
166
+ """
167
+ assert gt.dtype == bool
168
+ ap, bp, cp, dp = fgbg_depth_thinned(pr, t)
169
+ ag, bg, cg, dg = fgbg_binary_mask(gt)
170
+ return 0.25 * (
171
+ np.count_nonzero(ap & ag) / max(np.count_nonzero(ag), 1)
172
+ + np.count_nonzero(bp & bg) / max(np.count_nonzero(bg), 1)
173
+ + np.count_nonzero(cp & cg) / max(np.count_nonzero(cg), 1)
174
+ + np.count_nonzero(dp & dg) / max(np.count_nonzero(dg), 1)
175
+ )
176
+
177
+
178
+ def boundary_f1(
179
+ pr: np.ndarray,
180
+ gt: np.ndarray,
181
+ t: float,
182
+ return_p: bool = False,
183
+ return_r: bool = False,
184
+ ) -> float:
185
+ """Calculate Boundary F1 score.
186
+
187
+ Args:
188
+ ----
189
+ pr (np.ndarray): Predicted depth matrix.
190
+ gt (np.ndarray): Ground truth depth matrix.
191
+ t (float): Threshold for comparison.
192
+ return_p (bool, optional): If True, return precision. Defaults to False.
193
+ return_r (bool, optional): If True, return recall. Defaults to False.
194
+
195
+ Returns:
196
+ -------
197
+ float: Boundary F1 score, or precision, or recall depending on the flags.
198
+
199
+ """
200
+ ap, bp, cp, dp = fgbg_depth(pr, t)
201
+ ag, bg, cg, dg = fgbg_depth(gt, t)
202
+
203
+ r = 0.25 * (
204
+ np.count_nonzero(ap & ag) / max(np.count_nonzero(ag), 1)
205
+ + np.count_nonzero(bp & bg) / max(np.count_nonzero(bg), 1)
206
+ + np.count_nonzero(cp & cg) / max(np.count_nonzero(cg), 1)
207
+ + np.count_nonzero(dp & dg) / max(np.count_nonzero(dg), 1)
208
+ )
209
+ p = 0.25 * (
210
+ np.count_nonzero(ap & ag) / max(np.count_nonzero(ap), 1)
211
+ + np.count_nonzero(bp & bg) / max(np.count_nonzero(bp), 1)
212
+ + np.count_nonzero(cp & cg) / max(np.count_nonzero(cp), 1)
213
+ + np.count_nonzero(dp & dg) / max(np.count_nonzero(dp), 1)
214
+ )
215
+ if r + p == 0:
216
+ return 0.0
217
+ if return_p:
218
+ return p
219
+ if return_r:
220
+ return r
221
+ return 2 * (r * p) / (r + p)
222
+
223
+
224
+ def get_thresholds_and_weights(
225
+ t_min: float, t_max: float, N: int
226
+ ) -> Tuple[np.ndarray, np.ndarray]:
227
+ """Generate thresholds and weights for the given range.
228
+
229
+ Args:
230
+ ----
231
+ t_min (float): Minimum threshold.
232
+ t_max (float): Maximum threshold.
233
+ N (int): Number of thresholds.
234
+
235
+ Returns:
236
+ -------
237
+ Tuple[np.ndarray, np.ndarray]: Array of thresholds and corresponding weights.
238
+
239
+ """
240
+ thresholds = np.linspace(t_min, t_max, N)
241
+ weights = thresholds / thresholds.sum()
242
+ return thresholds, weights
243
+
244
+
245
+ def invert_depth(depth: np.ndarray, eps: float = 1e-6) -> np.ndarray:
246
+ """Inverts a depth map with numerical stability.
247
+
248
+ Args:
249
+ ----
250
+ depth (np.ndarray): Depth map to be inverted.
251
+ eps (float): Minimum value to avoid division by zero (default is 1e-6).
252
+
253
+ Returns:
254
+ -------
255
+ np.ndarray: Inverted depth map.
256
+
257
+ """
258
+ inverse_depth = 1.0 / depth.clip(min=eps)
259
+ return inverse_depth
260
+
261
+
262
+ def SI_boundary_F1(
263
+ predicted_depth: np.ndarray,
264
+ target_depth: np.ndarray,
265
+ t_min: float = 1.05,
266
+ t_max: float = 1.25,
267
+ N: int = 10,
268
+ ) -> float:
269
+ """Calculate Scale-Invariant Boundary F1 Score for depth-based ground-truth.
270
+
271
+ Args:
272
+ ----
273
+ predicted_depth (np.ndarray): Predicted depth matrix.
274
+ target_depth (np.ndarray): Ground truth depth matrix.
275
+ t_min (float, optional): Minimum threshold. Defaults to 1.05.
276
+ t_max (float, optional): Maximum threshold. Defaults to 1.25.
277
+ N (int, optional): Number of thresholds. Defaults to 10.
278
+
279
+ Returns:
280
+ -------
281
+ float: Scale-Invariant Boundary F1 Score.
282
+
283
+ """
284
+ assert predicted_depth.ndim == target_depth.ndim == 2
285
+ thresholds, weights = get_thresholds_and_weights(t_min, t_max, N)
286
+ f1_scores = np.array(
287
+ [
288
+ boundary_f1(invert_depth(predicted_depth), invert_depth(target_depth), t)
289
+ for t in thresholds
290
+ ]
291
+ )
292
+ return np.sum(f1_scores * weights)
293
+
294
+
295
+ def SI_boundary_Recall(
296
+ predicted_depth: np.ndarray,
297
+ target_mask: np.ndarray,
298
+ t_min: float = 1.05,
299
+ t_max: float = 1.25,
300
+ N: int = 10,
301
+ alpha_threshold: float = 0.1,
302
+ ) -> float:
303
+ """Calculate Scale-Invariant Boundary Recall Score for mask-based ground-truth.
304
+
305
+ Args:
306
+ ----
307
+ predicted_depth (np.ndarray): Predicted depth matrix.
308
+ target_mask (np.ndarray): Ground truth binary mask.
309
+ t_min (float, optional): Minimum threshold. Defaults to 1.05.
310
+ t_max (float, optional): Maximum threshold. Defaults to 1.25.
311
+ N (int, optional): Number of thresholds. Defaults to 10.
312
+ alpha_threshold (float, optional): Threshold for alpha masking. Defaults to 0.1.
313
+
314
+ Returns:
315
+ -------
316
+ float: Scale-Invariant Boundary Recall Score.
317
+
318
+ """
319
+ assert predicted_depth.ndim == target_mask.ndim == 2
320
+ thresholds, weights = get_thresholds_and_weights(t_min, t_max, N)
321
+ thresholded_target = target_mask > alpha_threshold
322
+
323
+ recall_scores = np.array(
324
+ [
325
+ edge_recall_matting(
326
+ invert_depth(predicted_depth), thresholded_target, t=float(t)
327
+ )
328
+ for t in thresholds
329
+ ]
330
+ )
331
+ weighted_recall = np.sum(recall_scores * weights)
332
+ return weighted_recall
finetune/modules/depth_warping/depth_pro/eval/dis5k_sample_list.txt ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DIS5K/DIS-TE1/im/12#Graphics#4#TrafficSign#8245751856_821be14f86_o.jpg
2
+ DIS5K/DIS-TE1/im/13#Insect#4#Butterfly#16023994688_7ff8cdccb1_o.jpg
3
+ DIS5K/DIS-TE1/im/14#Kitchenware#4#Kitchenware#IMG_20210520_205538.jpg
4
+ DIS5K/DIS-TE1/im/14#Kitchenware#8#SweetStand#4848284981_fc90f54b50_o.jpg
5
+ DIS5K/DIS-TE1/im/17#Non-motor Vehicle#4#Cart#15012855035_d10b57014f_o.jpg
6
+ DIS5K/DIS-TE1/im/2#Aircraft#5#Kite#13104545564_5afceec9bd_o.jpg
7
+ DIS5K/DIS-TE1/im/20#Sports#10#Skateboarding#8472763540_bb2390e928_o.jpg
8
+ DIS5K/DIS-TE1/im/21#Tool#14#Sword#32473146960_dcc6b77848_o.jpg
9
+ DIS5K/DIS-TE1/im/21#Tool#15#Tapeline#9680492386_2d2020f282_o.jpg
10
+ DIS5K/DIS-TE1/im/21#Tool#4#Flag#507752845_ef852100f0_o.jpg
11
+ DIS5K/DIS-TE1/im/21#Tool#6#Key#11966089533_3becd78b44_o.jpg
12
+ DIS5K/DIS-TE1/im/21#Tool#8#Scale#31946428472_d28def471b_o.jpg
13
+ DIS5K/DIS-TE1/im/22#Weapon#4#Rifle#8472656430_3eb908b211_o.jpg
14
+ DIS5K/DIS-TE1/im/8#Electronics#3#Earphone#1177468301_641df8c267_o.jpg
15
+ DIS5K/DIS-TE1/im/8#Electronics#9#MusicPlayer#2235782872_7d47847bb4_o.jpg
16
+ DIS5K/DIS-TE2/im/11#Furniture#13#Ladder#3878434417_2ed740586e_o.jpg
17
+ DIS5K/DIS-TE2/im/13#Insect#1#Ant#27047700955_3b3a1271f8_o.jpg
18
+ DIS5K/DIS-TE2/im/13#Insect#11#Spider#5567179191_38d1f65589_o.jpg
19
+ DIS5K/DIS-TE2/im/13#Insect#8#Locust#5237933769_e6687c05e4_o.jpg
20
+ DIS5K/DIS-TE2/im/14#Kitchenware#2#DishRack#70838854_40cf689da7_o.jpg
21
+ DIS5K/DIS-TE2/im/14#Kitchenware#8#SweetStand#8467929412_fef7f4275d_o.jpg
22
+ DIS5K/DIS-TE2/im/16#Music Instrument#2#Harp#28058219806_28e05ff24a_o.jpg
23
+ DIS5K/DIS-TE2/im/17#Non-motor Vehicle#1#BabyCarriage#29794777180_2e1695a0cf_o.jpg
24
+ DIS5K/DIS-TE2/im/19#Ship#3#Sailboat#22442908623_5977e3becf_o.jpg
25
+ DIS5K/DIS-TE2/im/2#Aircraft#5#Kite#44654358051_1400e71cc4_o.jpg
26
+ DIS5K/DIS-TE2/im/21#Tool#11#Stand#IMG_20210520_205442.jpg
27
+ DIS5K/DIS-TE2/im/21#Tool#17#Tripod#9318977876_34615ec9a0_o.jpg
28
+ DIS5K/DIS-TE2/im/5#Artifact#3#Handcraft#50860882577_8482143b1b_o.jpg
29
+ DIS5K/DIS-TE2/im/8#Electronics#10#Robot#3093360210_fee54dc5c5_o.jpg
30
+ DIS5K/DIS-TE2/im/8#Electronics#6#Microphone#47411477652_6da66cbc10_o.jpg
31
+ DIS5K/DIS-TE3/im/14#Kitchenware#4#Kitchenware#2451122898_ef883175dd_o.jpg
32
+ DIS5K/DIS-TE3/im/15#Machine#4#SewingMachine#9311164128_97ba1d3947_o.jpg
33
+ DIS5K/DIS-TE3/im/16#Music Instrument#2#Harp#7670920550_59e992fd7b_o.jpg
34
+ DIS5K/DIS-TE3/im/17#Non-motor Vehicle#1#BabyCarriage#8389984877_1fddf8715c_o.jpg
35
+ DIS5K/DIS-TE3/im/17#Non-motor Vehicle#3#Carriage#5947122724_98e0fc3d1f_o.jpg
36
+ DIS5K/DIS-TE3/im/2#Aircraft#2#Balloon#2487168092_641505883f_o.jpg
37
+ DIS5K/DIS-TE3/im/2#Aircraft#4#Helicopter#8401177591_06c71c8df2_o.jpg
38
+ DIS5K/DIS-TE3/im/20#Sports#1#Archery#12520003103_faa43ea3e0_o.jpg
39
+ DIS5K/DIS-TE3/im/21#Tool#11#Stand#IMG_20210709_221507.jpg
40
+ DIS5K/DIS-TE3/im/21#Tool#2#Clip#5656649687_63d0c6696d_o.jpg
41
+ DIS5K/DIS-TE3/im/21#Tool#6#Key#12878459244_6387a140ea_o.jpg
42
+ DIS5K/DIS-TE3/im/3#Aquatic#1#Lobster#109214461_f52b4b6093_o.jpg
43
+ DIS5K/DIS-TE3/im/4#Architecture#19#Windmill#20195851863_2627117e0e_o.jpg
44
+ DIS5K/DIS-TE3/im/5#Artifact#2#Cage#5821476369_ea23927487_o.jpg
45
+ DIS5K/DIS-TE3/im/8#Electronics#7#MobileHolder#49732997896_7f53c290b5_o.jpg
46
+ DIS5K/DIS-TE4/im/13#Insect#6#Centipede#15302179708_a267850881_o.jpg
47
+ DIS5K/DIS-TE4/im/17#Non-motor Vehicle#11#Tricycle#5771069105_a3aef6f665_o.jpg
48
+ DIS5K/DIS-TE4/im/17#Non-motor Vehicle#2#Bicycle#4245936196_fdf812dcb7_o.jpg
49
+ DIS5K/DIS-TE4/im/17#Non-motor Vehicle#9#ShoppingCart#4674052920_a5b7a2b236_o.jpg
50
+ DIS5K/DIS-TE4/im/18#Plant#1#Bonsai#3539420884_ca8973e2c0_o.jpg
51
+ DIS5K/DIS-TE4/im/2#Aircraft#6#Parachute#33590416634_9d6f2325e7_o.jpg
52
+ DIS5K/DIS-TE4/im/20#Sports#1#Archery#46924476515_0be1caa684_o.jpg
53
+ DIS5K/DIS-TE4/im/20#Sports#8#Racket#19337607166_dd1985fb59_o.jpg
54
+ DIS5K/DIS-TE4/im/21#Tool#6#Key#3193329588_839b0c74ce_o.jpg
55
+ DIS5K/DIS-TE4/im/5#Artifact#2#Cage#5821886526_0573ba2d0d_o.jpg
56
+ DIS5K/DIS-TE4/im/5#Artifact#3#Handcraft#50105138282_3c1d02c968_o.jpg
57
+ DIS5K/DIS-TE4/im/8#Electronics#1#Antenna#4305034305_874f21a701_o.jpg
58
+ DIS5K/DIS-TR/im/1#Accessories#1#Bag#15554964549_3105e51b6f_o.jpg
59
+ DIS5K/DIS-TR/im/1#Accessories#1#Bag#41104261980_098a6c4a56_o.jpg
60
+ DIS5K/DIS-TR/im/1#Accessories#2#Clothes#2284764037_871b2e8ca4_o.jpg
61
+ DIS5K/DIS-TR/im/1#Accessories#3#Eyeglasses#1824643784_70d0134156_o.jpg
62
+ DIS5K/DIS-TR/im/1#Accessories#3#Eyeglasses#3590020230_37b09a29b3_o.jpg
63
+ DIS5K/DIS-TR/im/1#Accessories#3#Eyeglasses#4809652879_4da8a69f3b_o.jpg
64
+ DIS5K/DIS-TR/im/1#Accessories#3#Eyeglasses#792204934_f9b28f99b4_o.jpg
65
+ DIS5K/DIS-TR/im/1#Accessories#5#Jewelry#13909132974_c4750c5fb7_o.jpg
66
+ DIS5K/DIS-TR/im/1#Accessories#7#Shoe#2483391615_9199ece8d6_o.jpg
67
+ DIS5K/DIS-TR/im/1#Accessories#8#Watch#4343266960_f6633b029b_o.jpg
68
+ DIS5K/DIS-TR/im/10#Frame#2#BicycleFrame#17897573_42964dd104_o.jpg
69
+ DIS5K/DIS-TR/im/10#Frame#5#Rack#15898634812_64807069ff_o.jpg
70
+ DIS5K/DIS-TR/im/10#Frame#5#Rack#23928546819_c184cb0b60_o.jpg
71
+ DIS5K/DIS-TR/im/11#Furniture#19#Shower#6189119596_77bcfe80ee_o.jpg
72
+ DIS5K/DIS-TR/im/11#Furniture#2#Bench#3263647075_9306e280b5_o.jpg
73
+ DIS5K/DIS-TR/im/11#Furniture#5#CoatHanger#12774091054_cd5ff520ef_o.jpg
74
+ DIS5K/DIS-TR/im/11#Furniture#6#DentalChair#13878156865_d0439dcb32_o.jpg
75
+ DIS5K/DIS-TR/im/11#Furniture#9#Easel#5861024714_2070cd480c_o.jpg
76
+ DIS5K/DIS-TR/im/12#Graphics#4#TrafficSign#40621867334_f3c32ec189_o.jpg
77
+ DIS5K/DIS-TR/im/13#Insect#1#Ant#3295038190_db5dd0d4f4_o.jpg
78
+ DIS5K/DIS-TR/im/13#Insect#10#Mosquito#24341339_a88a1dad4c_o.jpg
79
+ DIS5K/DIS-TR/im/13#Insect#11#Spider#27171518270_63b78069ff_o.jpg
80
+ DIS5K/DIS-TR/im/13#Insect#11#Spider#49925050281_fa727c154e_o.jpg
81
+ DIS5K/DIS-TR/im/13#Insect#2#Beatle#279616486_2f1e64f591_o.jpg
82
+ DIS5K/DIS-TR/im/13#Insect#3#Bee#43892067695_82cf3e536b_o.jpg
83
+ DIS5K/DIS-TR/im/13#Insect#6#Centipede#20874281788_3e15c90a1c_o.jpg
84
+ DIS5K/DIS-TR/im/13#Insect#7#Dragonfly#14106671120_1b824d77e4_o.jpg
85
+ DIS5K/DIS-TR/im/13#Insect#8#Locust#21637491048_676ef7c9f7_o.jpg
86
+ DIS5K/DIS-TR/im/13#Insect#9#Mantis#1381120202_9dff6987b2_o.jpg
87
+ DIS5K/DIS-TR/im/14#Kitchenware#1#Cup#12812517473_327d6474b8_o.jpg
88
+ DIS5K/DIS-TR/im/14#Kitchenware#10#WineGlass#6402491641_389275d4d1_o.jpg
89
+ DIS5K/DIS-TR/im/14#Kitchenware#3#Hydrovalve#3129932040_8c05825004_o.jpg
90
+ DIS5K/DIS-TR/im/14#Kitchenware#4#Kitchenware#2881934780_87d5218ebb_o.jpg
91
+ DIS5K/DIS-TR/im/14#Kitchenware#4#Kitchenware#IMG_20210520_205527.jpg
92
+ DIS5K/DIS-TR/im/14#Kitchenware#6#Spoon#32989113501_b69eccf0df_o.jpg
93
+ DIS5K/DIS-TR/im/14#Kitchenware#8#SweetStand#2867322189_c56d1e0b87_o.jpg
94
+ DIS5K/DIS-TR/im/15#Machine#1#Gear#19217846720_f5f2807475_o.jpg
95
+ DIS5K/DIS-TR/im/15#Machine#2#Machine#1620160659_9571b7a7ab_o.jpg
96
+ DIS5K/DIS-TR/im/16#Music Instrument#2#Harp#6012801603_1a6e2c16a6_o.jpg
97
+ DIS5K/DIS-TR/im/16#Music Instrument#5#Trombone#8683292118_d223c17ccb_o.jpg
98
+ DIS5K/DIS-TR/im/16#Music Instrument#6#Trumpet#8393262740_b8c216142c_o.jpg
99
+ DIS5K/DIS-TR/im/16#Music Instrument#8#Violin#1511267391_40e4949d68_o.jpg
100
+ DIS5K/DIS-TR/im/17#Non-motor Vehicle#1#BabyCarriage#6989512997_38b3dbc88b_o.jpg
101
+ DIS5K/DIS-TR/im/17#Non-motor Vehicle#12#Wheel#14627183228_b2d68cf501_o.jpg
102
+ DIS5K/DIS-TR/im/17#Non-motor Vehicle#12#Wheel#2932226475_1b2403e549_o.jpg
103
+ DIS5K/DIS-TR/im/17#Non-motor Vehicle#12#Wheel#5420155648_86459905b8_o.jpg
104
+ DIS5K/DIS-TR/im/17#Non-motor Vehicle#2#Bicycle#IMG_20210513_134904.jpg
105
+ DIS5K/DIS-TR/im/17#Non-motor Vehicle#3#Carriage#3311962551_6f211b7bd6_o.jpg
106
+ DIS5K/DIS-TR/im/17#Non-motor Vehicle#4#Cart#2609732026_baf7fff3a1_o.jpg
107
+ DIS5K/DIS-TR/im/17#Non-motor Vehicle#5#Handcart#5821282211_201cefeaf2_o.jpg
108
+ DIS5K/DIS-TR/im/17#Non-motor Vehicle#7#Mower#5779003232_3bb3ae531a_o.jpg
109
+ DIS5K/DIS-TR/im/17#Non-motor Vehicle#9#ShoppingCart#10051622843_ace07e32b8_o.jpg
110
+ DIS5K/DIS-TR/im/17#Non-motor Vehicle#9#ShoppingCart#8075259294_f23e243849_o.jpg
111
+ DIS5K/DIS-TR/im/18#Plant#2#Tree#44800999741_e377e16dbb_o.jpg
112
+ DIS5K/DIS-TR/im/2#Aircraft#1#Airplane#2631761913_3ac67d0223_o.jpg
113
+ DIS5K/DIS-TR/im/2#Aircraft#1#Airplane#37707911566_e908a261b6_o.jpg
114
+ DIS5K/DIS-TR/im/2#Aircraft#3#HangGlider#2557220131_b8506920c5_o.jpg
115
+ DIS5K/DIS-TR/im/2#Aircraft#4#Helicopter#6215659280_5dbd9b4546_o.jpg
116
+ DIS5K/DIS-TR/im/2#Aircraft#6#Parachute#20185790493_e56fcaf8c6_o.jpg
117
+ DIS5K/DIS-TR/im/20#Sports#1#Archery#3871269982_ae4c59a7eb_o.jpg
118
+ DIS5K/DIS-TR/im/20#Sports#9#RockClimbing#9662433268_51299bc50e_o.jpg
119
+ DIS5K/DIS-TR/im/21#Tool#14#Sword#26258479365_2950d7fa37_o.jpg
120
+ DIS5K/DIS-TR/im/21#Tool#15#Tapeline#15505703447_e0fdeaa5a6_o.jpg
121
+ DIS5K/DIS-TR/im/21#Tool#4#Flag#26678602024_9b665742de_o.jpg
122
+ DIS5K/DIS-TR/im/21#Tool#4#Flag#5774823110_d603ce3cc8_o.jpg
123
+ DIS5K/DIS-TR/im/21#Tool#5#Hook#6867989814_dba18d673c_o.jpg
124
+ DIS5K/DIS-TR/im/22#Weapon#4#Rifle#4451713125_cd91719189_o.jpg
125
+ DIS5K/DIS-TR/im/3#Aquatic#2#Seadragon#4910944581_913139b238_o.jpg
126
+ DIS5K/DIS-TR/im/4#Architecture#12#Scaffold#3661448960_8aff24cc4d_o.jpg
127
+ DIS5K/DIS-TR/im/4#Architecture#13#Sculpture#6385318715_9a88d4eba7_o.jpg
128
+ DIS5K/DIS-TR/im/4#Architecture#17#Well#5011603479_75cf42808a_o.jpg
129
+ DIS5K/DIS-TR/im/5#Artifact#2#Cage#4892828841_7f1bc05682_o.jpg
130
+ DIS5K/DIS-TR/im/5#Artifact#3#Handcraft#15404211628_9e9ff2ce2e_o.jpg
131
+ DIS5K/DIS-TR/im/5#Artifact#3#Handcraft#3200169865_7c84cfcccf_o.jpg
132
+ DIS5K/DIS-TR/im/5#Artifact#3#Handcraft#5859295071_c217e7c22f_o.jpg
133
+ DIS5K/DIS-TR/im/6#Automobile#10#SteeringWheel#17200338026_f1e2122d8e_o.jpg
134
+ DIS5K/DIS-TR/im/6#Automobile#3#Car#3780893425_1a7d275e09_o.jpg
135
+ DIS5K/DIS-TR/im/6#Automobile#5#Crane#15282506502_1b1132a7c3_o.jpg
136
+ DIS5K/DIS-TR/im/7#Electrical#1#Cable#16767791875_8e6df41752_o.jpg
137
+ DIS5K/DIS-TR/im/7#Electrical#1#Cable#3291433361_38747324c4_o.jpg
138
+ DIS5K/DIS-TR/im/7#Electrical#1#Cable#4195104238_12a754c61a_o.jpg
139
+ DIS5K/DIS-TR/im/7#Electrical#1#Cable#49645415132_61e5664ecf_o.jpg
140
+ DIS5K/DIS-TR/im/7#Electrical#1#Cable#IMG_20210521_232406.jpg
141
+ DIS5K/DIS-TR/im/7#Electrical#10#UtilityPole#3298312021_92f431e3e9_o.jpg
142
+ DIS5K/DIS-TR/im/7#Electrical#10#UtilityPole#47950134773_fbfff63f4e_o.jpg
143
+ DIS5K/DIS-TR/im/7#Electrical#11#VacuumCleaner#5448403677_6a29e21881_o.jpg
144
+ DIS5K/DIS-TR/im/7#Electrical#2#CeilingLamp#611568868_680ed5d39f_o.jpg
145
+ DIS5K/DIS-TR/im/7#Electrical#3#Fan#3391683115_990525a693_o.jpg
146
+ DIS5K/DIS-TR/im/7#Electrical#6#StreetLamp#150049122_0692266618_o.jpg
147
+ DIS5K/DIS-TR/im/7#Electrical#9#TransmissionTower#31433908671_7e7e277dfe_o.jpg
148
+ DIS5K/DIS-TR/im/8#Electronics#1#Antenna#8727884873_e0622ee5c4_o.jpg
149
+ DIS5K/DIS-TR/im/8#Electronics#2#Camcorder#4172690390_7e5f280ace_o.jpg
150
+ DIS5K/DIS-TR/im/8#Electronics#3#Earphone#413984555_f290febdf5_o.jpg
151
+ DIS5K/DIS-TR/im/8#Electronics#5#Headset#30574225373_3717ed9fa4_o.jpg
152
+ DIS5K/DIS-TR/im/8#Electronics#6#Microphone#538006482_4aae4f5bd6_o.jpg
153
+ DIS5K/DIS-TR/im/8#Electronics#9#MusicPlayer#1306012480_2ea80d2afd_o.jpg
154
+ DIS5K/DIS-TR/im/9#Entertainment#1#GymEquipment#33071754135_8f3195cbd1_o.jpg
155
+ DIS5K/DIS-TR/im/9#Entertainment#2#KidsPlayground#2305807849_be53d724ea_o.jpg
156
+ DIS5K/DIS-TR/im/9#Entertainment#2#KidsPlayground#3862040422_5bbf903204_o.jpg
157
+ DIS5K/DIS-TR/im/9#Entertainment#3#OutdoorFitnessEquipment#10814507005_3dacaa28b3_o.jpg
158
+ DIS5K/DIS-TR/im/9#Entertainment#4#FerrisWheel#81640293_4b0ee62040_o.jpg
159
+ DIS5K/DIS-TR/im/9#Entertainment#5#Swing#49867339188_08073f4b76_o.jpg
160
+ DIS5K/DIS-VD/im/1#Accessories#1#Bag#6815402415_e01c1a41e6_o.jpg
161
+ DIS5K/DIS-VD/im/1#Accessories#5#Jewelry#2744070193_1486582e8d_o.jpg
162
+ DIS5K/DIS-VD/im/10#Frame#1#BasketballHoop#IMG_20210521_232650.jpg
163
+ DIS5K/DIS-VD/im/10#Frame#5#Rack#6156611713_49ebf12b1e_o.jpg
164
+ DIS5K/DIS-VD/im/11#Furniture#11#Handrail#3276641240_1b84b5af85_o.jpg
165
+ DIS5K/DIS-VD/im/11#Furniture#13#Ladder#33423266_5391cf47e9_o.jpg
166
+ DIS5K/DIS-VD/im/11#Furniture#17#Table#3725111755_4fc101e7ab_o.jpg
167
+ DIS5K/DIS-VD/im/11#Furniture#2#Bench#35556410400_7235b58070_o.jpg
168
+ DIS5K/DIS-VD/im/11#Furniture#4#Chair#3301769985_e49de6739f_o.jpg
169
+ DIS5K/DIS-VD/im/11#Furniture#6#DentalChair#23811071619_2a95c3a688_o.jpg
170
+ DIS5K/DIS-VD/im/11#Furniture#9#Easel#8322807354_df6d56542e_o.jpg
171
+ DIS5K/DIS-VD/im/13#Insect#10#Mosquito#12391674863_0cdf430d3f_o.jpg
172
+ DIS5K/DIS-VD/im/13#Insect#7#Dragonfly#14693028899_344ea118f2_o.jpg
173
+ DIS5K/DIS-VD/im/14#Kitchenware#10#WineGlass#4450148455_8f460f541a_o.jpg
174
+ DIS5K/DIS-VD/im/14#Kitchenware#3#Hydrovalve#IMG_20210520_203410.jpg
175
+ DIS5K/DIS-VD/im/15#Machine#3#PlowHarrow#34521712846_df4babb024_o.jpg
176
+ DIS5K/DIS-VD/im/16#Music Instrument#5#Trombone#6222242743_e7189405cd_o.jpg
177
+ DIS5K/DIS-VD/im/17#Non-motor Vehicle#12#Wheel#25677578797_ea47e1d9e8_o.jpg
178
+ DIS5K/DIS-VD/im/17#Non-motor Vehicle#2#Bicycle#5153474856_21560b081b_o.jpg
179
+ DIS5K/DIS-VD/im/17#Non-motor Vehicle#7#Mower#16992510572_8a6ff27398_o.jpg
180
+ DIS5K/DIS-VD/im/19#Ship#2#Canoe#40571458163_7faf8b73d9_o.jpg
181
+ DIS5K/DIS-VD/im/2#Aircraft#1#Airplane#4270588164_66a619e834_o.jpg
182
+ DIS5K/DIS-VD/im/2#Aircraft#4#Helicopter#86789665_650b94b2ee_o.jpg
183
+ DIS5K/DIS-VD/im/20#Sports#14#Wakesurfing#5589577652_5061c168d2_o.jpg
184
+ DIS5K/DIS-VD/im/21#Tool#10#Spade#37018312543_63b21b0784_o.jpg
185
+ DIS5K/DIS-VD/im/21#Tool#14#Sword#24789047250_42df9bf422_o.jpg
186
+ DIS5K/DIS-VD/im/21#Tool#18#Umbrella#IMG_20210513_140445.jpg
187
+ DIS5K/DIS-VD/im/21#Tool#6#Key#43939732715_5a6e28b518_o.jpg
188
+ DIS5K/DIS-VD/im/22#Weapon#1#Cannon#12758066705_90b54295e7_o.jpg
189
+ DIS5K/DIS-VD/im/22#Weapon#4#Rifle#8019368790_fb6dc469a7_o.jpg
190
+ DIS5K/DIS-VD/im/3#Aquatic#5#Shrimp#2582833427_7a99e7356e_o.jpg
191
+ DIS5K/DIS-VD/im/4#Architecture#12#Scaffold#1013402687_590750354e_o.jpg
192
+ DIS5K/DIS-VD/im/4#Architecture#13#Sculpture#17176841759_272a3ed6e3_o.jpg
193
+ DIS5K/DIS-VD/im/4#Architecture#14#Stair#15079108505_0d11281624_o.jpg
194
+ DIS5K/DIS-VD/im/4#Architecture#19#Windmill#2928111082_ceb3051c04_o.jpg
195
+ DIS5K/DIS-VD/im/4#Architecture#3#Crack#3551574032_17dd106d31_o.jpg
196
+ DIS5K/DIS-VD/im/4#Architecture#5#GasStation#4564307581_c3069bdc62_o.jpg
197
+ DIS5K/DIS-VD/im/4#Architecture#8#ObservationTower#2704526950_d4f0ddc807_o.jpg
198
+ DIS5K/DIS-VD/im/5#Artifact#3#Handcraft#10873642323_1bafce3aa5_o.jpg
199
+ DIS5K/DIS-VD/im/6#Automobile#11#Tractor#8594504006_0c2c557d85_o.jpg
200
+ DIS5K/DIS-VD/im/8#Electronics#3#Earphone#8106454803_1178d867cc_o.jpg
finetune/modules/depth_warping/depth_pro/network/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
2
+ """Depth Pro network blocks."""
finetune/modules/depth_warping/depth_pro/network/decoder.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Copyright (C) 2024 Apple Inc. All Rights Reserved.
2
+
3
+ Dense Prediction Transformer Decoder architecture.
4
+
5
+ Implements a variant of Vision Transformers for Dense Prediction, https://arxiv.org/abs/2103.13413
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Iterable
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+
16
+ class MultiresConvDecoder(nn.Module):
17
+ """Decoder for multi-resolution encodings."""
18
+
19
+ def __init__(
20
+ self,
21
+ dims_encoder: Iterable[int],
22
+ dim_decoder: int,
23
+ ):
24
+ """Initialize multiresolution convolutional decoder.
25
+
26
+ Args:
27
+ ----
28
+ dims_encoder: Expected dims at each level from the encoder.
29
+ dim_decoder: Dim of decoder features.
30
+
31
+ """
32
+ super().__init__()
33
+ self.dims_encoder = list(dims_encoder)
34
+ self.dim_decoder = dim_decoder
35
+ self.dim_out = dim_decoder
36
+
37
+ num_encoders = len(self.dims_encoder)
38
+
39
+ # At the highest resolution, i.e. level 0, we apply projection w/ 1x1 convolution
40
+ # when the dimensions mismatch. Otherwise we do not do anything, which is
41
+ # the default behavior of monodepth.
42
+ conv0 = (
43
+ nn.Conv2d(self.dims_encoder[0], dim_decoder, kernel_size=1, bias=False)
44
+ if self.dims_encoder[0] != dim_decoder
45
+ else nn.Identity()
46
+ )
47
+
48
+ convs = [conv0]
49
+ for i in range(1, num_encoders):
50
+ convs.append(
51
+ nn.Conv2d(
52
+ self.dims_encoder[i],
53
+ dim_decoder,
54
+ kernel_size=3,
55
+ stride=1,
56
+ padding=1,
57
+ bias=False,
58
+ )
59
+ )
60
+
61
+ self.convs = nn.ModuleList(convs)
62
+
63
+ fusions = []
64
+ for i in range(num_encoders):
65
+ fusions.append(
66
+ FeatureFusionBlock2d(
67
+ num_features=dim_decoder,
68
+ deconv=(i != 0),
69
+ batch_norm=False,
70
+ )
71
+ )
72
+ self.fusions = nn.ModuleList(fusions)
73
+
74
+ def forward(self, encodings: torch.Tensor) -> torch.Tensor:
75
+ """Decode the multi-resolution encodings."""
76
+ num_levels = len(encodings)
77
+ num_encoders = len(self.dims_encoder)
78
+
79
+ if num_levels != num_encoders:
80
+ raise ValueError(
81
+ f"Got encoder output levels={num_levels}, expected levels={num_encoders+1}."
82
+ )
83
+
84
+ # Project features of different encoder dims to the same decoder dim.
85
+ # Fuse features from the lowest resolution (num_levels-1)
86
+ # to the highest (0).
87
+ features = self.convs[-1](encodings[-1])
88
+ lowres_features = features
89
+ features = self.fusions[-1](features)
90
+ for i in range(num_levels - 2, -1, -1):
91
+ features_i = self.convs[i](encodings[i])
92
+ features = self.fusions[i](features, features_i)
93
+ return features, lowres_features
94
+
95
+
96
+ class ResidualBlock(nn.Module):
97
+ """Generic implementation of residual blocks.
98
+
99
+ This implements a generic residual block from
100
+ He et al. - Identity Mappings in Deep Residual Networks (2016),
101
+ https://arxiv.org/abs/1603.05027
102
+ which can be further customized via factory functions.
103
+ """
104
+
105
+ def __init__(self, residual: nn.Module, shortcut: nn.Module | None = None) -> None:
106
+ """Initialize ResidualBlock."""
107
+ super().__init__()
108
+ self.residual = residual
109
+ self.shortcut = shortcut
110
+
111
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
112
+ """Apply residual block."""
113
+ delta_x = self.residual(x)
114
+
115
+ if self.shortcut is not None:
116
+ x = self.shortcut(x)
117
+
118
+ return x + delta_x
119
+
120
+
121
+ class FeatureFusionBlock2d(nn.Module):
122
+ """Feature fusion for DPT."""
123
+
124
+ def __init__(
125
+ self,
126
+ num_features: int,
127
+ deconv: bool = False,
128
+ batch_norm: bool = False,
129
+ ):
130
+ """Initialize feature fusion block.
131
+
132
+ Args:
133
+ ----
134
+ num_features: Input and output dimensions.
135
+ deconv: Whether to use deconv before the final output conv.
136
+ batch_norm: Whether to use batch normalization in resnet blocks.
137
+
138
+ """
139
+ super().__init__()
140
+
141
+ self.resnet1 = self._residual_block(num_features, batch_norm)
142
+ self.resnet2 = self._residual_block(num_features, batch_norm)
143
+
144
+ self.use_deconv = deconv
145
+ if deconv:
146
+ self.deconv = nn.ConvTranspose2d(
147
+ in_channels=num_features,
148
+ out_channels=num_features,
149
+ kernel_size=2,
150
+ stride=2,
151
+ padding=0,
152
+ bias=False,
153
+ )
154
+
155
+ self.out_conv = nn.Conv2d(
156
+ num_features,
157
+ num_features,
158
+ kernel_size=1,
159
+ stride=1,
160
+ padding=0,
161
+ bias=True,
162
+ )
163
+
164
+ self.skip_add = nn.quantized.FloatFunctional()
165
+
166
+ def forward(self, x0: torch.Tensor, x1: torch.Tensor | None = None) -> torch.Tensor:
167
+ """Process and fuse input features."""
168
+ x = x0
169
+
170
+ if x1 is not None:
171
+ res = self.resnet1(x1)
172
+ x = self.skip_add.add(x, res)
173
+
174
+ x = self.resnet2(x)
175
+
176
+ if self.use_deconv:
177
+ x = self.deconv(x)
178
+ x = self.out_conv(x)
179
+
180
+ return x
181
+
182
+ @staticmethod
183
+ def _residual_block(num_features: int, batch_norm: bool):
184
+ """Create a residual block."""
185
+
186
+ def _create_block(dim: int, batch_norm: bool) -> list[nn.Module]:
187
+ layers = [
188
+ nn.ReLU(False),
189
+ nn.Conv2d(
190
+ num_features,
191
+ num_features,
192
+ kernel_size=3,
193
+ stride=1,
194
+ padding=1,
195
+ bias=not batch_norm,
196
+ ),
197
+ ]
198
+ if batch_norm:
199
+ layers.append(nn.BatchNorm2d(dim))
200
+ return layers
201
+
202
+ residual = nn.Sequential(
203
+ *_create_block(dim=num_features, batch_norm=batch_norm),
204
+ *_create_block(dim=num_features, batch_norm=batch_norm),
205
+ )
206
+ return ResidualBlock(residual)