HReynaud commited on
Commit
dab5199
·
0 Parent(s):

first commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # Ruff stuff:
171
+ .ruff_cache/
172
+
173
+ # PyPI configuration file
174
+ .pypirc
175
+
176
+ tmp/
177
+ .vscode/
178
+ .gradio/
179
+ .cursor/
180
+ *.mp4
README.md ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: EchoFlow
3
+ emoji: 💙
4
+ colorFrom: gray
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 5.22.0
8
+ app_file: demo.py
9
+ pinned: true
10
+ license: apache-2.0
11
+ python_version: 3.11.8
12
+ models:
13
+ - HReynaud/EchoFlow
14
+ datasets:
15
+ - HReynaud/EchoFlow
16
+ ---
17
+
18
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
assets/anatomies_dynamic.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d8bf0fa238ca8b4ccdf8457fc8b248cebd52b005d9385115db773ec8005dc29
3
+ size 10271965
assets/anatomies_lvh.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dfe6ff14cb9e6ba9a8d79e770423096f3bd9fa072b2a8fc984150f6e5fd91fe9
3
+ size 11179209
assets/anatomies_ped_a4c.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2675b28071004ad15f060f057ae13330f1f61369500d7507fadefe7b5ae9c74
3
+ size 3364061
assets/anatomies_ped_psax.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:881e51666a580b2830a27d7e97055a0f2ab037152547aa79af1448a2b5f65ccb
3
+ size 4635874
assets/h1.png ADDED
assets/h2.png ADDED
assets/h3.png ADDED
assets/h4.png ADDED
assets/scaling.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bbcd8f8cf990d57b96ce7e544e5f9b48b7ad2400dfc4080e0651575f666b19ac
3
+ size 1432
assets/seg.png ADDED
demo.py ADDED
@@ -0,0 +1,945 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import types
4
+ from urllib.parse import urlparse
5
+
6
+ import cv2
7
+ import diffusers
8
+ import gradio as gr
9
+ import numpy as np
10
+ import torch
11
+ from einops import rearrange
12
+ from huggingface_hub import hf_hub_download
13
+ from omegaconf import OmegaConf
14
+ from PIL import Image, ImageOps
15
+ from safetensors.torch import load_file
16
+ from torch.nn import functional as F
17
+ from torchdiffeq import odeint_adjoint as odeint
18
+
19
+ from echoflow.common import instantiate_class_from_config, unscale_latents
20
+ from echoflow.common.models import (
21
+ ContrastiveModel,
22
+ DiffuserSTDiT,
23
+ ResNet18,
24
+ SegDiTTransformer2DModel,
25
+ )
26
+
27
+ torch.set_grad_enabled(False)
28
+
29
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ dtype = torch.float32
31
+
32
+ # 4f4 latent space
33
+ B, T, C, H, W = 1, 64, 4, 28, 28
34
+
35
+ VIEWS = ["A4C", "PSAX", "PLAX"]
36
+
37
+
38
+ def load_model(path):
39
+ if path.startswith("http"):
40
+ parsed_url = urlparse(path)
41
+ if "huggingface.co" in parsed_url.netloc:
42
+ parts = parsed_url.path.strip("/").split("/")
43
+ repo_id = "/".join(parts[:2])
44
+
45
+ subfolder = None
46
+ if len(parts) > 3:
47
+ subfolder = "/".join(parts[4:])
48
+
49
+ local_root = "./tmp"
50
+ local_dir = os.path.join(local_root, repo_id.replace("/", "_"))
51
+ if subfolder:
52
+ local_dir = os.path.join(local_root, subfolder)
53
+ os.makedirs(local_root, exist_ok=True)
54
+
55
+ config_file = hf_hub_download(
56
+ repo_id=repo_id,
57
+ subfolder=subfolder,
58
+ filename="config.json",
59
+ local_dir=local_root,
60
+ repo_type="model",
61
+ token=os.getenv("READ_HF_TOKEN"),
62
+ local_dir_use_symlinks=False,
63
+ )
64
+
65
+ assert os.path.exists(config_file)
66
+
67
+ hf_hub_download(
68
+ repo_id=repo_id,
69
+ filename="diffusion_pytorch_model.safetensors",
70
+ subfolder=subfolder,
71
+ local_dir=local_root,
72
+ local_dir_use_symlinks=False,
73
+ token=os.getenv("READ_HF_TOKEN"),
74
+ )
75
+
76
+ path = local_dir
77
+
78
+ model_root = os.path.join(config_file.split("config.json")[0])
79
+ json_path = os.path.join(model_root, "config.json")
80
+ assert os.path.exists(json_path)
81
+
82
+ with open(json_path, "r") as f:
83
+ config = json.load(f)
84
+
85
+ klass_name = config["_class_name"]
86
+ klass = getattr(diffusers, klass_name, None) or globals().get(klass_name, None)
87
+ assert (
88
+ klass is not None
89
+ ), f"Could not find class {klass_name} in diffusers or global scope."
90
+ assert hasattr(
91
+ klass, "from_pretrained"
92
+ ), f"Class {klass_name} does not support 'from_pretrained'."
93
+
94
+ return klass.from_pretrained(path)
95
+
96
+
97
+ def load_reid(path):
98
+ parsed_url = urlparse(path)
99
+ parts = parsed_url.path.strip("/").split("/")
100
+ repo_id = "/".join(parts[:2])
101
+ subfolder = "/".join(parts[4:])
102
+
103
+ local_root = "./tmp"
104
+
105
+ config_file = hf_hub_download(
106
+ repo_id=repo_id,
107
+ subfolder=subfolder,
108
+ filename="config.yaml",
109
+ local_dir=local_root,
110
+ repo_type="model",
111
+ token=os.getenv("READ_HF_TOKEN"),
112
+ local_dir_use_symlinks=False,
113
+ )
114
+
115
+ weights_file = hf_hub_download(
116
+ repo_id=repo_id,
117
+ subfolder=subfolder,
118
+ filename="backbone.safetensors",
119
+ local_dir=local_root,
120
+ repo_type="model",
121
+ token=os.getenv("READ_HF_TOKEN"),
122
+ local_dir_use_symlinks=False,
123
+ )
124
+
125
+ config = OmegaConf.load(config_file)
126
+ backbone = instantiate_class_from_config(config.backbone)
127
+ backbone = ContrastiveModel.patch_backbone(
128
+ backbone, config.model.args.in_channels, config.model.args.out_channels
129
+ )
130
+ state_dict = load_file(weights_file)
131
+ backbone.load_state_dict(state_dict)
132
+ backbone = backbone.to(device, dtype=dtype)
133
+ backbone.eval()
134
+ return backbone
135
+
136
+
137
+ def get_vae_scaler(path):
138
+ scaler = torch.load(path)
139
+ scaler = {k: v.to(device) for k, v in scaler.items()}
140
+ return scaler
141
+
142
+
143
+ generator = torch.Generator(device=device).manual_seed(0)
144
+
145
+ lifm = load_model("https://huggingface.co/HReynaud/EchoFlow/tree/main/lifm/FMiT-S2-4f4")
146
+ lifm = lifm.to(device, dtype=dtype)
147
+ lifm.eval()
148
+
149
+ vae = load_model("https://huggingface.co/HReynaud/EchoFlow/tree/main/vae/avae-4f4")
150
+ vae = vae.to(device, dtype=dtype)
151
+ vae.eval()
152
+ vae_scaler = get_vae_scaler("assets/scaling.pt")
153
+
154
+ reid = {
155
+ "anatomies": {
156
+ "A4C": torch.cat(
157
+ [
158
+ torch.load("assets/anatomies_dynamic.pt"),
159
+ torch.load("assets/anatomies_ped_a4c.pt"),
160
+ ],
161
+ dim=0,
162
+ ),
163
+ "PSAX": torch.load("assets/anatomies_ped_psax.pt"),
164
+ "PLAX": torch.load("assets/anatomies_lvh.pt"),
165
+ },
166
+ "models": {
167
+ "A4C": load_reid(
168
+ "https://huggingface.co/HReynaud/EchoFlow/tree/main/reid/dynamic-4f4"
169
+ ),
170
+ "PSAX": load_reid(
171
+ "https://huggingface.co/HReynaud/EchoFlow/tree/main/reid/ped_psax-4f4"
172
+ ),
173
+ "PLAX": load_reid(
174
+ "https://huggingface.co/HReynaud/EchoFlow/tree/main/reid/lvh-4f4"
175
+ ),
176
+ },
177
+ "tau": {
178
+ "A4C": 0.9997,
179
+ "PSAX": 0.9953,
180
+ "PLAX": 0.9950,
181
+ },
182
+ }
183
+
184
+ lvfm = load_model("https://huggingface.co/HReynaud/EchoFlow/tree/main/lvfm/FMvT-S2-4f4")
185
+ lvfm = lvfm.to(device, dtype=dtype)
186
+ lvfm.eval()
187
+
188
+
189
+ def load_default_mask():
190
+ """Load the default mask from disk. If not found, return a blank black mask."""
191
+ default_mask_path = os.path.join("assets", "default_mask.png")
192
+ try:
193
+ if os.path.exists(default_mask_path):
194
+ mask = Image.open(default_mask_path).convert("L")
195
+ # Ensure the mask is square and of proper size
196
+ mask = mask.resize((400, 400), Image.Resampling.LANCZOS)
197
+ # Make sure it's binary (0 or 255)
198
+ mask = ImageOps.autocontrast(mask, cutoff=0)
199
+ return np.array(mask)
200
+ except Exception as e:
201
+ print(f"Error loading default mask: {e}")
202
+
203
+ # Return a blank black mask if no default mask is found
204
+ return np.zeros((400, 400), dtype=np.uint8)
205
+
206
+
207
+ def preprocess_mask(mask):
208
+ """Ensure mask is properly formatted for the model."""
209
+ if mask is None:
210
+ return np.zeros((112, 112), dtype=np.uint8)
211
+
212
+ # Check if mask is an EditorValue with multiple parts
213
+ if isinstance(mask, dict) and "composite" in mask:
214
+ # Use the composite image from the ImageEditor
215
+ mask = mask["composite"]
216
+
217
+ # If mask is already a numpy array, convert to PIL for processing
218
+ if isinstance(mask, np.ndarray):
219
+ mask_pil = Image.fromarray(mask)
220
+ else:
221
+ mask_pil = mask
222
+
223
+ # Ensure the mask is in L mode (grayscale)
224
+ mask_pil = mask_pil.convert("L")
225
+
226
+ # Apply contrast to make it binary (0 or 255)
227
+ mask_pil = ImageOps.autocontrast(mask_pil, cutoff=0)
228
+
229
+ # Threshold to ensure binary values
230
+ mask_pil = mask_pil.point(lambda p: 255 if p > 127 else 0)
231
+
232
+ # Print sizes for debugging
233
+ # print(f"Original mask size: {mask_pil.size}")
234
+
235
+ # Resize to 112x112 for the model
236
+ mask_pil = mask_pil.resize((112, 112), Image.Resampling.LANCZOS)
237
+
238
+ # Convert back to numpy array
239
+ return np.array(mask_pil)
240
+
241
+
242
+ def generate_latent_image(mask, class_selection, sampling_steps=50):
243
+ """Generate a latent image based on mask, class selection, and sampling steps"""
244
+
245
+ # Mask
246
+ mask = preprocess_mask(mask)
247
+ mask = torch.from_numpy(mask).to(device, dtype=dtype)
248
+ mask = mask.unsqueeze(0).unsqueeze(0)
249
+ mask = F.interpolate(mask, size=(H, W), mode="bilinear", align_corners=False)
250
+ mask = 1.0 * (mask > 0)
251
+
252
+ # print(mask.shape, mask.min(), mask.max(), mask.mean(), mask.std())
253
+
254
+ # Class
255
+ class_idx = VIEWS.index(class_selection)
256
+ class_idx = torch.tensor([class_idx], device=device, dtype=torch.long)
257
+
258
+ # Timesteps
259
+ timesteps = torch.linspace(
260
+ 1.0, 0.0, steps=sampling_steps + 1, device=device, dtype=dtype
261
+ )
262
+
263
+ forward_kwargs = {
264
+ "class_labels": class_idx, # B x 1
265
+ "segmentation": mask, # B x 1 x H x W
266
+ }
267
+
268
+ z_1 = torch.randn(
269
+ (B, C, H, W),
270
+ device=device,
271
+ dtype=dtype,
272
+ generator=generator,
273
+ )
274
+
275
+ lifm.forward_original = lifm.forward
276
+
277
+ def new_forward(self, t, y, *args, **kwargs):
278
+ kwargs = {**kwargs, **forward_kwargs}
279
+ return self.forward_original(y, t.view(1), *args, **kwargs).sample
280
+
281
+ lifm.forward = types.MethodType(new_forward, lifm)
282
+
283
+ # Use odeint to integrate
284
+ with torch.autocast("cuda"):
285
+ latent_image = odeint(
286
+ lifm,
287
+ z_1,
288
+ timesteps,
289
+ atol=1e-5,
290
+ rtol=1e-5,
291
+ adjoint_params=lifm.parameters(),
292
+ method="euler",
293
+ )[-1]
294
+
295
+ lifm.forward = lifm.forward_original
296
+
297
+ latent_image = latent_image.detach().cpu().numpy()
298
+
299
+ # callm VAE here
300
+
301
+ return latent_image # B x C x H x W
302
+
303
+
304
+ def decode_images(latents, vae):
305
+ """Decode latent representations to pixel space using a VAE.
306
+
307
+ Args:
308
+ latents: A numpy array of shape [B, C, H, W] for single image
309
+ or [B, C, T, H, W] for sequences/animations
310
+ vae: The VAE model for decoding
311
+
312
+ Returns:
313
+ numpy array of decoded images in [B, H, W, 3] format for single image
314
+ or [B, C, T, H, W] for sequences
315
+ """
316
+ if latents is None:
317
+ return None
318
+
319
+ # Convert to torch tensor if needed
320
+ if not isinstance(latents, torch.Tensor):
321
+ latents = torch.from_numpy(latents).to(device, dtype=dtype)
322
+
323
+ # Unscale latents
324
+ latents = unscale_latents(latents, vae_scaler)
325
+
326
+ # Handle both single images and sequences
327
+ is_sequence = len(latents.shape) == 5 # B C T H W
328
+
329
+ # print("Sequence:", is_sequence)
330
+
331
+ if is_sequence:
332
+ B, C, T, H, W = latents.shape
333
+ latents = rearrange(latents[0], "c t h w -> t c h w")
334
+ else:
335
+ B, C, H, W = latents.shape
336
+
337
+ # print("Latents:", latents.shape)
338
+
339
+ with torch.no_grad():
340
+ # Decode latents to pixel space
341
+ # decode one by one
342
+ decoded = []
343
+ for i in range(latents.shape[0]):
344
+ decoded.append(vae.decode(latents[i : i + 1].float()).sample)
345
+ decoded = torch.cat(decoded, dim=0)
346
+
347
+ decoded = (decoded + 1) * 128
348
+ decoded = decoded.clamp(0, 255).to(torch.uint8).cpu()
349
+
350
+ if is_sequence:
351
+ # Reshape back to [B, C, T, H, W] for sequences
352
+ decoded = rearrange(decoded, "t c h w -> c t h w").unsqueeze(0)
353
+ else:
354
+ decoded = decoded.squeeze()
355
+ decoded = decoded.permute(1, 2, 0)
356
+
357
+ # print("Decoded:", decoded.shape)
358
+ return decoded.numpy()
359
+
360
+
361
+ def decode_latent_to_pixel(latent_image):
362
+ """Decode a single latent image to pixel space"""
363
+ global vae
364
+ if latent_image is None:
365
+ return None
366
+
367
+ # Add batch dimension if needed
368
+ if len(latent_image.shape) == 3:
369
+ latent_image = latent_image[None, ...]
370
+
371
+ decoded_image = decode_images(latent_image, vae)
372
+ decoded_image = cv2.resize(
373
+ decoded_image, (400, 400), interpolation=cv2.INTER_NEAREST
374
+ )
375
+
376
+ return decoded_image
377
+
378
+
379
+ def check_privacy(latent_image_numpy, class_selection):
380
+ """Check if the latent image is too similar to database images"""
381
+ latent_image = torch.from_numpy(latent_image_numpy).to(device, dtype=dtype)
382
+ reid_model = reid["models"][class_selection].to(device, dtype=dtype)
383
+ real_anatomies = reid["anatomies"][class_selection] # already scaled
384
+ tau = reid["tau"][class_selection]
385
+
386
+ with torch.no_grad():
387
+ features = reid_model(latent_image).sigmoid().cpu()
388
+
389
+ corr = torch.corrcoef(torch.cat([real_anatomies, features], dim=0))[0, 1:]
390
+ corr = corr.max()
391
+
392
+ if corr > tau:
393
+ return (
394
+ None,
395
+ f"⚠️ **Warning:** Generated image is too similar to training data. Privacy check failed (corr = {corr:.4f} / tau = {tau:.4f})",
396
+ )
397
+ else:
398
+ return (
399
+ latent_image_numpy,
400
+ f"✅ **Success:** Generated image passed privacy check (corr = {corr:.4f} / tau = {tau:.4f})",
401
+ )
402
+
403
+
404
+ def generate_animation(
405
+ latent_image, ejection_fraction, sampling_steps=50, cfg_scale=1.0
406
+ ):
407
+ """Generate an animated sequence of latent images based on EF"""
408
+ # print(
409
+ # f"Generating animation with EF = {ejection_fraction}, steps = {sampling_steps}, CFG = {cfg_scale}"
410
+ # )
411
+ # print(latent_image.shape, type(latent_image))
412
+
413
+ if latent_image is None:
414
+ return None
415
+
416
+ lvefs = torch.tensor([ejection_fraction / 100.0], device=device, dtype=dtype)
417
+ lvefs = lvefs[:, None, None].to(device, dtype)
418
+ uncond_lvefs = -1 * torch.ones_like(lvefs)
419
+
420
+ ref_images = torch.from_numpy(latent_image).to(device, dtype)
421
+ ref_images = ref_images[:, :, None, :, :] # B x C x 1 x H x W
422
+ ref_images = ref_images.repeat(1, 1, T, 1, 1) # B x C x T x H x W
423
+ uncond_images = torch.zeros_like(ref_images)
424
+
425
+ timesteps = torch.linspace(
426
+ 1.0, 0.0, steps=sampling_steps + 1, device=device, dtype=dtype
427
+ )
428
+
429
+ forward_kwargs = {
430
+ "encoder_hidden_states": lvefs,
431
+ "cond_image": ref_images,
432
+ }
433
+
434
+ z_1 = torch.randn(
435
+ (B, C, T, H, W),
436
+ device=device,
437
+ dtype=dtype,
438
+ generator=generator,
439
+ )
440
+
441
+ # print(
442
+ # z_1.shape,
443
+ # forward_kwargs["encoder_hidden_states"].shape,
444
+ # forward_kwargs["cond_image"].shape,
445
+ # )
446
+
447
+ lvfm.forward_original = lvfm.forward
448
+
449
+ def new_forward(self, t, y, *args, **kwargs):
450
+ kwargs = {**kwargs, **forward_kwargs}
451
+ # y has shape (B, C, T, H, W)
452
+
453
+ pred = self.forward_original(y, t.repeat(y.size(0)), *args, **kwargs).sample
454
+
455
+ if cfg_scale != 1.0:
456
+ uncond_kwargs = {
457
+ "encoder_hidden_states": uncond_lvefs,
458
+ "cond_image": uncond_images,
459
+ }
460
+ uncond_pred = self.forward_original(
461
+ y, t.repeat(y.size(0)), *args, **uncond_kwargs
462
+ ).sample
463
+
464
+ pred = uncond_pred + cfg_scale * (pred - uncond_pred)
465
+
466
+ return pred
467
+
468
+ lvfm.forward = types.MethodType(new_forward, lvfm)
469
+
470
+ with torch.autocast("cuda"):
471
+ synthetic_video = odeint(
472
+ lvfm,
473
+ z_1,
474
+ timesteps,
475
+ atol=1e-5,
476
+ rtol=1e-5,
477
+ adjoint_params=lvfm.parameters(),
478
+ method="euler",
479
+ )[-1]
480
+
481
+ lvfm.forward = lvfm.forward_original
482
+
483
+ # print("Synthetic video:", synthetic_video.shape)
484
+
485
+ return synthetic_video # B x C x T x H x W
486
+
487
+
488
+ def decode_animation(latent_animation):
489
+ """Decode a latent animation to pixel space"""
490
+ global vae
491
+ if latent_animation is None:
492
+ return None
493
+
494
+ # Convert to torch tensor if needed
495
+ if not isinstance(latent_animation, torch.Tensor):
496
+ latent_animation = torch.from_numpy(latent_animation).to(device, dtype=dtype)
497
+
498
+ # Ensure shape is B x C x T x H x W
499
+ if len(latent_animation.shape) == 4: # [T, C, H, W]
500
+ latent_animation = latent_animation[None, ...] # Add batch dimension
501
+
502
+ # Decode using VAE
503
+ decoded = decode_images(
504
+ latent_animation, vae
505
+ ) # Returns B x C x T x H x W numpy array
506
+
507
+ # Remove batch dimension and transpose to T x H x W x C
508
+ decoded = np.transpose(decoded[0], (1, 2, 3, 0)) # [T, H, W, C]
509
+
510
+ # Resize frames to 400x400
511
+ decoded = np.stack(
512
+ [
513
+ cv2.resize(frame, (400, 400), interpolation=cv2.INTER_NEAREST)
514
+ for frame in decoded
515
+ ]
516
+ )
517
+
518
+ # Save to temporary file
519
+ temp_file = "temp_video_2.mp4"
520
+ fps = 32
521
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
522
+ out = cv2.VideoWriter(temp_file, fourcc, fps, (400, 400))
523
+
524
+ # Write frames
525
+ for frame in decoded:
526
+ out.write(frame)
527
+ out.release()
528
+
529
+ return temp_file
530
+
531
+
532
+ def convert_latent_to_display(latent_image):
533
+ """Convert multi-channel latent image to grayscale for display"""
534
+ if latent_image is None:
535
+ return None
536
+
537
+ # Check shape
538
+ if len(latent_image.shape) == 4: # [B, C, H, W]
539
+ # Remove batch dimension and average across channels
540
+ display_image = np.squeeze(latent_image, axis=0) # [C, H, W]
541
+ display_image = np.mean(display_image, axis=0) # [H, W]
542
+ elif len(latent_image.shape) == 3: # [C, H, W]
543
+ # Average across channels
544
+ display_image = np.mean(latent_image, axis=0) # [H, W]
545
+ else:
546
+ display_image = latent_image
547
+
548
+ # Normalize to 0-1 range
549
+ display_image = (display_image - display_image.min()) / (
550
+ display_image.max() - display_image.min() + 1e-8
551
+ )
552
+
553
+ # Convert to grayscale image
554
+ display_image = (display_image * 255).astype(np.uint8)
555
+
556
+ # Resize to a larger size (e.g., 400x400) using bicubic interpolation
557
+ display_image = cv2.resize(
558
+ display_image, (400, 400), interpolation=cv2.INTER_NEAREST
559
+ )
560
+
561
+ return display_image
562
+
563
+
564
+ def latent_animation_to_grayscale(latent_animation):
565
+ """Convert multi-channel latent animation to grayscale for display"""
566
+ if latent_animation is None:
567
+ return None
568
+
569
+ # print("Input shape:", latent_animation.shape)
570
+
571
+ # Convert to numpy if it's a torch tensor
572
+ if torch.is_tensor(latent_animation):
573
+ latent_animation = latent_animation.detach().cpu().numpy()
574
+
575
+ # Handle shape B x C x T x H x W -> T x H x W
576
+ if len(latent_animation.shape) == 5: # [B, C, T, H, W]
577
+ latent_animation = np.squeeze(latent_animation, axis=0) # [C, T, H, W]
578
+ latent_animation = np.transpose(latent_animation, (1, 0, 2, 3)) # [T, C, H, W]
579
+
580
+ # print("After transpose:", latent_animation.shape)
581
+
582
+ # Average across channels
583
+ latent_animation = np.mean(latent_animation, axis=1) # [T, H, W]
584
+
585
+ # print("After channel reduction:", latent_animation.shape)
586
+
587
+ # Normalize each frame independently
588
+ min_vals = latent_animation.min(axis=(1, 2), keepdims=True)
589
+ max_vals = latent_animation.max(axis=(1, 2), keepdims=True)
590
+ latent_animation = (latent_animation - min_vals) / (max_vals - min_vals + 1e-8)
591
+
592
+ # Convert to uint8
593
+ latent_animation = (latent_animation * 255).astype(np.uint8)
594
+
595
+ # print("Before resize:", latent_animation.shape)
596
+
597
+ # Resize each frame
598
+ resized_frames = []
599
+ for frame in latent_animation:
600
+ resized = cv2.resize(frame, (400, 400), interpolation=cv2.INTER_NEAREST)
601
+ resized_frames.append(resized)
602
+
603
+ # Stack back into video
604
+ grayscale_video = np.stack(resized_frames)
605
+
606
+ # print("Final shape:", grayscale_video.shape)
607
+
608
+ # Add a dummy channel dimension for grayscale video
609
+ grayscale_video = grayscale_video[..., None].repeat(3, axis=-1) # Convert to RGB
610
+
611
+ # print("Output shape with channels:", grayscale_video.shape)
612
+
613
+ # Save to temporary file
614
+ temp_file = "temp_video.mp4"
615
+ fps = 32
616
+
617
+ # Create VideoWriter object
618
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
619
+ out = cv2.VideoWriter(temp_file, fourcc, fps, (400, 400))
620
+
621
+ # Write frames
622
+ for frame in grayscale_video:
623
+ out.write(frame)
624
+
625
+ out.release()
626
+
627
+ return temp_file
628
+
629
+
630
+ def create_demo():
631
+ # Define the theme and layout
632
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
633
+ gr.Markdown("# EchoFlow Demo")
634
+ gr.Markdown("## Dataset Generation Pipeline")
635
+
636
+ gr.Markdown(
637
+ """
638
+ ### 🎯 Purpose
639
+ This demo showcases EchoFlow's ability to generate synthetic echocardiogram images and videos while preserving patient privacy. The pipeline consists of four main steps:
640
+
641
+ 1. **Latent Image Generation**: Draw a mask to indicate the region where the Left Ventricle should appear. Select the desired cardiac view, and click "Generate Latent Image". This outputs a latent image, which can be decoded into a pixel space image by clicking "Decode to Pixel Space".
642
+ 2. **Privacy Filter**: When clicking "Run Privacy Check", the generated image will be checked against a database of all training anatomies to ensure it is sufficiently different from real patient data.
643
+ 3. **Latent Video Generation**: If the privacy check passes, the latent image can be animated into a video with the desired Ejection Fraction.
644
+ 4. **Video Decoding**: The video can be decoded into a pixel space video by clicking "Decode Video".
645
+
646
+ ### ⚙️ Parameters
647
+ - **Sampling Steps**: Higher values produce better quality but take longer
648
+ - **Ejection Fraction**: Controls the strength of heart contraction in the animation
649
+ - **CFG Scale**: Controls how closely the animation follows the specified conditions
650
+ """
651
+ )
652
+
653
+ # Main container with 4 columns
654
+ with gr.Row():
655
+ # Column 1: Latent Image Generation
656
+ with gr.Column():
657
+ gr.Markdown(
658
+ '<img src="https://i.ibb.co/MysCHY1M/h1.png" style="width: 100%; height: 75px; object-fit: contain;">'
659
+ )
660
+ gr.Markdown("### Latent Image Generation")
661
+
662
+ with gr.Row():
663
+ # Input mask (binary image)
664
+ with gr.Column(scale=1):
665
+ # gr.Markdown("#### Mask Condition")
666
+ gr.Markdown("Draw the LV mask (white = region of interest)")
667
+ # Create a black background for the canvas
668
+ black_background = np.zeros((400, 400), dtype=np.uint8)
669
+
670
+ # Load the default mask image if it exists
671
+ try:
672
+ mask_image = Image.open("assets/seg.png").convert("L")
673
+ mask_image = mask_image.resize(
674
+ (400, 400), Image.Resampling.LANCZOS
675
+ )
676
+ # Make it binary (0 or 255)
677
+ mask_image = ImageOps.autocontrast(mask_image, cutoff=0)
678
+ mask_image = mask_image.point(
679
+ lambda p: 255 if p > 127 else 0
680
+ )
681
+ mask_array = np.array(mask_image)
682
+
683
+ # Create the editor value structure
684
+ editor_value = {
685
+ "background": black_background, # Black background
686
+ "layers": [mask_array], # The mask as an editable layer
687
+ "composite": mask_array, # The composite image (what's displayed)
688
+ }
689
+ except Exception as e:
690
+ print(f"Error loading mask image: {e}")
691
+ # Fall back to empty canvas
692
+ editor_value = black_background
693
+
694
+ mask_input = gr.ImageEditor(
695
+ label="Binary Mask",
696
+ height=400,
697
+ width=400,
698
+ image_mode="L",
699
+ value=editor_value,
700
+ type="numpy",
701
+ brush=gr.Brush(
702
+ colors=["#ffffff"],
703
+ color_mode="fixed",
704
+ default_size=20,
705
+ default_color="#ffffff",
706
+ ),
707
+ eraser=gr.Eraser(default_size=20),
708
+ # show_label=False,
709
+ show_download_button=True,
710
+ sources=[],
711
+ canvas_size=(400, 400),
712
+ fixed_canvas=True,
713
+ layers=False, # Enable layers to make the mask editable
714
+ )
715
+
716
+ # # Class selection
717
+ # with gr.Column(scale=1):
718
+ # gr.Markdown("#### View Condition")
719
+ class_selection = gr.Radio(
720
+ choices=["A4C", "PSAX", "PLAX"],
721
+ label="View Class",
722
+ value="A4C",
723
+ )
724
+
725
+ # gr.Markdown("#### Sampling Steps")
726
+ sampling_steps = gr.Slider(
727
+ minimum=1,
728
+ maximum=200,
729
+ value=100,
730
+ step=1,
731
+ label="Number of Sampling Steps",
732
+ info="Higher values = better quality but slower generation",
733
+ )
734
+
735
+ # Generate button
736
+ generate_btn = gr.Button("Generate Latent Image", variant="primary")
737
+
738
+ # Display area for latent image (grayscale visualization)
739
+ latent_image_display = gr.Image(
740
+ label="Latent Image",
741
+ type="numpy",
742
+ height=400,
743
+ width=400,
744
+ # show_label=False,
745
+ )
746
+
747
+ # Decode button (initially disabled)
748
+ decode_btn = gr.Button(
749
+ "Decode to Pixel Space (Optional)",
750
+ interactive=False,
751
+ variant="primary",
752
+ )
753
+
754
+ # Display area for decoded image
755
+ decoded_image_display = gr.Image(
756
+ label="Decoded Image",
757
+ type="numpy",
758
+ height=400,
759
+ width=400,
760
+ # show_label=False,
761
+ )
762
+
763
+ # Column 2: Privacy Filter
764
+ with gr.Column():
765
+ gr.Markdown(
766
+ '<img src="https://i.ibb.co/MysCHY1M/h1.png" style="width: 100%; height: 75px; object-fit: contain;">'
767
+ )
768
+ gr.Markdown("### Privacy Filter")
769
+ gr.Markdown(
770
+ "Checks if the generated image is too similar to training data"
771
+ )
772
+
773
+ # Privacy check button
774
+ privacy_btn = gr.Button(
775
+ "Run Privacy Check", interactive=False, variant="primary"
776
+ )
777
+
778
+ # Display area for privacy result status
779
+ privacy_status = gr.Markdown("No image processed yet")
780
+
781
+ # Display area for privacy-filtered latent image
782
+ filtered_latent_display = gr.Image(
783
+ label="Filtered Latent Image", type="numpy", height=400, width=400
784
+ )
785
+
786
+ # Column 3: Animation
787
+ with gr.Column():
788
+ gr.Markdown(
789
+ '<img src="https://i.ibb.co/MysCHY1M/h1.png" style="width: 100%; height: 75px; object-fit: contain;">'
790
+ )
791
+ gr.Markdown("### Latent Video Generation")
792
+
793
+ # Ejection Fraction slider
794
+ ef_slider = gr.Slider(
795
+ minimum=0,
796
+ maximum=100,
797
+ value=65,
798
+ label="Ejection Fraction (%)",
799
+ info="Higher values = stronger contraction",
800
+ )
801
+
802
+ # Add sampling steps slider for animation
803
+ animation_steps = gr.Slider(
804
+ minimum=1,
805
+ maximum=200,
806
+ value=100,
807
+ step=1,
808
+ label="Number of Sampling Steps",
809
+ info="Higher values = better quality but slower generation",
810
+ )
811
+
812
+ # Add CFG slider
813
+ cfg_slider = gr.Slider(
814
+ minimum=0,
815
+ maximum=10,
816
+ value=1,
817
+ step=1,
818
+ label="Classifier-Free Guidance Scale",
819
+ # info="Higher values = better quality but slower generation",
820
+ )
821
+
822
+ # Animate button
823
+ animate_btn = gr.Button(
824
+ "Generate Video", interactive=False, variant="primary"
825
+ )
826
+
827
+ # Display area for latent animation (grayscale)
828
+ latent_animation_display = gr.Video(
829
+ label="Latent Video", format="mp4", autoplay=True, loop=True
830
+ )
831
+
832
+ # Column 4: Video Decoding
833
+ with gr.Column():
834
+ gr.Markdown(
835
+ '<img src="https://i.ibb.co/MysCHY1M/h1.png" style="width: 100%; height: 75px; object-fit: contain;">'
836
+ )
837
+ gr.Markdown("### Video Decoding")
838
+
839
+ # Decode animation button
840
+ decode_animation_btn = gr.Button(
841
+ "Decode Video", interactive=False, variant="primary"
842
+ )
843
+
844
+ # Display area for decoded animation
845
+ decoded_animation_display = gr.Video(
846
+ label="Decoded Video", format="mp4", autoplay=True, loop=True
847
+ )
848
+
849
+ # Hidden state variables to store the full latent representations
850
+ latent_image_state = gr.State(None)
851
+ filtered_latent_state = gr.State(None)
852
+ latent_animation_state = gr.State(None)
853
+
854
+ # Event handlers
855
+ generate_btn.click(
856
+ fn=generate_latent_image,
857
+ inputs=[mask_input, class_selection, sampling_steps],
858
+ outputs=[latent_image_state],
859
+ queue=True,
860
+ ).then(
861
+ fn=convert_latent_to_display,
862
+ inputs=[latent_image_state],
863
+ outputs=[latent_image_display],
864
+ queue=False,
865
+ ).then(
866
+ fn=lambda x: gr.Button(
867
+ interactive=x is not None
868
+ ), # Properly update button state
869
+ inputs=[latent_image_state],
870
+ outputs=[decode_btn],
871
+ queue=False,
872
+ ).then(
873
+ fn=lambda x: gr.Button(
874
+ interactive=x is not None
875
+ ), # Properly update button state
876
+ inputs=[latent_image_state],
877
+ outputs=[privacy_btn],
878
+ queue=False,
879
+ )
880
+
881
+ decode_btn.click(
882
+ fn=decode_latent_to_pixel,
883
+ inputs=[latent_image_state],
884
+ outputs=[decoded_image_display],
885
+ queue=True,
886
+ ).then(
887
+ fn=lambda x: gr.Button(
888
+ interactive=x is not None
889
+ ), # Properly update button state
890
+ inputs=[decoded_image_display],
891
+ outputs=[privacy_btn],
892
+ queue=False,
893
+ )
894
+
895
+ privacy_btn.click(
896
+ fn=check_privacy,
897
+ inputs=[latent_image_state, class_selection],
898
+ outputs=[filtered_latent_state, privacy_status],
899
+ queue=True,
900
+ ).then(
901
+ fn=convert_latent_to_display,
902
+ inputs=[filtered_latent_state],
903
+ outputs=[filtered_latent_display],
904
+ queue=False,
905
+ ).then(
906
+ fn=lambda x: gr.Button(
907
+ interactive=x is not None
908
+ ), # Properly update button state
909
+ inputs=[filtered_latent_state],
910
+ outputs=[animate_btn],
911
+ queue=False,
912
+ )
913
+
914
+ animate_btn.click(
915
+ fn=generate_animation,
916
+ inputs=[filtered_latent_state, ef_slider, animation_steps, cfg_slider],
917
+ outputs=[latent_animation_state],
918
+ queue=True,
919
+ ).then(
920
+ fn=latent_animation_to_grayscale,
921
+ inputs=[latent_animation_state],
922
+ outputs=[latent_animation_display],
923
+ queue=False,
924
+ ).then(
925
+ fn=lambda x: gr.Button(
926
+ interactive=x is not None
927
+ ), # Properly update button state
928
+ inputs=[latent_animation_state],
929
+ outputs=[decode_animation_btn],
930
+ queue=False,
931
+ )
932
+
933
+ decode_animation_btn.click(
934
+ fn=decode_animation,
935
+ inputs=[latent_animation_state], # Remove vae_state from inputs
936
+ outputs=[decoded_animation_display],
937
+ queue=True,
938
+ )
939
+
940
+ return demo
941
+
942
+
943
+ if __name__ == "__main__":
944
+ demo = create_demo()
945
+ demo.launch()
echoflow/common/__init__.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ import omegaconf
4
+
5
+ from .models import ContrastiveModel, DiffuserSTDiT, ResNet18, SegDiTTransformer2DModel
6
+
7
+
8
+ def parse_klass_arg(value, full_config):
9
+ """
10
+ Parse an argument value that might represent a class, enum, or basic data type.
11
+ This function tries to dynamically import and resolve nested attributes.
12
+ It also resolves OmegaConf interpolations if found.
13
+ """
14
+ if isinstance(value, str) and "." in value:
15
+ # Check if the value is an interpolation and try to resolve it
16
+ if value.startswith("${") and value.endswith("}"):
17
+ try:
18
+ # Attempt to resolve the interpolation directly using OmegaConf
19
+ value = omegaconf.OmegaConf.resolve(full_config)[value[2:-1]]
20
+ except Exception as e:
21
+ print(f"Error resolving OmegaConf interpolation {value}: {e}")
22
+ return None
23
+
24
+ parts = value.split(".")
25
+ for i in range(len(parts) - 1, 0, -1):
26
+ module_name = ".".join(parts[:i])
27
+ attr_name = parts[i]
28
+ try:
29
+ module = importlib.import_module(module_name)
30
+ result = module
31
+ for j in range(i, len(parts)):
32
+ result = getattr(result, parts[j])
33
+ return result
34
+ except ImportError as e:
35
+ continue
36
+ except AttributeError as e:
37
+ print(
38
+ f"Warning: Could not resolve attribute {parts[j]} from {module_name}, error: {e}"
39
+ )
40
+ continue
41
+ # print(f"Warning: Failed to import or resolve {value}. Falling back to string.")
42
+ return (
43
+ value # Return the original string if no valid import and resolution occurs
44
+ )
45
+ return value
46
+
47
+
48
+ def instantiate_class_from_config(config, *args, **kwargs):
49
+ """
50
+ Dynamically instantiate a class based on a configuration object.
51
+ Supports passing additional positional and keyword arguments.
52
+ """
53
+ module_name, class_name = config.target.rsplit(".", 1)
54
+ klass = globals().get(class_name)
55
+ # module = importlib.import_module(module_name)
56
+ # klass = getattr(module, class_name)
57
+
58
+ # Assuming config might be a part of a larger OmegaConf structure:
59
+ # if not isinstance(config, omegaconf.DictConfig):
60
+ # config = omegaconf.OmegaConf.create(config)
61
+ config = omegaconf.OmegaConf.to_container(config, resolve=True)
62
+ # Resolve args and kwargs from the configuration
63
+ # conf_args = [parse_klass_arg(arg, config) for arg in config.get('args', [])]
64
+ # conf_kwargs = {key: parse_klass_arg(value, config) for key, value in config.get('kwargs', {}).items()}
65
+ conf_kwargs = {
66
+ key: parse_klass_arg(value, config) for key, value in config["args"].items()
67
+ }
68
+ # Combine conf_args with explicitly passed *args
69
+ all_args = list(args) # + conf_args
70
+
71
+ # Combine conf_kwargs with explicitly passed **kwargs
72
+ all_kwargs = {**conf_kwargs, **kwargs}
73
+
74
+ # Instantiate the class with the processed arguments
75
+ instance = klass(*all_args, **all_kwargs)
76
+ return instance
77
+
78
+
79
+ def unscale_latents(latents, vae_scaling=None):
80
+ if vae_scaling is not None:
81
+ if latents.ndim == 4:
82
+ v = (1, -1, 1, 1)
83
+ elif latents.ndim == 5:
84
+ v = (1, -1, 1, 1, 1)
85
+ else:
86
+ raise ValueError("Latents should be 4D or 5D")
87
+ latents *= vae_scaling["std"].view(*v)
88
+ latents += vae_scaling["mean"].view(*v)
89
+
90
+ return latents
echoflow/common/models.py ADDED
@@ -0,0 +1,1730 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file contains modified code from the HuggingFace Diffusers library.
2
+
3
+ import math
4
+ from dataclasses import dataclass
5
+ from typing import Any, Dict, Optional, Tuple, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch._dynamo
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import xformers
13
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
14
+ from diffusers.loaders import UNet2DConditionLoadersMixin
15
+ from diffusers.models.attention import BasicTransformerBlock
16
+ from diffusers.models.attention_processor import (
17
+ CROSS_ATTENTION_PROCESSORS,
18
+ AttentionProcessor,
19
+ AttnProcessor,
20
+ )
21
+ from diffusers.models.embeddings import PatchEmbed, TimestepEmbedding, Timesteps
22
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
23
+ from diffusers.models.modeling_utils import ModelMixin
24
+ from diffusers.models.unets.unet_3d_blocks import UNetMidBlockSpatioTemporal
25
+ from diffusers.models.unets.unet_3d_blocks import get_down_block as get_down_block_3d
26
+ from diffusers.models.unets.unet_3d_blocks import get_up_block as get_up_block_3d
27
+ from diffusers.utils import BaseOutput, is_torch_version
28
+ from einops import rearrange
29
+ from timm.layers.drop import DropPath
30
+ from timm.layers.mlp import Mlp
31
+ from torchvision.models import resnet18
32
+
33
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
34
+
35
+
36
+ class SegDiTTransformer2DModel(ModelMixin, ConfigMixin):
37
+ r"""
38
+ A 2D Transformer model as introduced in DiT (https://arxiv.org/abs/2212.09748).
39
+
40
+ Parameters:
41
+ num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention.
42
+ attention_head_dim (int, optional, defaults to 72): The number of channels in each head.
43
+ in_channels (int, defaults to 4): The number of channels in the input.
44
+ out_channels (int, optional):
45
+ The number of channels in the output. Specify this parameter if the output channel number differs from the
46
+ input.
47
+ num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use.
48
+ dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks.
49
+ norm_num_groups (int, optional, defaults to 32):
50
+ Number of groups for group normalization within Transformer blocks.
51
+ attention_bias (bool, optional, defaults to True):
52
+ Configure if the Transformer blocks' attention should contain a bias parameter.
53
+ sample_size (int, defaults to 32):
54
+ The width of the latent images. This parameter is fixed during training.
55
+ patch_size (int, defaults to 2):
56
+ Size of the patches the model processes, relevant for architectures working on non-sequential data.
57
+ activation_fn (str, optional, defaults to "gelu-approximate"):
58
+ Activation function to use in feed-forward networks within Transformer blocks.
59
+ num_embeds_ada_norm (int, optional, defaults to 1000):
60
+ Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during
61
+ inference.
62
+ upcast_attention (bool, optional, defaults to False):
63
+ If true, upcasts the attention mechanism dimensions for potentially improved performance.
64
+ norm_type (str, optional, defaults to "ada_norm_zero"):
65
+ Specifies the type of normalization used, can be 'ada_norm_zero'.
66
+ norm_elementwise_affine (bool, optional, defaults to False):
67
+ If true, enables element-wise affine parameters in the normalization layers.
68
+ norm_eps (float, optional, defaults to 1e-5):
69
+ A small constant added to the denominator in normalization layers to prevent division by zero.
70
+ """
71
+
72
+ _supports_gradient_checkpointing = True
73
+
74
+ @register_to_config
75
+ def __init__(
76
+ self,
77
+ num_attention_heads: int = 16,
78
+ attention_head_dim: int = 72,
79
+ in_channels: int = 4,
80
+ out_channels: Optional[int] = None,
81
+ num_layers: int = 28,
82
+ dropout: float = 0.0,
83
+ norm_num_groups: int = 32,
84
+ attention_bias: bool = True,
85
+ sample_size: int = 32,
86
+ patch_size: int = 2,
87
+ activation_fn: str = "gelu-approximate",
88
+ num_embeds_ada_norm: Optional[int] = 1000,
89
+ upcast_attention: bool = False,
90
+ norm_type: str = "ada_norm_zero",
91
+ norm_elementwise_affine: bool = False,
92
+ norm_eps: float = 1e-5,
93
+ ):
94
+ super().__init__()
95
+
96
+ # Validate inputs.
97
+ if norm_type != "ada_norm_zero":
98
+ raise NotImplementedError(
99
+ f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'."
100
+ )
101
+ elif norm_type == "ada_norm_zero" and num_embeds_ada_norm is None:
102
+ raise ValueError(
103
+ f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
104
+ )
105
+
106
+ # Set some common variables used across the board.
107
+ self.attention_head_dim = attention_head_dim
108
+ self.inner_dim = (
109
+ self.config.num_attention_heads * self.config.attention_head_dim
110
+ )
111
+ self.out_channels = in_channels if out_channels is None else out_channels
112
+ self.gradient_checkpointing = False
113
+
114
+ # 2. Initialize the position embedding and transformer blocks.
115
+ self.height = self.config.sample_size
116
+ self.width = self.config.sample_size
117
+
118
+ self.patch_size = self.config.patch_size
119
+ self.pos_embed = PatchEmbed(
120
+ height=self.config.sample_size,
121
+ width=self.config.sample_size,
122
+ patch_size=self.config.patch_size,
123
+ in_channels=self.config.in_channels,
124
+ embed_dim=self.inner_dim,
125
+ )
126
+
127
+ self.transformer_blocks = nn.ModuleList(
128
+ [
129
+ BasicTransformerBlock(
130
+ self.inner_dim,
131
+ self.config.num_attention_heads,
132
+ self.config.attention_head_dim,
133
+ dropout=self.config.dropout,
134
+ activation_fn=self.config.activation_fn,
135
+ num_embeds_ada_norm=self.config.num_embeds_ada_norm,
136
+ attention_bias=self.config.attention_bias,
137
+ upcast_attention=self.config.upcast_attention,
138
+ norm_type=norm_type,
139
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
140
+ norm_eps=self.config.norm_eps,
141
+ )
142
+ for _ in range(self.config.num_layers)
143
+ ]
144
+ )
145
+
146
+ # 3. Output blocks.
147
+ self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
148
+ self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
149
+ self.proj_out_2 = nn.Linear(
150
+ self.inner_dim,
151
+ self.config.patch_size * self.config.patch_size * self.out_channels,
152
+ )
153
+
154
+ def _set_gradient_checkpointing(self, module, value=False):
155
+ if hasattr(module, "gradient_checkpointing"):
156
+ module.gradient_checkpointing = value
157
+
158
+ def forward(
159
+ self,
160
+ hidden_states: torch.Tensor,
161
+ timestep: Optional[torch.LongTensor] = None,
162
+ class_labels: Optional[torch.LongTensor] = None,
163
+ cross_attention_kwargs: Dict[str, Any] = None,
164
+ segmentation: Optional[torch.LongTensor] = None,
165
+ return_dict: bool = True,
166
+ ):
167
+ """
168
+ The [`DiTTransformer2DModel`] forward method.
169
+
170
+ Args:
171
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
172
+ Input `hidden_states`.
173
+ timestep ( `torch.LongTensor`, *optional*):
174
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
175
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
176
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
177
+ `AdaLayerZeroNorm`.
178
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
179
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
180
+ `self.processor` in
181
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
182
+ return_dict (`bool`, *optional*, defaults to `True`):
183
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
184
+ tuple.
185
+
186
+ Returns:
187
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
188
+ `tuple` where the first element is the sample tensor.
189
+ """
190
+
191
+ # 0. If segmentation is provided, apply it to the input.
192
+ if segmentation is not None:
193
+ hidden_states = torch.cat([hidden_states, segmentation], dim=1) # B C+1 H W
194
+
195
+ # 1. Input
196
+ height, width = (
197
+ hidden_states.shape[-2] // self.patch_size,
198
+ hidden_states.shape[-1] // self.patch_size,
199
+ )
200
+ hidden_states = self.pos_embed(hidden_states)
201
+
202
+ # 2. Blocks
203
+ for block in self.transformer_blocks:
204
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
205
+
206
+ def create_custom_forward(module, return_dict=None):
207
+ def custom_forward(*inputs):
208
+ if return_dict is not None:
209
+ return module(*inputs, return_dict=return_dict)
210
+ else:
211
+ return module(*inputs)
212
+
213
+ return custom_forward
214
+
215
+ ckpt_kwargs: Dict[str, Any] = (
216
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
217
+ )
218
+ hidden_states = torch.utils.checkpoint.checkpoint(
219
+ create_custom_forward(block),
220
+ hidden_states,
221
+ None,
222
+ None,
223
+ None,
224
+ timestep,
225
+ cross_attention_kwargs,
226
+ class_labels,
227
+ **ckpt_kwargs,
228
+ )
229
+ else:
230
+ hidden_states = block(
231
+ hidden_states,
232
+ attention_mask=None,
233
+ encoder_hidden_states=None,
234
+ encoder_attention_mask=None,
235
+ timestep=timestep,
236
+ cross_attention_kwargs=cross_attention_kwargs,
237
+ class_labels=class_labels,
238
+ )
239
+
240
+ # 3. Output
241
+ conditioning = self.transformer_blocks[0].norm1.emb(
242
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
243
+ )
244
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
245
+ hidden_states = (
246
+ self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
247
+ )
248
+ hidden_states = self.proj_out_2(hidden_states)
249
+
250
+ # unpatchify
251
+ height = width = int(hidden_states.shape[1] ** 0.5)
252
+ hidden_states = hidden_states.reshape(
253
+ shape=(
254
+ -1,
255
+ height,
256
+ width,
257
+ self.patch_size,
258
+ self.patch_size,
259
+ self.out_channels,
260
+ )
261
+ )
262
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
263
+ output = hidden_states.reshape(
264
+ shape=(
265
+ -1,
266
+ self.out_channels,
267
+ height * self.patch_size,
268
+ width * self.patch_size,
269
+ )
270
+ )
271
+
272
+ if not return_dict:
273
+ return (output,)
274
+
275
+ return Transformer2DModelOutput(sample=output)
276
+
277
+
278
+ def get_2d_sincos_pos_embed(
279
+ embed_dim, grid_size, cls_token=False, extra_tokens=0, scale=1.0, base_size=None
280
+ ):
281
+ """
282
+ grid_size: int of the grid height and width
283
+ return:
284
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
285
+ """
286
+ if not isinstance(grid_size, tuple):
287
+ grid_size = (grid_size, grid_size)
288
+
289
+ grid_h = np.arange(grid_size[0], dtype=np.float32) / scale
290
+ grid_w = np.arange(grid_size[1], dtype=np.float32) / scale
291
+ if base_size is not None:
292
+ grid_h *= base_size / grid_size[0]
293
+ grid_w *= base_size / grid_size[1]
294
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
295
+ grid = np.stack(grid, axis=0)
296
+
297
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
298
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
299
+ if cls_token and extra_tokens > 0:
300
+ pos_embed = np.concatenate(
301
+ [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
302
+ )
303
+ return pos_embed
304
+
305
+
306
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
307
+ assert embed_dim % 2 == 0
308
+
309
+ # use half of dimensions to encode grid_h
310
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
311
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
312
+
313
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
314
+ return emb
315
+
316
+
317
+ def get_1d_sincos_pos_embed(embed_dim, length, scale=1.0):
318
+ pos = np.arange(0, length)[..., None] / scale
319
+ return get_1d_sincos_pos_embed_from_grid(embed_dim, pos)
320
+
321
+
322
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
323
+ """
324
+ embed_dim: output dimension for each position
325
+ pos: a list of positions to be encoded: size (M,)
326
+ out: (M, D)
327
+ """
328
+ assert embed_dim % 2 == 0
329
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
330
+ omega /= embed_dim / 2.0
331
+ omega = 1.0 / 10000**omega # (D/2,)
332
+
333
+ pos = pos.reshape(-1) # (M,)
334
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
335
+
336
+ emb_sin = np.sin(out) # (M, D/2)
337
+ emb_cos = np.cos(out) # (M, D/2)
338
+
339
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
340
+ return emb
341
+
342
+
343
+ def t2i_modulate(x, shift, scale):
344
+ return x * (1 + scale) + shift
345
+
346
+
347
+ class PatchEmbed3D(nn.Module):
348
+ """Video to Patch Embedding.
349
+
350
+ Args:
351
+ patch_size (int): Patch token size. Default: (2,4,4).
352
+ in_chans (int): Number of input video channels. Default: 3.
353
+ embed_dim (int): Number of linear projection output channels. Default: 96.
354
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
355
+ """
356
+
357
+ def __init__(
358
+ self,
359
+ patch_size=(2, 4, 4),
360
+ in_chans=3,
361
+ embed_dim=96,
362
+ norm_layer=None,
363
+ flatten=True,
364
+ ):
365
+ super().__init__()
366
+ self.patch_size = patch_size
367
+ self.flatten = flatten
368
+
369
+ self.in_chans = in_chans
370
+ self.embed_dim = embed_dim
371
+
372
+ self.proj = nn.Conv3d(
373
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
374
+ )
375
+ if norm_layer is not None:
376
+ self.norm = norm_layer(embed_dim)
377
+ else:
378
+ self.norm = None
379
+
380
+ def forward(self, x):
381
+ """Forward function."""
382
+ # padding
383
+ _, _, D, H, W = x.size()
384
+ if W % self.patch_size[2] != 0:
385
+ x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
386
+ if H % self.patch_size[1] != 0:
387
+ x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
388
+ if D % self.patch_size[0] != 0:
389
+ x = F.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
390
+
391
+ x = self.proj(x) # (B C T H W)
392
+ if self.norm is not None:
393
+ D, Wh, Ww = x.size(2), x.size(3), x.size(4)
394
+ x = x.flatten(2).transpose(1, 2)
395
+ x = self.norm(x)
396
+ x = x.transpose(1, 2).view(-1, self.embed_dim, D, Wh, Ww)
397
+ if self.flatten:
398
+ x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
399
+ return x
400
+
401
+
402
+ class Attention(nn.Module):
403
+ def __init__(
404
+ self,
405
+ dim: int,
406
+ num_heads: int = 8,
407
+ qkv_bias: bool = False,
408
+ qk_norm: bool = False,
409
+ attn_drop: float = 0.0,
410
+ proj_drop: float = 0.0,
411
+ norm_layer: nn.Module = nn.LayerNorm,
412
+ enable_flashattn: bool = False,
413
+ ) -> None:
414
+ super().__init__()
415
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
416
+ self.dim = dim
417
+ self.num_heads = num_heads
418
+ self.head_dim = dim // num_heads
419
+ self.scale = self.head_dim**-0.5
420
+
421
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
422
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
423
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
424
+ self.attn_drop = nn.Dropout(attn_drop)
425
+ self.proj = nn.Linear(dim, dim)
426
+ self.proj_drop = nn.Dropout(proj_drop)
427
+
428
+ if enable_flashattn:
429
+ print(
430
+ "[WARNING] FlashAttention cannot be used. Set enable_flashattn to False."
431
+ )
432
+
433
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
434
+ B, N, C = x.shape
435
+ qkv = self.qkv(x)
436
+ qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
437
+ qkv_permute_shape = (2, 0, 3, 1, 4)
438
+ qkv = qkv.view(qkv_shape).permute(qkv_permute_shape)
439
+ q, k, v = qkv.unbind(0)
440
+ q, k = self.q_norm(q), self.k_norm(k)
441
+
442
+ dtype = q.dtype
443
+ q = q * self.scale
444
+ attn = q @ k.transpose(-2, -1) # translate attn to float32
445
+ attn = attn.to(torch.float32)
446
+ attn = attn.softmax(dim=-1)
447
+ attn = attn.to(dtype) # cast back attn to original dtype
448
+ attn = self.attn_drop(attn)
449
+ x = attn @ v
450
+
451
+ x_output_shape = (B, N, C)
452
+ x = x.reshape(x_output_shape)
453
+ x = self.proj(x)
454
+ x = self.proj_drop(x)
455
+ return x
456
+
457
+
458
+ class MultiHeadCrossAttention(nn.Module):
459
+ def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
460
+ super(MultiHeadCrossAttention, self).__init__()
461
+ assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
462
+
463
+ self.d_model = d_model
464
+ self.num_heads = num_heads
465
+ self.head_dim = d_model // num_heads
466
+
467
+ self.q_linear = nn.Linear(d_model, d_model)
468
+ self.kv_linear = nn.Linear(d_model, d_model * 2)
469
+ self.attn_drop = nn.Dropout(attn_drop)
470
+ self.proj = nn.Linear(d_model, d_model)
471
+ self.proj_drop = nn.Dropout(proj_drop)
472
+
473
+ @torch._dynamo.disable
474
+ def forward(self, x, cond, mask=None):
475
+ # query/value: img tokens; key: condition; mask: if padding tokens
476
+ B, N, C = x.shape
477
+
478
+ q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
479
+ kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
480
+ k, v = kv.unbind(2)
481
+
482
+ attn_bias = None
483
+ if mask is not None:
484
+ attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
485
+ x = xformers.ops.memory_efficient_attention(
486
+ q, k, v, p=self.attn_drop.p, attn_bias=attn_bias
487
+ )
488
+
489
+ x = x.view(B, -1, C)
490
+ x = self.proj(x)
491
+ x = self.proj_drop(x)
492
+ return x
493
+
494
+
495
+ class TimestepEmbedder(nn.Module):
496
+ """
497
+ Embeds scalar timesteps into vector representations.
498
+ """
499
+
500
+ def __init__(self, hidden_size, frequency_embedding_size=256):
501
+ super().__init__()
502
+ self.mlp = nn.Sequential(
503
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
504
+ nn.SiLU(),
505
+ nn.Linear(hidden_size, hidden_size, bias=True),
506
+ )
507
+ self.frequency_embedding_size = frequency_embedding_size
508
+
509
+ @staticmethod
510
+ def timestep_embedding(t, dim, max_period=10000):
511
+ """
512
+ Create sinusoidal timestep embeddings.
513
+ :param t: a 1-D Tensor of N indices, one per batch element.
514
+ These may be fractional.
515
+ :param dim: the dimension of the output.
516
+ :param max_period: controls the minimum frequency of the embeddings.
517
+ :return: an (N, D) Tensor of positional embeddings.
518
+ """
519
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
520
+ half = dim // 2
521
+ freqs = torch.exp(
522
+ -math.log(max_period)
523
+ * torch.arange(start=0, end=half, dtype=torch.float32)
524
+ / half
525
+ )
526
+ freqs = freqs.to(device=t.device)
527
+ args = t[:, None].float() * freqs[None]
528
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
529
+ if dim % 2:
530
+ embedding = torch.cat(
531
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
532
+ )
533
+ return embedding
534
+
535
+ def forward(self, t, dtype):
536
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
537
+ if t_freq.dtype != dtype:
538
+ t_freq = t_freq.to(dtype)
539
+ t_emb = self.mlp(t_freq)
540
+ return t_emb
541
+
542
+
543
+ class CaptionEmbedder(nn.Module):
544
+ """
545
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
546
+ """
547
+
548
+ def __init__(
549
+ self,
550
+ in_channels,
551
+ hidden_size,
552
+ uncond_prob,
553
+ act_layer=nn.GELU(approximate="tanh"),
554
+ token_num=120,
555
+ ):
556
+ super().__init__()
557
+ self.y_proj = Mlp(
558
+ in_features=in_channels,
559
+ hidden_features=hidden_size,
560
+ out_features=hidden_size,
561
+ act_layer=act_layer,
562
+ drop=0,
563
+ )
564
+ self.register_buffer(
565
+ "y_embedding",
566
+ nn.Parameter(torch.randn(token_num, in_channels) / in_channels**0.5),
567
+ )
568
+ self.uncond_prob = uncond_prob
569
+
570
+ def token_drop(self, caption, force_drop_ids=None):
571
+ """
572
+ Drops labels to enable classifier-free guidance.
573
+ """
574
+ if force_drop_ids is None:
575
+ drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
576
+ else:
577
+ drop_ids = force_drop_ids == 1
578
+ caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
579
+ return caption
580
+
581
+ @torch._dynamo.disable
582
+ def forward(self, caption, train, force_drop_ids=None):
583
+ if train:
584
+ assert caption.shape[2:] == self.y_embedding.shape
585
+ use_dropout = self.uncond_prob > 0
586
+ if (train and use_dropout) or (force_drop_ids is not None):
587
+ caption = self.token_drop(caption, force_drop_ids)
588
+ caption = self.y_proj(caption)
589
+ return caption
590
+
591
+
592
+ class T2IFinalLayer(nn.Module):
593
+ """
594
+ The final layer of PixArt.
595
+ """
596
+
597
+ def __init__(self, hidden_size, num_patch, out_channels):
598
+ super().__init__()
599
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
600
+ self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True)
601
+ self.scale_shift_table = nn.Parameter(
602
+ torch.randn(2, hidden_size) / hidden_size**0.5
603
+ )
604
+ self.out_channels = out_channels
605
+
606
+ def forward(self, x, t):
607
+ shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
608
+ x = t2i_modulate(self.norm_final(x), shift, scale)
609
+ x = self.linear(x)
610
+ return x
611
+
612
+
613
+ class STDiTBlock(nn.Module):
614
+ """
615
+ STDiT: Spatio-Temporal Diffusion Transformer.
616
+
617
+ Args:
618
+ hidden_size (int): Hidden size of the model.
619
+ num_heads (int): Number of attention heads.
620
+ d_s (int): Spatial patch size.
621
+ d_t (int): Temporal patch size.
622
+ mlp_ratio (float): Ratio of hidden to mlp hidden size.
623
+ drop_path (float): Drop path rate.
624
+ enable_flashattn (bool): Enable FlashAttention.
625
+ """
626
+
627
+ def __init__(
628
+ self,
629
+ hidden_size,
630
+ num_heads,
631
+ d_s=None,
632
+ d_t=None,
633
+ mlp_ratio=4.0,
634
+ drop_path=0.0,
635
+ enable_flashattn=False,
636
+ uncond=False,
637
+ ):
638
+ super().__init__()
639
+ self.hidden_size = hidden_size
640
+ self.enable_flashattn = enable_flashattn
641
+
642
+ self.attn_cls = Attention
643
+ self.mha_cls = MultiHeadCrossAttention
644
+
645
+ self.norm1 = nn.LayerNorm(hidden_size, eps=1e-6, elementwise_affine=False)
646
+ self.attn = self.attn_cls(
647
+ hidden_size,
648
+ num_heads=num_heads,
649
+ qkv_bias=True,
650
+ enable_flashattn=False,
651
+ )
652
+ if uncond:
653
+ self.cross_attn = self.mha_cls(hidden_size, num_heads)
654
+ self.norm2 = nn.LayerNorm(hidden_size, eps=1e-6, elementwise_affine=False)
655
+ self.mlp = Mlp(
656
+ in_features=hidden_size,
657
+ hidden_features=int(hidden_size * mlp_ratio),
658
+ act_layer=approx_gelu,
659
+ drop=0,
660
+ )
661
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
662
+ self.scale_shift_table = nn.Parameter(
663
+ torch.randn(6, hidden_size) / hidden_size**0.5
664
+ )
665
+
666
+ # temporal attention
667
+ self.d_s = d_s
668
+ self.d_t = d_t
669
+
670
+ self.attn_temp = self.attn_cls(
671
+ hidden_size,
672
+ num_heads=num_heads,
673
+ qkv_bias=True,
674
+ enable_flashattn=self.enable_flashattn,
675
+ )
676
+
677
+ def forward(self, x, t, y=None, mask=None, tpe=None):
678
+ """
679
+ Args:
680
+ x (torch.Tensor): noisy input tensor of shape [B, N, C]
681
+ y (torch.Tensor): conditional input tensor of shape [B, N, C]
682
+ t (torch.Tensor): input tensor; of shape [B, C]
683
+ mask (torch.Tensor): input tensor; of shape [B, N]
684
+ tpe (torch.Tensor): input tensor; of shape [B, C]
685
+ """
686
+ B, N, C = x.shape
687
+
688
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
689
+ self.scale_shift_table[None] + t.reshape(B, 6, -1)
690
+ ).chunk(6, dim=1)
691
+ x_m = t2i_modulate(self.norm1(x), shift_msa, scale_msa)
692
+
693
+ # spatial branch
694
+ x_s = rearrange(x_m, "B (T S) C -> (B T) S C", T=self.d_t, S=self.d_s)
695
+ x_s = self.attn(x_s)
696
+ x_s = rearrange(x_s, "(B T) S C -> B (T S) C", T=self.d_t, S=self.d_s)
697
+ x = x + self.drop_path(gate_msa * x_s)
698
+
699
+ # temporal branch
700
+ x_t = rearrange(x, "B (T S) C -> (B S) T C", T=self.d_t, S=self.d_s)
701
+ if tpe is not None:
702
+ x_t = x_t + tpe
703
+ x_t = self.attn_temp(x_t)
704
+ x_t = rearrange(x_t, "(B S) T C -> B (T S) C", T=self.d_t, S=self.d_s)
705
+ x = x + self.drop_path(gate_msa * x_t)
706
+
707
+ # cross attn
708
+ if y is not None:
709
+ x = x + self.cross_attn(x, y, mask)
710
+
711
+ # mlp
712
+ x = x + self.drop_path(
713
+ gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))
714
+ )
715
+
716
+ return x
717
+
718
+
719
+ # | Model | Layers N | Hidden size d | Heads | Gflops (I=32, p=4) |
720
+ # |-------|----------|---------------|-------|---------------------|
721
+ # | DiT-S | 12 | 384 | 6 | 1.4 |
722
+ # | DiT-B | 12 | 768 | 12 | 5.6 |
723
+ # | DiT-L | 24 | 1024 | 16 | 19.7 |
724
+ # | DiT-XL| 28 | 1152 | 16 | 29.1 |
725
+ class STDiT(nn.Module):
726
+ def __init__(
727
+ self,
728
+ input_size=(1, 32, 32), # T, H, W
729
+ in_channels=4,
730
+ out_channels=4,
731
+ patch_size=(1, 2, 2), # T, H, W
732
+ hidden_size=1152, #
733
+ depth=28, # Number of layers
734
+ num_heads=16,
735
+ mlp_ratio=4.0,
736
+ class_dropout_prob=0.1,
737
+ drop_path=0.0,
738
+ no_temporal_pos_emb=False,
739
+ caption_channels=4096, # 0 to disable
740
+ model_max_length=120,
741
+ space_scale=1.0,
742
+ time_scale=1.0,
743
+ enable_flashattn=False,
744
+ ):
745
+ super().__init__()
746
+ self.in_channels = in_channels
747
+ self.out_channels = out_channels
748
+ self.hidden_size = hidden_size
749
+ self.patch_size = patch_size
750
+ self.input_size = input_size
751
+ num_patches = np.prod([input_size[i] // patch_size[i] for i in range(3)])
752
+ self.num_patches = num_patches
753
+ self.num_temporal = input_size[0] // patch_size[0]
754
+ self.num_spatial = num_patches // self.num_temporal
755
+ self.num_heads = num_heads
756
+ self.no_temporal_pos_emb = no_temporal_pos_emb
757
+ self.depth = depth
758
+ self.mlp_ratio = mlp_ratio
759
+ self.enable_flashattn = enable_flashattn
760
+ self.space_scale = space_scale
761
+ self.time_scale = time_scale
762
+
763
+ if caption_channels == 0:
764
+ print("Warning: caption_channels is 0, disabling text conditioning.")
765
+
766
+ self.register_buffer("pos_embed", self.get_spatial_pos_embed())
767
+ self.register_buffer("pos_embed_temporal", self.get_temporal_pos_embed())
768
+
769
+ self.x_embedder = PatchEmbed3D(patch_size, in_channels, hidden_size)
770
+ self.t_embedder = TimestepEmbedder(hidden_size)
771
+ self.t_block = nn.Sequential(
772
+ nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)
773
+ )
774
+ self.y_embedder = (
775
+ CaptionEmbedder(
776
+ in_channels=caption_channels,
777
+ hidden_size=hidden_size,
778
+ uncond_prob=class_dropout_prob,
779
+ act_layer=approx_gelu,
780
+ token_num=model_max_length,
781
+ )
782
+ if caption_channels > 0
783
+ else None
784
+ )
785
+
786
+ drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)]
787
+ self.blocks = nn.ModuleList(
788
+ [
789
+ STDiTBlock(
790
+ self.hidden_size,
791
+ self.num_heads,
792
+ mlp_ratio=self.mlp_ratio,
793
+ drop_path=drop_path[i],
794
+ enable_flashattn=self.enable_flashattn,
795
+ d_t=self.num_temporal,
796
+ d_s=self.num_spatial,
797
+ uncond=(caption_channels > 0),
798
+ )
799
+ for i in range(self.depth)
800
+ ]
801
+ )
802
+ self.final_layer = T2IFinalLayer(
803
+ hidden_size, np.prod(self.patch_size), self.out_channels
804
+ )
805
+
806
+ # init model
807
+ self.initialize_weights()
808
+ self.initialize_temporal()
809
+
810
+ # sequence parallel related configs
811
+ self.sp_rank = None
812
+
813
+ def forward(self, x, timestep, y=None, mask=None, cond_image=None):
814
+ """
815
+ Forward pass of STDiT.
816
+ Args:
817
+ x (torch.Tensor): latent representation of video; of shape [B, C, T, H, W]
818
+ timestep (torch.Tensor): diffusion time steps; of shape [B]
819
+ y (torch.Tensor): representation of prompts; of shape [B, 1, N_token, C]
820
+ mask (torch.Tensor): mask for selecting prompt tokens; of shape [B, N_token]
821
+
822
+ Returns:
823
+ x (torch.Tensor): output latent representation; of shape [B, C, T, H, W]
824
+ """
825
+
826
+ # x = x.to(self.dtype)
827
+ # timestep = timestep.to(self.dtype)
828
+ # y = y.to(self.dtype)
829
+
830
+ # embedding
831
+ x = self.x_embedder(x) # [B, N, C]
832
+ # print(x.shape, self.num_temporal, self.num_spatial)
833
+ x = rearrange(
834
+ x, "B (T S) C -> B T S C", T=self.num_temporal, S=self.num_spatial
835
+ )
836
+ x = x + self.pos_embed
837
+ x = rearrange(x, "B T S C -> B (T S) C")
838
+
839
+ # shard over the sequence dim if sp is enabled
840
+ # if self.enable_sequence_parallelism:
841
+ # x = split_forward_gather_backward(x, get_sequence_parallel_group(), dim=1, grad_scale="down")
842
+
843
+ t = self.t_embedder(timestep, dtype=x.dtype) # [B, C]
844
+ t0 = self.t_block(t) # [B, C]
845
+ if self.y_embedder is not None and y is not None:
846
+ y = self.y_embedder(y, self.training) # [B, 1, N_token, C]
847
+
848
+ if mask is not None:
849
+ if mask.shape[0] != y.shape[0]:
850
+ mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
851
+ mask = mask.squeeze(1).squeeze(1)
852
+ y = (
853
+ y.squeeze(1)
854
+ .masked_select(mask.unsqueeze(-1) != 0)
855
+ .view(1, -1, x.shape[-1])
856
+ )
857
+ y_lens = mask.sum(dim=1).tolist()
858
+ else:
859
+ y_lens = [y.shape[2]] * y.shape[0] # N_token * B
860
+ y = y.squeeze(1).view(1, -1, x.shape[-1])
861
+ else:
862
+ y = None
863
+ y_lens = None
864
+
865
+ # blocks
866
+ for i, block in enumerate(self.blocks):
867
+ if i == 0:
868
+ tpe = self.pos_embed_temporal
869
+ else:
870
+ tpe = None
871
+ x = block(x=x, t=t0, y=y, mask=y_lens, tpe=tpe)
872
+ # x.shape: [B, N, C]
873
+
874
+ # final process
875
+ x = self.final_layer(x, t) # [B, N, C=T_p * H_p * W_p * C_out]
876
+ x = self.unpatchify(x) # [B, C_out, T, H, W]
877
+
878
+ return x
879
+
880
+ def unpatchify(self, x):
881
+ """
882
+ Args:
883
+ x (torch.Tensor): of shape [B, N, C]
884
+
885
+ Return:
886
+ x (torch.Tensor): of shape [B, C_out, T, H, W]
887
+ """
888
+
889
+ N_t, N_h, N_w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
890
+ T_p, H_p, W_p = self.patch_size
891
+ x = rearrange(
892
+ x,
893
+ "B (N_t N_h N_w) (T_p H_p W_p C_out) -> B C_out (N_t T_p) (N_h H_p) (N_w W_p)",
894
+ N_t=N_t,
895
+ N_h=N_h,
896
+ N_w=N_w,
897
+ T_p=T_p,
898
+ H_p=H_p,
899
+ W_p=W_p,
900
+ C_out=self.out_channels,
901
+ )
902
+ return x
903
+
904
+ def unpatchify_old(self, x):
905
+ c = self.out_channels
906
+ t, h, w = [self.input_size[i] // self.patch_size[i] for i in range(3)]
907
+ pt, ph, pw = self.patch_size
908
+
909
+ x = x.reshape(shape=(x.shape[0], t, h, w, pt, ph, pw, c))
910
+ x = rearrange(x, "n t h w r p q c -> n c t r h p w q")
911
+ imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
912
+ return imgs
913
+
914
+ def get_spatial_pos_embed(self, grid_size=None):
915
+ if grid_size is None:
916
+ grid_size = self.input_size[1:]
917
+ pos_embed = get_2d_sincos_pos_embed(
918
+ self.hidden_size,
919
+ (grid_size[0] // self.patch_size[1], grid_size[1] // self.patch_size[2]),
920
+ scale=self.space_scale,
921
+ )
922
+ pos_embed = (
923
+ torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
924
+ )
925
+ return pos_embed
926
+
927
+ def get_temporal_pos_embed(self):
928
+ pos_embed = get_1d_sincos_pos_embed(
929
+ self.hidden_size,
930
+ self.input_size[0] // self.patch_size[0],
931
+ scale=self.time_scale,
932
+ )
933
+ pos_embed = (
934
+ torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False)
935
+ )
936
+ return pos_embed
937
+
938
+ def freeze_not_temporal(self):
939
+ for n, p in self.named_parameters():
940
+ if "attn_temp" not in n:
941
+ p.requires_grad = False
942
+
943
+ def freeze_text(self):
944
+ for n, p in self.named_parameters():
945
+ if "cross_attn" in n:
946
+ p.requires_grad = False
947
+
948
+ def initialize_temporal(self):
949
+ for block in self.blocks:
950
+ nn.init.constant_(block.attn_temp.proj.weight, 0)
951
+ nn.init.constant_(block.attn_temp.proj.bias, 0)
952
+
953
+ def initialize_weights(self):
954
+ # Initialize transformer layers:
955
+ def _basic_init(module):
956
+ if isinstance(module, nn.Linear):
957
+ torch.nn.init.xavier_uniform_(module.weight)
958
+ if module.bias is not None:
959
+ nn.init.constant_(module.bias, 0)
960
+
961
+ self.apply(_basic_init)
962
+
963
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
964
+ w = self.x_embedder.proj.weight.data
965
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
966
+
967
+ # Initialize timestep embedding MLP:
968
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
969
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
970
+ nn.init.normal_(self.t_block[1].weight, std=0.02)
971
+
972
+ # Initialize caption embedding MLP:
973
+ if self.y_embedder is not None:
974
+ nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02)
975
+ nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)
976
+
977
+ # Zero-out adaLN modulation layers in PixArt blocks:
978
+ for block in self.blocks:
979
+ nn.init.constant_(block.cross_attn.proj.weight, 0)
980
+ nn.init.constant_(block.cross_attn.proj.bias, 0)
981
+
982
+ # Zero-out output layers:
983
+ nn.init.constant_(self.final_layer.linear.weight, 0)
984
+ nn.init.constant_(self.final_layer.linear.bias, 0)
985
+
986
+
987
+ @dataclass
988
+ class DiffuserSTDiTModelOutput(BaseOutput):
989
+ """
990
+ The output of [`DiffuserSTDiT`].
991
+
992
+ Args:
993
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, num_frames, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
994
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
995
+ distributions for the unnoised latent pixels.
996
+ """
997
+
998
+ sample: torch.FloatTensor
999
+
1000
+
1001
+ class DiffuserSTDiT(ModelMixin, ConfigMixin):
1002
+ """
1003
+ STDiT: Spatio-Temporal Diffusion Transformer.
1004
+
1005
+ Parameters:
1006
+ input_size (tuple): Input size of the video. Default: (1, 32, 32).
1007
+ in_channels (int): Number of input video channels. Default: 4.
1008
+ out_channels (int): Number of output video channels. Default: 4.
1009
+ patch_size (tuple): Patch token size. Default: (1, 2, 2).
1010
+ hidden_size (int): Hidden size of the model. Default: 1152.
1011
+ depth (int): Number of layers. Default: 28.
1012
+ num_heads (int): Number of attention heads. Default: 16.
1013
+ mlp_ratio (float): Ratio of hidden to mlp hidden size. Default: 4.0.
1014
+ class_dropout_prob (float): Probability of dropping class tokens. Default: 0.1.
1015
+ drop_path (float): Drop path rate. Default: 0.0.
1016
+ no_temporal_pos_emb (bool): Disable temporal positional embeddings. Default: False.
1017
+ caption_channels (int): Number of caption channels. Default: 4096.
1018
+ model_max_length (int): Maximum length of the model. Default: 120.
1019
+ space_scale (float): Spatial scale. Default: 1.0.
1020
+ time_scale (float): Temporal scale. Default: 1.0.
1021
+ enable_flashattn (bool): Enable FlashAttention. Default: False.
1022
+ """
1023
+
1024
+ @register_to_config
1025
+ def __init__(
1026
+ self,
1027
+ input_size=(1, 32, 32), # T, H, W
1028
+ in_channels=4,
1029
+ out_channels=4,
1030
+ patch_size=(1, 2, 2), # T, H, W
1031
+ hidden_size=1152, #
1032
+ depth=28, # Number of layers
1033
+ num_heads=16,
1034
+ mlp_ratio=4.0,
1035
+ class_dropout_prob=0.1,
1036
+ drop_path=0.0,
1037
+ no_temporal_pos_emb=False,
1038
+ caption_channels=4096, # 0 to disable
1039
+ model_max_length=120,
1040
+ space_scale=1.0,
1041
+ time_scale=1.0,
1042
+ enable_flashattn=False,
1043
+ ):
1044
+
1045
+ super().__init__()
1046
+
1047
+ self.model = STDiT(
1048
+ input_size=input_size,
1049
+ in_channels=in_channels,
1050
+ out_channels=out_channels,
1051
+ patch_size=patch_size,
1052
+ hidden_size=hidden_size,
1053
+ depth=depth,
1054
+ num_heads=num_heads,
1055
+ mlp_ratio=mlp_ratio,
1056
+ class_dropout_prob=class_dropout_prob,
1057
+ drop_path=drop_path,
1058
+ no_temporal_pos_emb=no_temporal_pos_emb,
1059
+ caption_channels=caption_channels,
1060
+ model_max_length=model_max_length,
1061
+ space_scale=space_scale,
1062
+ time_scale=time_scale,
1063
+ enable_flashattn=enable_flashattn,
1064
+ )
1065
+
1066
+ def forward(
1067
+ self,
1068
+ x,
1069
+ timestep,
1070
+ encoder_hidden_states=None,
1071
+ cond_image=None,
1072
+ mask=None,
1073
+ return_dict=True,
1074
+ *args,
1075
+ **kwargs,
1076
+ ):
1077
+ """
1078
+ Args:
1079
+ x (torch.Tensor): latent representation of video; of shape [B, C, T, H, W]
1080
+ timestep (torch.Tensor): diffusion time steps; of shape [B]
1081
+ y (torch.Tensor): representation of prompts; of shape [B, 1, N_token, C]
1082
+ mask (torch.Tensor): mask for selecting prompt tokens; of shape [B, N_token]
1083
+ return_dict (bool): return a dictionary or not. Default: True.
1084
+ """
1085
+ if type(timestep) == int or timestep.ndim == 0:
1086
+ timestep = torch.ones(x.shape[0], device=x.device) * timestep
1087
+
1088
+ encoder_hidden_states = (
1089
+ encoder_hidden_states.unsqueeze(1)
1090
+ if encoder_hidden_states is not None
1091
+ else None
1092
+ )
1093
+
1094
+ if cond_image is not None:
1095
+ assert (
1096
+ x.shape == cond_image.shape
1097
+ ), "x and cond_image must have the same shape"
1098
+ x = torch.cat([x, cond_image], dim=1) # B x 2C x T x H x W
1099
+
1100
+ output = self.model(x, timestep, encoder_hidden_states, mask)
1101
+ if not return_dict:
1102
+ return (output,)
1103
+
1104
+ return DiffuserSTDiTModelOutput(sample=output)
1105
+
1106
+
1107
+ ##############################
1108
+ # Image-Conditionned ST UNet #
1109
+ ##############################
1110
+
1111
+
1112
+ @torch._dynamo.disable
1113
+ @dataclass
1114
+ class UNetSTICOutput(BaseOutput): # UNet-SpatioTemporal-ImageConditionned
1115
+ """
1116
+ The output of [`UNetSpatioTemporalConditionModel`].
1117
+
1118
+ Args:
1119
+ sample (`torch.Tensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
1120
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
1121
+ """
1122
+
1123
+ sample: torch.Tensor = None
1124
+
1125
+
1126
+ class UNetSTIC(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
1127
+ r"""
1128
+ A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and
1129
+ returns a sample shaped output.
1130
+
1131
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
1132
+ for all models (such as downloading or saving).
1133
+
1134
+ Parameters:
1135
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
1136
+ Height and width of input/output sample.
1137
+ in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample.
1138
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
1139
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`):
1140
+ The tuple of downsample blocks to use.
1141
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`):
1142
+ The tuple of upsample blocks to use.
1143
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
1144
+ The tuple of output channels for each block.
1145
+ addition_time_embed_dim: (`int`, defaults to 256):
1146
+ Dimension to to encode the additional time ids.
1147
+ projection_class_embeddings_input_dim (`int`, defaults to 768):
1148
+ The dimension of the projection of encoded `added_time_ids`.
1149
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
1150
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
1151
+ The dimension of the cross attention features.
1152
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
1153
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
1154
+ [`~models.unets.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`],
1155
+ [`~models.unets.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
1156
+ [`~models.unets.unet_3d_blocks.UNetMidBlockSpatioTemporal`].
1157
+ num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):
1158
+ The number of attention heads.
1159
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
1160
+ """
1161
+
1162
+ _supports_gradient_checkpointing = True
1163
+
1164
+ @register_to_config
1165
+ def __init__(
1166
+ self,
1167
+ sample_size: Optional[int] = None,
1168
+ in_channels: int = 8,
1169
+ out_channels: int = 4,
1170
+ down_block_types: Tuple[str] = (
1171
+ "CrossAttnDownBlockSpatioTemporal",
1172
+ "CrossAttnDownBlockSpatioTemporal",
1173
+ "CrossAttnDownBlockSpatioTemporal",
1174
+ "DownBlockSpatioTemporal",
1175
+ ),
1176
+ up_block_types: Tuple[str] = (
1177
+ "UpBlockSpatioTemporal",
1178
+ "CrossAttnUpBlockSpatioTemporal",
1179
+ "CrossAttnUpBlockSpatioTemporal",
1180
+ "CrossAttnUpBlockSpatioTemporal",
1181
+ ),
1182
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
1183
+ addition_time_embed_dim: int = 256,
1184
+ projection_class_embeddings_input_dim: int = 768,
1185
+ layers_per_block: Union[int, Tuple[int]] = 2,
1186
+ cross_attention_dim: Union[int, Tuple[int]] = 1024,
1187
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
1188
+ num_attention_heads: Union[int, Tuple[int]] = (5, 10, 20, 20),
1189
+ num_frames: int = 25,
1190
+ ):
1191
+ super().__init__()
1192
+
1193
+ self.sample_size = sample_size
1194
+
1195
+ # Check inputs
1196
+ if len(down_block_types) != len(up_block_types):
1197
+ raise ValueError(
1198
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
1199
+ )
1200
+
1201
+ if len(block_out_channels) != len(down_block_types):
1202
+ raise ValueError(
1203
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
1204
+ )
1205
+
1206
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
1207
+ down_block_types
1208
+ ):
1209
+ raise ValueError(
1210
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
1211
+ )
1212
+
1213
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(
1214
+ down_block_types
1215
+ ):
1216
+ raise ValueError(
1217
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
1218
+ )
1219
+
1220
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(
1221
+ down_block_types
1222
+ ):
1223
+ raise ValueError(
1224
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
1225
+ )
1226
+
1227
+ # input
1228
+ self.conv_in = nn.Conv2d(
1229
+ in_channels,
1230
+ block_out_channels[0],
1231
+ kernel_size=3,
1232
+ padding=1,
1233
+ )
1234
+
1235
+ # time
1236
+ time_embed_dim = block_out_channels[0] * 4
1237
+
1238
+ self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)
1239
+ timestep_input_dim = block_out_channels[0]
1240
+
1241
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
1242
+
1243
+ # self.add_time_proj = Timesteps(
1244
+ # addition_time_embed_dim, True, downscale_freq_shift=0
1245
+ # )
1246
+ # self.add_embedding = TimestepEmbedding(
1247
+ # projection_class_embeddings_input_dim, time_embed_dim
1248
+ # )
1249
+
1250
+ self.down_blocks = nn.ModuleList([])
1251
+ self.up_blocks = nn.ModuleList([])
1252
+
1253
+ if isinstance(num_attention_heads, int):
1254
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
1255
+
1256
+ if isinstance(cross_attention_dim, int):
1257
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
1258
+
1259
+ if isinstance(layers_per_block, int):
1260
+ layers_per_block = [layers_per_block] * len(down_block_types)
1261
+
1262
+ if isinstance(transformer_layers_per_block, int):
1263
+ transformer_layers_per_block = [transformer_layers_per_block] * len(
1264
+ down_block_types
1265
+ )
1266
+
1267
+ blocks_time_embed_dim = time_embed_dim
1268
+
1269
+ # down
1270
+ output_channel = block_out_channels[0]
1271
+ for i, down_block_type in enumerate(down_block_types):
1272
+ input_channel = output_channel
1273
+ output_channel = block_out_channels[i]
1274
+ is_final_block = i == len(block_out_channels) - 1
1275
+
1276
+ down_block = get_down_block_3d(
1277
+ down_block_type,
1278
+ num_layers=layers_per_block[i],
1279
+ transformer_layers_per_block=transformer_layers_per_block[i],
1280
+ in_channels=input_channel,
1281
+ out_channels=output_channel,
1282
+ temb_channels=blocks_time_embed_dim,
1283
+ add_downsample=not is_final_block,
1284
+ resnet_eps=1e-5,
1285
+ cross_attention_dim=cross_attention_dim[i],
1286
+ num_attention_heads=num_attention_heads[i],
1287
+ resnet_act_fn="silu",
1288
+ )
1289
+ self.down_blocks.append(down_block)
1290
+
1291
+ # mid
1292
+ self.mid_block = UNetMidBlockSpatioTemporal(
1293
+ block_out_channels[-1],
1294
+ temb_channels=blocks_time_embed_dim,
1295
+ transformer_layers_per_block=transformer_layers_per_block[-1],
1296
+ cross_attention_dim=cross_attention_dim[-1],
1297
+ num_attention_heads=num_attention_heads[-1],
1298
+ )
1299
+
1300
+ # count how many layers upsample the images
1301
+ self.num_upsamplers = 0
1302
+
1303
+ # up
1304
+ reversed_block_out_channels = list(reversed(block_out_channels))
1305
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
1306
+ reversed_layers_per_block = list(reversed(layers_per_block))
1307
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
1308
+ reversed_transformer_layers_per_block = list(
1309
+ reversed(transformer_layers_per_block)
1310
+ )
1311
+
1312
+ output_channel = reversed_block_out_channels[0]
1313
+ for i, up_block_type in enumerate(up_block_types):
1314
+ is_final_block = i == len(block_out_channels) - 1
1315
+
1316
+ prev_output_channel = output_channel
1317
+ output_channel = reversed_block_out_channels[i]
1318
+ input_channel = reversed_block_out_channels[
1319
+ min(i + 1, len(block_out_channels) - 1)
1320
+ ]
1321
+
1322
+ # add upsample block for all BUT final layer
1323
+ if not is_final_block:
1324
+ add_upsample = True
1325
+ self.num_upsamplers += 1
1326
+ else:
1327
+ add_upsample = False
1328
+
1329
+ up_block = get_up_block_3d(
1330
+ up_block_type,
1331
+ num_layers=reversed_layers_per_block[i] + 1,
1332
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
1333
+ in_channels=input_channel,
1334
+ out_channels=output_channel,
1335
+ prev_output_channel=prev_output_channel,
1336
+ temb_channels=blocks_time_embed_dim,
1337
+ add_upsample=add_upsample,
1338
+ resnet_eps=1e-5,
1339
+ resolution_idx=i,
1340
+ cross_attention_dim=reversed_cross_attention_dim[i],
1341
+ num_attention_heads=reversed_num_attention_heads[i],
1342
+ resnet_act_fn="silu",
1343
+ )
1344
+ self.up_blocks.append(up_block)
1345
+ prev_output_channel = output_channel
1346
+
1347
+ # out
1348
+ self.conv_norm_out = nn.GroupNorm(
1349
+ num_channels=block_out_channels[0], num_groups=32, eps=1e-5
1350
+ )
1351
+ self.conv_act = nn.SiLU()
1352
+
1353
+ self.conv_out = nn.Conv2d(
1354
+ block_out_channels[0],
1355
+ out_channels,
1356
+ kernel_size=3,
1357
+ padding=1,
1358
+ )
1359
+
1360
+ # self.set_default_attn_processor()
1361
+
1362
+ @property
1363
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
1364
+ r"""
1365
+ Returns:
1366
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
1367
+ indexed by its weight name.
1368
+ """
1369
+ # set recursively
1370
+ processors = {}
1371
+
1372
+ def fn_recursive_add_processors(
1373
+ name: str,
1374
+ module: torch.nn.Module,
1375
+ processors: Dict[str, AttentionProcessor],
1376
+ ):
1377
+ if hasattr(module, "get_processor"):
1378
+ processors[f"{name}.processor"] = module.get_processor()
1379
+
1380
+ for sub_name, child in module.named_children():
1381
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
1382
+
1383
+ return processors
1384
+
1385
+ for name, module in self.named_children():
1386
+ fn_recursive_add_processors(name, module, processors)
1387
+
1388
+ return processors
1389
+
1390
+ def set_attn_processor(self, processor):
1391
+ r"""
1392
+ Sets the attention processor to use to compute attention.
1393
+
1394
+ Parameters:
1395
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
1396
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
1397
+ for **all** `Attention` layers.
1398
+
1399
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
1400
+ processor. This is strongly recommended when setting trainable attention processors.
1401
+
1402
+ """
1403
+ count = len(self.attn_processors.keys())
1404
+
1405
+ if isinstance(processor, dict) and len(processor) != count:
1406
+ raise ValueError(
1407
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
1408
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
1409
+ )
1410
+
1411
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
1412
+ if hasattr(module, "set_processor"):
1413
+ if not isinstance(processor, dict):
1414
+ module.set_processor(processor)
1415
+ else:
1416
+ module.set_processor(processor.pop(f"{name}.processor"))
1417
+
1418
+ for sub_name, child in module.named_children():
1419
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
1420
+
1421
+ for name, module in self.named_children():
1422
+ fn_recursive_attn_processor(name, module, processor)
1423
+
1424
+ def set_default_attn_processor(self):
1425
+ """
1426
+ Disables custom attention processors and sets the default attention implementation.
1427
+ """
1428
+ if all(
1429
+ proc.__class__ in CROSS_ATTENTION_PROCESSORS
1430
+ for proc in self.attn_processors.values()
1431
+ ):
1432
+ processor = AttnProcessor()
1433
+ else:
1434
+ raise ValueError(
1435
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
1436
+ )
1437
+
1438
+ self.set_attn_processor(processor)
1439
+
1440
+ def _set_gradient_checkpointing(self, module, value=False):
1441
+ if hasattr(module, "gradient_checkpointing"):
1442
+ module.gradient_checkpointing = value
1443
+
1444
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
1445
+ def enable_forward_chunking(
1446
+ self, chunk_size: Optional[int] = None, dim: int = 0
1447
+ ) -> None:
1448
+ """
1449
+ Sets the attention processor to use [feed forward
1450
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
1451
+
1452
+ Parameters:
1453
+ chunk_size (`int`, *optional*):
1454
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
1455
+ over each tensor of dim=`dim`.
1456
+ dim (`int`, *optional*, defaults to `0`):
1457
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
1458
+ or dim=1 (sequence length).
1459
+ """
1460
+ if dim not in [0, 1]:
1461
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
1462
+
1463
+ # By default chunk size is 1
1464
+ chunk_size = chunk_size or 1
1465
+
1466
+ def fn_recursive_feed_forward(
1467
+ module: torch.nn.Module, chunk_size: int, dim: int
1468
+ ):
1469
+ if hasattr(module, "set_chunk_feed_forward"):
1470
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
1471
+
1472
+ for child in module.children():
1473
+ fn_recursive_feed_forward(child, chunk_size, dim)
1474
+
1475
+ for module in self.children():
1476
+ fn_recursive_feed_forward(module, chunk_size, dim)
1477
+
1478
+ def forward(
1479
+ self,
1480
+ x: torch.Tensor,
1481
+ timestep: Union[torch.Tensor, float, int],
1482
+ encoder_hidden_states: torch.Tensor,
1483
+ cond_image=None,
1484
+ mask=None,
1485
+ # added_time_ids: torch.Tensor,
1486
+ return_dict: bool = True,
1487
+ ) -> Union[UNetSTICOutput, Tuple]:
1488
+ r"""
1489
+ The [`UNetSpatioTemporalConditionModel`] forward method.
1490
+
1491
+ Args:
1492
+ sample (`torch.Tensor`):
1493
+ The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
1494
+ timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
1495
+ encoder_hidden_states (`torch.Tensor`):
1496
+ The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
1497
+ added_time_ids: (`torch.Tensor`):
1498
+ The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
1499
+ embeddings and added to the time embeddings.
1500
+ return_dict (`bool`, *optional*, defaults to `True`):
1501
+ Whether or not to return a [`~models.unet_slatio_temporal.UNetSTICOutput`] instead
1502
+ of a plain tuple.
1503
+ Returns:
1504
+ [`~models.unet_slatio_temporal.UNetSTICOutput`] or `tuple`:
1505
+ If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSTICOutput`] is
1506
+ returned, otherwise a `tuple` is returned where the first element is the sample tensor.
1507
+ """
1508
+
1509
+ sample = torch.cat([x, cond_image], dim=1) # B C+1 T H W
1510
+
1511
+ # pad to multiple of 2**n
1512
+ res_target = 2 ** (np.ceil(np.log2(sample.shape[-1])).astype(int))
1513
+ padding = (res_target - sample.shape[-1]) // 2
1514
+ sample = F.pad(
1515
+ sample, (padding, padding, padding, padding, 0, 0), mode="circular"
1516
+ )
1517
+
1518
+ # reshape from B C T H W to B T C H W
1519
+ sample = sample.permute(0, 2, 1, 3, 4)
1520
+
1521
+ # 1. time
1522
+ timesteps = timestep
1523
+ if not torch.is_tensor(timesteps):
1524
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
1525
+ # This would be a good case for the `match` statement (Python 3.10+)
1526
+ is_mps = sample.device.type == "mps"
1527
+ if isinstance(timestep, float):
1528
+ dtype = torch.float32 if is_mps else torch.float64
1529
+ else:
1530
+ dtype = torch.int32 if is_mps else torch.int64
1531
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
1532
+ elif len(timesteps.shape) == 0:
1533
+ timesteps = timesteps[None].to(sample.device)
1534
+
1535
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1536
+ batch_size, num_frames = sample.shape[:2]
1537
+ timesteps = timesteps.expand(batch_size)
1538
+
1539
+ t_emb = self.time_proj(timesteps)
1540
+
1541
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1542
+ # but time_embedding might actually be running in fp16. so we need to cast here.
1543
+ # there might be better ways to encapsulate this.
1544
+ t_emb = t_emb.to(dtype=sample.dtype)
1545
+
1546
+ emb = self.time_embedding(t_emb)
1547
+
1548
+ # time_embeds = self.add_time_proj(added_time_ids.flatten())
1549
+ # time_embeds = time_embeds.reshape((batch_size, -1))
1550
+ # time_embeds = time_embeds.to(emb.dtype)
1551
+ # aug_emb = self.add_embedding(time_embeds)
1552
+ # emb = emb + aug_emb
1553
+
1554
+ # Flatten the batch and frames dimensions
1555
+ # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
1556
+ sample = sample.flatten(0, 1)
1557
+ # Repeat the embeddings num_video_frames times
1558
+ # emb: [batch, channels] -> [batch * frames, channels]
1559
+ emb = emb.repeat_interleave(num_frames, dim=0)
1560
+ # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
1561
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(
1562
+ num_frames, dim=0
1563
+ )
1564
+
1565
+ # 2. pre-process
1566
+ sample = self.conv_in(sample)
1567
+
1568
+ image_only_indicator = torch.zeros(
1569
+ batch_size, num_frames, dtype=sample.dtype, device=sample.device
1570
+ )
1571
+
1572
+ down_block_res_samples = (sample,)
1573
+ for downsample_block in self.down_blocks:
1574
+ if (
1575
+ hasattr(downsample_block, "has_cross_attention")
1576
+ and downsample_block.has_cross_attention
1577
+ ):
1578
+ sample, res_samples = downsample_block(
1579
+ hidden_states=sample,
1580
+ temb=emb,
1581
+ encoder_hidden_states=encoder_hidden_states,
1582
+ image_only_indicator=image_only_indicator,
1583
+ )
1584
+ else:
1585
+ sample, res_samples = downsample_block(
1586
+ hidden_states=sample,
1587
+ temb=emb,
1588
+ image_only_indicator=image_only_indicator,
1589
+ )
1590
+
1591
+ down_block_res_samples += res_samples
1592
+
1593
+ # 4. mid
1594
+ sample = self.mid_block(
1595
+ hidden_states=sample,
1596
+ temb=emb,
1597
+ encoder_hidden_states=encoder_hidden_states,
1598
+ image_only_indicator=image_only_indicator,
1599
+ )
1600
+
1601
+ # 5. up
1602
+ for i, upsample_block in enumerate(self.up_blocks):
1603
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1604
+ down_block_res_samples = down_block_res_samples[
1605
+ : -len(upsample_block.resnets)
1606
+ ]
1607
+
1608
+ if (
1609
+ hasattr(upsample_block, "has_cross_attention")
1610
+ and upsample_block.has_cross_attention
1611
+ ):
1612
+ sample = upsample_block(
1613
+ hidden_states=sample,
1614
+ temb=emb,
1615
+ res_hidden_states_tuple=res_samples,
1616
+ encoder_hidden_states=encoder_hidden_states,
1617
+ image_only_indicator=image_only_indicator,
1618
+ )
1619
+ else:
1620
+ sample = upsample_block(
1621
+ hidden_states=sample,
1622
+ temb=emb,
1623
+ res_hidden_states_tuple=res_samples,
1624
+ image_only_indicator=image_only_indicator,
1625
+ )
1626
+
1627
+ # 6. post-process
1628
+ sample = self.conv_norm_out(sample)
1629
+ sample = self.conv_act(sample)
1630
+ sample = self.conv_out(sample)
1631
+
1632
+ # 7. Reshape back to original shape
1633
+ sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
1634
+
1635
+ if padding > 0:
1636
+ sample = sample[:, :, :, padding:-padding, padding:-padding]
1637
+
1638
+ # reshape back to B C T H W
1639
+ sample = sample.permute(0, 2, 1, 3, 4)
1640
+
1641
+ if not return_dict:
1642
+ return (sample,)
1643
+
1644
+ return UNetSTICOutput(sample=sample)
1645
+
1646
+
1647
+ class ContrastiveModel(nn.Module):
1648
+ def __init__(self, in_channels, out_channels, backbone=None, kl_loss_weight=0.0):
1649
+ super(ContrastiveModel, self).__init__()
1650
+
1651
+ assert backbone is not None, "Backbone must be provided."
1652
+ self.backbone = backbone
1653
+
1654
+ self.backbone = self.patch_backbone(self.backbone, in_channels, out_channels)
1655
+
1656
+ self.fc_end = nn.Linear(out_channels, 1)
1657
+
1658
+ self.kl_loss_weight = kl_loss_weight
1659
+
1660
+ @classmethod
1661
+ def patch_backbone(cls, backbone, in_channels, out_channels):
1662
+ if "ResNet" in backbone.__class__.__name__:
1663
+ backbone.model.conv1 = nn.Conv2d(
1664
+ in_channels,
1665
+ 64,
1666
+ kernel_size=(7, 7),
1667
+ stride=(2, 2),
1668
+ padding=(3, 3),
1669
+ bias=False,
1670
+ )
1671
+ backbone.model.fc = nn.Linear(
1672
+ in_features=512, out_features=out_channels, bias=True
1673
+ )
1674
+ else:
1675
+ raise Exception(
1676
+ "Invalid argument: "
1677
+ + backbone.__class__.__name__
1678
+ + "\nChoose ResNet! Other architectures are not yet implemented in this framework."
1679
+ )
1680
+
1681
+ return backbone
1682
+
1683
+ def forward_once(self, x):
1684
+ features = self.backbone(x)
1685
+ output = torch.sigmoid(features)
1686
+ return output, features
1687
+
1688
+ def forward_constrastive(self, input1, input2):
1689
+ y1 = self.forward_once(input1)
1690
+ y2 = self.forward_once(input2)
1691
+
1692
+ difference = torch.abs(y1 - y2)
1693
+ output = self.fc_end(difference) # linear layer
1694
+
1695
+ return output # B x 1
1696
+
1697
+ def forward_fused(self, input1, input2):
1698
+ inputs = torch.cat((input1, input2), dim=0) # 2B x C x H x W
1699
+ outputs, features = self.forward_once(inputs)
1700
+ y1, y2 = torch.split(outputs, outputs.size(0) // 2, dim=0)
1701
+ difference = torch.abs(y1 - y2)
1702
+ output = self.fc_end(difference)
1703
+
1704
+ # Compute KL divergence
1705
+ if self.kl_loss_weight > 0:
1706
+ mu = torch.mean(features, dim=0)
1707
+ var = torch.var(features, dim=0) + 1e-6 # Add epsilon to avoid log(0)
1708
+ kl_loss = 0.5 * torch.sum(mu.pow(2) + var - torch.log(var) - 1)
1709
+ else:
1710
+ kl_loss = torch.zeros((1,), device=output.device)
1711
+ return output, kl_loss
1712
+
1713
+ def loss(self, output, target):
1714
+ return nn.functional.binary_cross_entropy_with_logits(output, target[:, None])
1715
+
1716
+ def forward(self, input1, input2, target):
1717
+ y_hat, kl_loss = self.forward_fused(input1, input2)
1718
+ loss = self.loss(y_hat, target)
1719
+ total_loss = loss + self.kl_loss_weight * kl_loss
1720
+ return total_loss, loss, kl_loss
1721
+
1722
+
1723
+ class ResNet18(ModelMixin, ConfigMixin):
1724
+ @register_to_config
1725
+ def __init__(self, weights=None, progress=False):
1726
+ super(ResNet18, self).__init__()
1727
+ self.model = resnet18(weights=weights, progress=progress)
1728
+
1729
+ def forward(self, x):
1730
+ return self.model(x)
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ opencv-python==4.9.0.80
2
+ diffusers==0.30.3
3
+ einops==0.7.0
4
+ gradio==5.22.0
5
+ huggingface-hub==0.29.3
6
+ numpy==1.26.4
7
+ omegaconf==2.3.0
8
+ pillow==10.2.0
9
+ safetensors==0.4.5
10
+ torch==2.2.2
11
+ torchdiffeq==0.2.4
12
+ xformers==0.0.25.post1
13
+ timm==0.9.16
14
+ accelerate==0.34.2