bubbliiiing commited on
Commit
a5c8285
·
1 Parent(s): 4f2d355

Update Space

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +160 -0
  2. asset/1.png +0 -0
  3. asset/2.png +0 -0
  4. asset/3.png +0 -0
  5. asset/4.png +0 -0
  6. asset/5.png +0 -0
  7. cogvideox/__init__.py +0 -0
  8. cogvideox/api/api.py +173 -0
  9. cogvideox/api/post_infer.py +89 -0
  10. cogvideox/data/bucket_sampler.py +379 -0
  11. cogvideox/data/dataset_image.py +76 -0
  12. cogvideox/data/dataset_image_video.py +589 -0
  13. cogvideox/data/dataset_video.py +262 -0
  14. cogvideox/models/__init__.py +8 -0
  15. cogvideox/models/cogvideox_transformer3d.py +797 -0
  16. cogvideox/models/cogvideox_vae.py +1675 -0
  17. cogvideox/models/wan_image_encoder.py +553 -0
  18. cogvideox/models/wan_text_encoder.py +324 -0
  19. cogvideox/models/wan_transformer3d.py +961 -0
  20. cogvideox/models/wan_vae.py +706 -0
  21. cogvideox/models/wan_xlm_roberta.py +170 -0
  22. cogvideox/pipeline/__init__.py +8 -0
  23. cogvideox/pipeline/pipeline_cogvideox_fun.py +877 -0
  24. cogvideox/pipeline/pipeline_cogvideox_fun_control.py +971 -0
  25. cogvideox/pipeline/pipeline_cogvideox_fun_inpaint.py +1151 -0
  26. cogvideox/pipeline/pipeline_wan_fun.py +562 -0
  27. cogvideox/pipeline/pipeline_wan_fun_inpaint.py +729 -0
  28. cogvideox/ui/cogvideox_fun_ui.py +722 -0
  29. cogvideox/ui/controller.py +390 -0
  30. cogvideox/ui/ui.py +288 -0
  31. cogvideox/ui/wan_fun_ui.py +742 -0
  32. cogvideox/ui/wan_ui.py +730 -0
  33. cogvideox/utils/__init__.py +0 -0
  34. cogvideox/utils/discrete_sampler.py +46 -0
  35. cogvideox/utils/fp8_optimization.py +56 -0
  36. cogvideox/utils/lora_utils.py +502 -0
  37. cogvideox/utils/utils.py +215 -0
  38. config/wan2.1/wan_civitai.yaml +39 -0
  39. config/zero_stage2_config.json +16 -0
  40. config/zero_stage3_config.json +28 -0
  41. examples/cogvideox_fun/app.py +67 -0
  42. examples/cogvideox_fun/predict_i2v.py +268 -0
  43. examples/cogvideox_fun/predict_t2v.py +208 -0
  44. examples/cogvideox_fun/predict_v2v.py +203 -0
  45. examples/cogvideox_fun/predict_v2v_control.py +188 -0
  46. examples/wan2.1/app.py +69 -0
  47. examples/wan2.1/predict_i2v.py +198 -0
  48. examples/wan2.1/predict_t2v.py +179 -0
  49. examples/wan2.1_fun/app.py +69 -0
  50. examples/wan2.1_fun/predict_i2v.py +199 -0
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
asset/1.png ADDED
asset/2.png ADDED
asset/3.png ADDED
asset/4.png ADDED
asset/5.png ADDED
cogvideox/__init__.py ADDED
File without changes
cogvideox/api/api.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import gc
3
+ import base64
4
+ import torch
5
+ import gradio as gr
6
+ import tempfile
7
+ import hashlib
8
+ import os
9
+
10
+ from fastapi import FastAPI
11
+ from io import BytesIO
12
+ from PIL import Image
13
+
14
+ # Function to encode a file to Base64
15
+ def encode_file_to_base64(file_path):
16
+ with open(file_path, "rb") as file:
17
+ # Encode the data to Base64
18
+ file_base64 = base64.b64encode(file.read())
19
+ return file_base64
20
+
21
+ def update_edition_api(_: gr.Blocks, app: FastAPI, controller):
22
+ @app.post("/cogvideox_fun/update_edition")
23
+ def _update_edition_api(
24
+ datas: dict,
25
+ ):
26
+ edition = datas.get('edition', 'v2')
27
+
28
+ try:
29
+ controller.update_edition(
30
+ edition
31
+ )
32
+ comment = "Success"
33
+ except Exception as e:
34
+ torch.cuda.empty_cache()
35
+ comment = f"Error. error information is {str(e)}"
36
+
37
+ return {"message": comment}
38
+
39
+ def update_diffusion_transformer_api(_: gr.Blocks, app: FastAPI, controller):
40
+ @app.post("/cogvideox_fun/update_diffusion_transformer")
41
+ def _update_diffusion_transformer_api(
42
+ datas: dict,
43
+ ):
44
+ diffusion_transformer_path = datas.get('diffusion_transformer_path', 'none')
45
+
46
+ try:
47
+ controller.update_diffusion_transformer(
48
+ diffusion_transformer_path
49
+ )
50
+ comment = "Success"
51
+ except Exception as e:
52
+ torch.cuda.empty_cache()
53
+ comment = f"Error. error information is {str(e)}"
54
+
55
+ return {"message": comment}
56
+
57
+ def save_base64_video(base64_string):
58
+ video_data = base64.b64decode(base64_string)
59
+
60
+ md5_hash = hashlib.md5(video_data).hexdigest()
61
+ filename = f"{md5_hash}.mp4"
62
+
63
+ temp_dir = tempfile.gettempdir()
64
+ file_path = os.path.join(temp_dir, filename)
65
+
66
+ with open(file_path, 'wb') as video_file:
67
+ video_file.write(video_data)
68
+
69
+ return file_path
70
+
71
+ def save_base64_image(base64_string):
72
+ video_data = base64.b64decode(base64_string)
73
+
74
+ md5_hash = hashlib.md5(video_data).hexdigest()
75
+ filename = f"{md5_hash}.jpg"
76
+
77
+ temp_dir = tempfile.gettempdir()
78
+ file_path = os.path.join(temp_dir, filename)
79
+
80
+ with open(file_path, 'wb') as video_file:
81
+ video_file.write(video_data)
82
+
83
+ return file_path
84
+
85
+ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
86
+ @app.post("/cogvideox_fun/infer_forward")
87
+ def _infer_forward_api(
88
+ datas: dict,
89
+ ):
90
+ base_model_path = datas.get('base_model_path', 'none')
91
+ lora_model_path = datas.get('lora_model_path', 'none')
92
+ lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
93
+ prompt_textbox = datas.get('prompt_textbox', None)
94
+ negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. ')
95
+ sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
96
+ sample_step_slider = datas.get('sample_step_slider', 30)
97
+ resize_method = datas.get('resize_method', "Generate by")
98
+ width_slider = datas.get('width_slider', 672)
99
+ height_slider = datas.get('height_slider', 384)
100
+ base_resolution = datas.get('base_resolution', 512)
101
+ is_image = datas.get('is_image', False)
102
+ generation_method = datas.get('generation_method', False)
103
+ length_slider = datas.get('length_slider', 49)
104
+ overlap_video_length = datas.get('overlap_video_length', 4)
105
+ partial_video_length = datas.get('partial_video_length', 72)
106
+ cfg_scale_slider = datas.get('cfg_scale_slider', 6)
107
+ start_image = datas.get('start_image', None)
108
+ end_image = datas.get('end_image', None)
109
+ validation_video = datas.get('validation_video', None)
110
+ validation_video_mask = datas.get('validation_video_mask', None)
111
+ control_video = datas.get('control_video', None)
112
+ denoise_strength = datas.get('denoise_strength', 0.70)
113
+ seed_textbox = datas.get("seed_textbox", 43)
114
+
115
+ generation_method = "Image Generation" if is_image else generation_method
116
+
117
+ if start_image is not None:
118
+ start_image = base64.b64decode(start_image)
119
+ start_image = [Image.open(BytesIO(start_image))]
120
+
121
+ if end_image is not None:
122
+ end_image = base64.b64decode(end_image)
123
+ end_image = [Image.open(BytesIO(end_image))]
124
+
125
+ if validation_video is not None:
126
+ validation_video = save_base64_video(validation_video)
127
+
128
+ if validation_video_mask is not None:
129
+ validation_video_mask = save_base64_image(validation_video_mask)
130
+
131
+ if control_video is not None:
132
+ control_video = save_base64_video(control_video)
133
+
134
+ try:
135
+ save_sample_path, comment = controller.generate(
136
+ "",
137
+ base_model_path,
138
+ lora_model_path,
139
+ lora_alpha_slider,
140
+ prompt_textbox,
141
+ negative_prompt_textbox,
142
+ sampler_dropdown,
143
+ sample_step_slider,
144
+ resize_method,
145
+ width_slider,
146
+ height_slider,
147
+ base_resolution,
148
+ generation_method,
149
+ length_slider,
150
+ overlap_video_length,
151
+ partial_video_length,
152
+ cfg_scale_slider,
153
+ start_image,
154
+ end_image,
155
+ validation_video,
156
+ validation_video_mask,
157
+ control_video,
158
+ denoise_strength,
159
+ seed_textbox,
160
+ is_api = True,
161
+ )
162
+ except Exception as e:
163
+ gc.collect()
164
+ torch.cuda.empty_cache()
165
+ torch.cuda.ipc_collect()
166
+ save_sample_path = ""
167
+ comment = f"Error. error information is {str(e)}"
168
+ return {"message": comment}
169
+
170
+ if save_sample_path != "":
171
+ return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
172
+ else:
173
+ return {"message": comment, "save_sample_path": save_sample_path}
cogvideox/api/post_infer.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import sys
4
+ import time
5
+ from datetime import datetime
6
+ from io import BytesIO
7
+
8
+ import cv2
9
+ import requests
10
+ import base64
11
+
12
+
13
+ def post_diffusion_transformer(diffusion_transformer_path, url='http://127.0.0.1:7860'):
14
+ datas = json.dumps({
15
+ "diffusion_transformer_path": diffusion_transformer_path
16
+ })
17
+ r = requests.post(f'{url}/cogvideox_fun/update_diffusion_transformer', data=datas, timeout=1500)
18
+ data = r.content.decode('utf-8')
19
+ return data
20
+
21
+ def post_update_edition(edition, url='http://0.0.0.0:7860'):
22
+ datas = json.dumps({
23
+ "edition": edition
24
+ })
25
+ r = requests.post(f'{url}/cogvideox_fun/update_edition', data=datas, timeout=1500)
26
+ data = r.content.decode('utf-8')
27
+ return data
28
+
29
+ def post_infer(generation_method, length_slider, url='http://127.0.0.1:7860'):
30
+ datas = json.dumps({
31
+ "base_model_path": "none",
32
+ "motion_module_path": "none",
33
+ "lora_model_path": "none",
34
+ "lora_alpha_slider": 0.55,
35
+ "prompt_textbox": "A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
36
+ "negative_prompt_textbox": "The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. ",
37
+ "sampler_dropdown": "Euler",
38
+ "sample_step_slider": 50,
39
+ "width_slider": 672,
40
+ "height_slider": 384,
41
+ "generation_method": "Video Generation",
42
+ "length_slider": length_slider,
43
+ "cfg_scale_slider": 6,
44
+ "seed_textbox": 43,
45
+ })
46
+ r = requests.post(f'{url}/cogvideox_fun/infer_forward', data=datas, timeout=1500)
47
+ data = r.content.decode('utf-8')
48
+ return data
49
+
50
+ if __name__ == '__main__':
51
+ # initiate time
52
+ now_date = datetime.now()
53
+ time_start = time.time()
54
+
55
+ # -------------------------- #
56
+ # Step 1: update edition
57
+ # -------------------------- #
58
+ diffusion_transformer_path = "models/Diffusion_Transformer/CogVideoX-Fun-2b-InP"
59
+ outputs = post_diffusion_transformer(diffusion_transformer_path)
60
+ print('Output update edition: ', outputs)
61
+
62
+ # -------------------------- #
63
+ # Step 2: infer
64
+ # -------------------------- #
65
+ # "Video Generation" and "Image Generation"
66
+ generation_method = "Video Generation"
67
+ length_slider = 49
68
+ outputs = post_infer(generation_method, length_slider)
69
+
70
+ # Get decoded data
71
+ outputs = json.loads(outputs)
72
+ base64_encoding = outputs["base64_encoding"]
73
+ decoded_data = base64.b64decode(base64_encoding)
74
+
75
+ is_image = True if generation_method == "Image Generation" else False
76
+ if is_image or length_slider == 1:
77
+ file_path = "1.png"
78
+ else:
79
+ file_path = "1.mp4"
80
+ with open(file_path, "wb") as file:
81
+ file.write(decoded_data)
82
+
83
+ # End of record time
84
+ # The calculated time difference is the execution time of the program, expressed in seconds / s
85
+ time_end = time.time()
86
+ time_sum = (time_end - time_start) % 60
87
+ print('# --------------------------------------------------------- #')
88
+ print(f'# Total expenditure: {time_sum}s')
89
+ print('# --------------------------------------------------------- #')
cogvideox/data/bucket_sampler.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os
3
+ from typing import (Generic, Iterable, Iterator, List, Optional, Sequence,
4
+ Sized, TypeVar, Union)
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ from PIL import Image
10
+ from torch.utils.data import BatchSampler, Dataset, Sampler
11
+
12
+ ASPECT_RATIO_512 = {
13
+ '0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0],
14
+ '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
15
+ '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
16
+ '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
17
+ '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
18
+ '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
19
+ '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
20
+ '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
21
+ '2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0],
22
+ '3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0]
23
+ }
24
+ ASPECT_RATIO_RANDOM_CROP_512 = {
25
+ '0.42': [320.0, 768.0], '0.5': [352.0, 704.0],
26
+ '0.57': [384.0, 672.0], '0.68': [416.0, 608.0], '0.78': [448.0, 576.0], '0.88': [480.0, 544.0],
27
+ '0.94': [480.0, 512.0], '1.0': [512.0, 512.0], '1.07': [512.0, 480.0],
28
+ '1.13': [544.0, 480.0], '1.29': [576.0, 448.0], '1.46': [608.0, 416.0], '1.75': [672.0, 384.0],
29
+ '2.0': [704.0, 352.0], '2.4': [768.0, 320.0]
30
+ }
31
+ ASPECT_RATIO_RANDOM_CROP_PROB = [
32
+ 1, 2,
33
+ 4, 4, 4, 4,
34
+ 8, 8, 8,
35
+ 4, 4, 4, 4,
36
+ 2, 1
37
+ ]
38
+ ASPECT_RATIO_RANDOM_CROP_PROB = np.array(ASPECT_RATIO_RANDOM_CROP_PROB) / sum(ASPECT_RATIO_RANDOM_CROP_PROB)
39
+
40
+ def get_closest_ratio(height: float, width: float, ratios: dict = ASPECT_RATIO_512):
41
+ aspect_ratio = height / width
42
+ closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
43
+ return ratios[closest_ratio], float(closest_ratio)
44
+
45
+ def get_image_size_without_loading(path):
46
+ with Image.open(path) as img:
47
+ return img.size # (width, height)
48
+
49
+ class RandomSampler(Sampler[int]):
50
+ r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
51
+
52
+ If with replacement, then user can specify :attr:`num_samples` to draw.
53
+
54
+ Args:
55
+ data_source (Dataset): dataset to sample from
56
+ replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
57
+ num_samples (int): number of samples to draw, default=`len(dataset)`.
58
+ generator (Generator): Generator used in sampling.
59
+ """
60
+
61
+ data_source: Sized
62
+ replacement: bool
63
+
64
+ def __init__(self, data_source: Sized, replacement: bool = False,
65
+ num_samples: Optional[int] = None, generator=None) -> None:
66
+ self.data_source = data_source
67
+ self.replacement = replacement
68
+ self._num_samples = num_samples
69
+ self.generator = generator
70
+ self._pos_start = 0
71
+
72
+ if not isinstance(self.replacement, bool):
73
+ raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}")
74
+
75
+ if not isinstance(self.num_samples, int) or self.num_samples <= 0:
76
+ raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}")
77
+
78
+ @property
79
+ def num_samples(self) -> int:
80
+ # dataset size might change at runtime
81
+ if self._num_samples is None:
82
+ return len(self.data_source)
83
+ return self._num_samples
84
+
85
+ def __iter__(self) -> Iterator[int]:
86
+ n = len(self.data_source)
87
+ if self.generator is None:
88
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
89
+ generator = torch.Generator()
90
+ generator.manual_seed(seed)
91
+ else:
92
+ generator = self.generator
93
+
94
+ if self.replacement:
95
+ for _ in range(self.num_samples // 32):
96
+ yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
97
+ yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
98
+ else:
99
+ for _ in range(self.num_samples // n):
100
+ xx = torch.randperm(n, generator=generator).tolist()
101
+ if self._pos_start >= n:
102
+ self._pos_start = 0
103
+ print("xx top 10", xx[:10], self._pos_start)
104
+ for idx in range(self._pos_start, n):
105
+ yield xx[idx]
106
+ self._pos_start = (self._pos_start + 1) % n
107
+ self._pos_start = 0
108
+ yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
109
+
110
+ def __len__(self) -> int:
111
+ return self.num_samples
112
+
113
+ class AspectRatioBatchImageSampler(BatchSampler):
114
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
115
+
116
+ Args:
117
+ sampler (Sampler): Base sampler.
118
+ dataset (Dataset): Dataset providing data information.
119
+ batch_size (int): Size of mini-batch.
120
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
121
+ its size would be less than ``batch_size``.
122
+ aspect_ratios (dict): The predefined aspect ratios.
123
+ """
124
+ def __init__(
125
+ self,
126
+ sampler: Sampler,
127
+ dataset: Dataset,
128
+ batch_size: int,
129
+ train_folder: str = None,
130
+ aspect_ratios: dict = ASPECT_RATIO_512,
131
+ drop_last: bool = False,
132
+ config=None,
133
+ **kwargs
134
+ ) -> None:
135
+ if not isinstance(sampler, Sampler):
136
+ raise TypeError('sampler should be an instance of ``Sampler``, '
137
+ f'but got {sampler}')
138
+ if not isinstance(batch_size, int) or batch_size <= 0:
139
+ raise ValueError('batch_size should be a positive integer value, '
140
+ f'but got batch_size={batch_size}')
141
+ self.sampler = sampler
142
+ self.dataset = dataset
143
+ self.train_folder = train_folder
144
+ self.batch_size = batch_size
145
+ self.aspect_ratios = aspect_ratios
146
+ self.drop_last = drop_last
147
+ self.config = config
148
+ # buckets for each aspect ratio
149
+ self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
150
+ # [str(k) for k, v in aspect_ratios]
151
+ self.current_available_bucket_keys = list(aspect_ratios.keys())
152
+
153
+ def __iter__(self):
154
+ for idx in self.sampler:
155
+ try:
156
+ image_dict = self.dataset[idx]
157
+
158
+ width, height = image_dict.get("width", None), image_dict.get("height", None)
159
+ if width is None or height is None:
160
+ image_id, name = image_dict['file_path'], image_dict['text']
161
+ if self.train_folder is None:
162
+ image_dir = image_id
163
+ else:
164
+ image_dir = os.path.join(self.train_folder, image_id)
165
+
166
+ width, height = get_image_size_without_loading(image_dir)
167
+
168
+ ratio = height / width # self.dataset[idx]
169
+ else:
170
+ height = int(height)
171
+ width = int(width)
172
+ ratio = height / width # self.dataset[idx]
173
+ except Exception as e:
174
+ print(e)
175
+ continue
176
+ # find the closest aspect ratio
177
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
178
+ if closest_ratio not in self.current_available_bucket_keys:
179
+ continue
180
+ bucket = self._aspect_ratio_buckets[closest_ratio]
181
+ bucket.append(idx)
182
+ # yield a batch of indices in the same aspect ratio group
183
+ if len(bucket) == self.batch_size:
184
+ yield bucket[:]
185
+ del bucket[:]
186
+
187
+ class AspectRatioBatchSampler(BatchSampler):
188
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
189
+
190
+ Args:
191
+ sampler (Sampler): Base sampler.
192
+ dataset (Dataset): Dataset providing data information.
193
+ batch_size (int): Size of mini-batch.
194
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
195
+ its size would be less than ``batch_size``.
196
+ aspect_ratios (dict): The predefined aspect ratios.
197
+ """
198
+ def __init__(
199
+ self,
200
+ sampler: Sampler,
201
+ dataset: Dataset,
202
+ batch_size: int,
203
+ video_folder: str = None,
204
+ train_data_format: str = "webvid",
205
+ aspect_ratios: dict = ASPECT_RATIO_512,
206
+ drop_last: bool = False,
207
+ config=None,
208
+ **kwargs
209
+ ) -> None:
210
+ if not isinstance(sampler, Sampler):
211
+ raise TypeError('sampler should be an instance of ``Sampler``, '
212
+ f'but got {sampler}')
213
+ if not isinstance(batch_size, int) or batch_size <= 0:
214
+ raise ValueError('batch_size should be a positive integer value, '
215
+ f'but got batch_size={batch_size}')
216
+ self.sampler = sampler
217
+ self.dataset = dataset
218
+ self.video_folder = video_folder
219
+ self.train_data_format = train_data_format
220
+ self.batch_size = batch_size
221
+ self.aspect_ratios = aspect_ratios
222
+ self.drop_last = drop_last
223
+ self.config = config
224
+ # buckets for each aspect ratio
225
+ self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
226
+ # [str(k) for k, v in aspect_ratios]
227
+ self.current_available_bucket_keys = list(aspect_ratios.keys())
228
+
229
+ def __iter__(self):
230
+ for idx in self.sampler:
231
+ try:
232
+ video_dict = self.dataset[idx]
233
+ width, more = video_dict.get("width", None), video_dict.get("height", None)
234
+
235
+ if width is None or height is None:
236
+ if self.train_data_format == "normal":
237
+ video_id, name = video_dict['file_path'], video_dict['text']
238
+ if self.video_folder is None:
239
+ video_dir = video_id
240
+ else:
241
+ video_dir = os.path.join(self.video_folder, video_id)
242
+ else:
243
+ videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
244
+ video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
245
+ cap = cv2.VideoCapture(video_dir)
246
+
247
+ # 获取视频尺寸
248
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数
249
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数
250
+
251
+ ratio = height / width # self.dataset[idx]
252
+ else:
253
+ height = int(height)
254
+ width = int(width)
255
+ ratio = height / width # self.dataset[idx]
256
+ except Exception as e:
257
+ print(e)
258
+ continue
259
+ # find the closest aspect ratio
260
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
261
+ if closest_ratio not in self.current_available_bucket_keys:
262
+ continue
263
+ bucket = self._aspect_ratio_buckets[closest_ratio]
264
+ bucket.append(idx)
265
+ # yield a batch of indices in the same aspect ratio group
266
+ if len(bucket) == self.batch_size:
267
+ yield bucket[:]
268
+ del bucket[:]
269
+
270
+ class AspectRatioBatchImageVideoSampler(BatchSampler):
271
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
272
+
273
+ Args:
274
+ sampler (Sampler): Base sampler.
275
+ dataset (Dataset): Dataset providing data information.
276
+ batch_size (int): Size of mini-batch.
277
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
278
+ its size would be less than ``batch_size``.
279
+ aspect_ratios (dict): The predefined aspect ratios.
280
+ """
281
+
282
+ def __init__(self,
283
+ sampler: Sampler,
284
+ dataset: Dataset,
285
+ batch_size: int,
286
+ train_folder: str = None,
287
+ aspect_ratios: dict = ASPECT_RATIO_512,
288
+ drop_last: bool = False
289
+ ) -> None:
290
+ if not isinstance(sampler, Sampler):
291
+ raise TypeError('sampler should be an instance of ``Sampler``, '
292
+ f'but got {sampler}')
293
+ if not isinstance(batch_size, int) or batch_size <= 0:
294
+ raise ValueError('batch_size should be a positive integer value, '
295
+ f'but got batch_size={batch_size}')
296
+ self.sampler = sampler
297
+ self.dataset = dataset
298
+ self.train_folder = train_folder
299
+ self.batch_size = batch_size
300
+ self.aspect_ratios = aspect_ratios
301
+ self.drop_last = drop_last
302
+
303
+ # buckets for each aspect ratio
304
+ self.current_available_bucket_keys = list(aspect_ratios.keys())
305
+ self.bucket = {
306
+ 'image':{ratio: [] for ratio in aspect_ratios},
307
+ 'video':{ratio: [] for ratio in aspect_ratios}
308
+ }
309
+
310
+ def __iter__(self):
311
+ for idx in self.sampler:
312
+ content_type = self.dataset[idx].get('type', 'image')
313
+ if content_type == 'image':
314
+ try:
315
+ image_dict = self.dataset[idx]
316
+
317
+ width, height = image_dict.get("width", None), image_dict.get("height", None)
318
+ if width is None or height is None:
319
+ image_id, name = image_dict['file_path'], image_dict['text']
320
+ if self.train_folder is None:
321
+ image_dir = image_id
322
+ else:
323
+ image_dir = os.path.join(self.train_folder, image_id)
324
+
325
+ width, height = get_image_size_without_loading(image_dir)
326
+
327
+ ratio = height / width # self.dataset[idx]
328
+ else:
329
+ height = int(height)
330
+ width = int(width)
331
+ ratio = height / width # self.dataset[idx]
332
+ except Exception as e:
333
+ print(e)
334
+ continue
335
+ # find the closest aspect ratio
336
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
337
+ if closest_ratio not in self.current_available_bucket_keys:
338
+ continue
339
+ bucket = self.bucket['image'][closest_ratio]
340
+ bucket.append(idx)
341
+ # yield a batch of indices in the same aspect ratio group
342
+ if len(bucket) == self.batch_size:
343
+ yield bucket[:]
344
+ del bucket[:]
345
+ else:
346
+ try:
347
+ video_dict = self.dataset[idx]
348
+ width, height = video_dict.get("width", None), video_dict.get("height", None)
349
+
350
+ if width is None or height is None:
351
+ video_id, name = video_dict['file_path'], video_dict['text']
352
+ if self.train_folder is None:
353
+ video_dir = video_id
354
+ else:
355
+ video_dir = os.path.join(self.train_folder, video_id)
356
+ cap = cv2.VideoCapture(video_dir)
357
+
358
+ # 获取视频尺寸
359
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数
360
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数
361
+
362
+ ratio = height / width # self.dataset[idx]
363
+ else:
364
+ height = int(height)
365
+ width = int(width)
366
+ ratio = height / width # self.dataset[idx]
367
+ except Exception as e:
368
+ print(e)
369
+ continue
370
+ # find the closest aspect ratio
371
+ closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
372
+ if closest_ratio not in self.current_available_bucket_keys:
373
+ continue
374
+ bucket = self.bucket['video'][closest_ratio]
375
+ bucket.append(idx)
376
+ # yield a batch of indices in the same aspect ratio group
377
+ if len(bucket) == self.batch_size:
378
+ yield bucket[:]
379
+ del bucket[:]
cogvideox/data/dataset_image.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torchvision.transforms as transforms
8
+ from PIL import Image
9
+ from torch.utils.data.dataset import Dataset
10
+
11
+
12
+ class CC15M(Dataset):
13
+ def __init__(
14
+ self,
15
+ json_path,
16
+ video_folder=None,
17
+ resolution=512,
18
+ enable_bucket=False,
19
+ ):
20
+ print(f"loading annotations from {json_path} ...")
21
+ self.dataset = json.load(open(json_path, 'r'))
22
+ self.length = len(self.dataset)
23
+ print(f"data scale: {self.length}")
24
+
25
+ self.enable_bucket = enable_bucket
26
+ self.video_folder = video_folder
27
+
28
+ resolution = tuple(resolution) if not isinstance(resolution, int) else (resolution, resolution)
29
+ self.pixel_transforms = transforms.Compose([
30
+ transforms.Resize(resolution[0]),
31
+ transforms.CenterCrop(resolution),
32
+ transforms.ToTensor(),
33
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
34
+ ])
35
+
36
+ def get_batch(self, idx):
37
+ video_dict = self.dataset[idx]
38
+ video_id, name = video_dict['file_path'], video_dict['text']
39
+
40
+ if self.video_folder is None:
41
+ video_dir = video_id
42
+ else:
43
+ video_dir = os.path.join(self.video_folder, video_id)
44
+
45
+ pixel_values = Image.open(video_dir).convert("RGB")
46
+ return pixel_values, name
47
+
48
+ def __len__(self):
49
+ return self.length
50
+
51
+ def __getitem__(self, idx):
52
+ while True:
53
+ try:
54
+ pixel_values, name = self.get_batch(idx)
55
+ break
56
+ except Exception as e:
57
+ print(e)
58
+ idx = random.randint(0, self.length-1)
59
+
60
+ if not self.enable_bucket:
61
+ pixel_values = self.pixel_transforms(pixel_values)
62
+ else:
63
+ pixel_values = np.array(pixel_values)
64
+
65
+ sample = dict(pixel_values=pixel_values, text=name)
66
+ return sample
67
+
68
+ if __name__ == "__main__":
69
+ dataset = CC15M(
70
+ csv_path="/mnt_wg/zhoumo.xjq/CCUtils/cc15m_add_index.json",
71
+ resolution=512,
72
+ )
73
+
74
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
75
+ for idx, batch in enumerate(dataloader):
76
+ print(batch["pixel_values"].shape, len(batch["text"]))
cogvideox/data/dataset_image_video.py ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import io
3
+ import json
4
+ import math
5
+ import os
6
+ import random
7
+ from threading import Thread
8
+
9
+ import albumentations
10
+ import cv2
11
+ import gc
12
+ import numpy as np
13
+ import torch
14
+ import torchvision.transforms as transforms
15
+
16
+ from func_timeout import func_timeout, FunctionTimedOut
17
+ from decord import VideoReader
18
+ from PIL import Image
19
+ from torch.utils.data import BatchSampler, Sampler
20
+ from torch.utils.data.dataset import Dataset
21
+ from contextlib import contextmanager
22
+
23
+ VIDEO_READER_TIMEOUT = 20
24
+
25
+ def get_random_mask(shape, image_start_only=False):
26
+ f, c, h, w = shape
27
+ mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
28
+
29
+ if not image_start_only:
30
+ if f != 1:
31
+ mask_index = np.random.choice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], p=[0.05, 0.2, 0.2, 0.2, 0.05, 0.05, 0.05, 0.1, 0.05, 0.05])
32
+ else:
33
+ mask_index = np.random.choice([0, 1], p = [0.2, 0.8])
34
+ if mask_index == 0:
35
+ center_x = torch.randint(0, w, (1,)).item()
36
+ center_y = torch.randint(0, h, (1,)).item()
37
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
38
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
39
+
40
+ start_x = max(center_x - block_size_x // 2, 0)
41
+ end_x = min(center_x + block_size_x // 2, w)
42
+ start_y = max(center_y - block_size_y // 2, 0)
43
+ end_y = min(center_y + block_size_y // 2, h)
44
+ mask[:, :, start_y:end_y, start_x:end_x] = 1
45
+ elif mask_index == 1:
46
+ mask[:, :, :, :] = 1
47
+ elif mask_index == 2:
48
+ mask_frame_index = np.random.randint(1, 5)
49
+ mask[mask_frame_index:, :, :, :] = 1
50
+ elif mask_index == 3:
51
+ mask_frame_index = np.random.randint(1, 5)
52
+ mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
53
+ elif mask_index == 4:
54
+ center_x = torch.randint(0, w, (1,)).item()
55
+ center_y = torch.randint(0, h, (1,)).item()
56
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
57
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
58
+
59
+ start_x = max(center_x - block_size_x // 2, 0)
60
+ end_x = min(center_x + block_size_x // 2, w)
61
+ start_y = max(center_y - block_size_y // 2, 0)
62
+ end_y = min(center_y + block_size_y // 2, h)
63
+
64
+ mask_frame_before = np.random.randint(0, f // 2)
65
+ mask_frame_after = np.random.randint(f // 2, f)
66
+ mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
67
+ elif mask_index == 5:
68
+ mask = torch.randint(0, 2, (f, 1, h, w), dtype=torch.uint8)
69
+ elif mask_index == 6:
70
+ num_frames_to_mask = random.randint(1, max(f // 2, 1))
71
+ frames_to_mask = random.sample(range(f), num_frames_to_mask)
72
+
73
+ for i in frames_to_mask:
74
+ block_height = random.randint(1, h // 4)
75
+ block_width = random.randint(1, w // 4)
76
+ top_left_y = random.randint(0, h - block_height)
77
+ top_left_x = random.randint(0, w - block_width)
78
+ mask[i, 0, top_left_y:top_left_y + block_height, top_left_x:top_left_x + block_width] = 1
79
+ elif mask_index == 7:
80
+ center_x = torch.randint(0, w, (1,)).item()
81
+ center_y = torch.randint(0, h, (1,)).item()
82
+ a = torch.randint(min(w, h) // 8, min(w, h) // 4, (1,)).item() # 长半轴
83
+ b = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item() # 短半轴
84
+
85
+ for i in range(h):
86
+ for j in range(w):
87
+ if ((i - center_y) ** 2) / (b ** 2) + ((j - center_x) ** 2) / (a ** 2) < 1:
88
+ mask[:, :, i, j] = 1
89
+ elif mask_index == 8:
90
+ center_x = torch.randint(0, w, (1,)).item()
91
+ center_y = torch.randint(0, h, (1,)).item()
92
+ radius = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item()
93
+ for i in range(h):
94
+ for j in range(w):
95
+ if (i - center_y) ** 2 + (j - center_x) ** 2 < radius ** 2:
96
+ mask[:, :, i, j] = 1
97
+ elif mask_index == 9:
98
+ for idx in range(f):
99
+ if np.random.rand() > 0.5:
100
+ mask[idx, :, :, :] = 1
101
+ else:
102
+ raise ValueError(f"The mask_index {mask_index} is not define")
103
+ else:
104
+ if f != 1:
105
+ mask[1:, :, :, :] = 1
106
+ else:
107
+ mask[:, :, :, :] = 1
108
+ return mask
109
+
110
+ class ImageVideoSampler(BatchSampler):
111
+ """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
112
+
113
+ Args:
114
+ sampler (Sampler): Base sampler.
115
+ dataset (Dataset): Dataset providing data information.
116
+ batch_size (int): Size of mini-batch.
117
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
118
+ its size would be less than ``batch_size``.
119
+ aspect_ratios (dict): The predefined aspect ratios.
120
+ """
121
+
122
+ def __init__(self,
123
+ sampler: Sampler,
124
+ dataset: Dataset,
125
+ batch_size: int,
126
+ drop_last: bool = False
127
+ ) -> None:
128
+ if not isinstance(sampler, Sampler):
129
+ raise TypeError('sampler should be an instance of ``Sampler``, '
130
+ f'but got {sampler}')
131
+ if not isinstance(batch_size, int) or batch_size <= 0:
132
+ raise ValueError('batch_size should be a positive integer value, '
133
+ f'but got batch_size={batch_size}')
134
+ self.sampler = sampler
135
+ self.dataset = dataset
136
+ self.batch_size = batch_size
137
+ self.drop_last = drop_last
138
+
139
+ # buckets for each aspect ratio
140
+ self.bucket = {'image':[], 'video':[]}
141
+
142
+ def __iter__(self):
143
+ for idx in self.sampler:
144
+ content_type = self.dataset.dataset[idx].get('type', 'image')
145
+ self.bucket[content_type].append(idx)
146
+
147
+ # yield a batch of indices in the same aspect ratio group
148
+ if len(self.bucket['video']) == self.batch_size:
149
+ bucket = self.bucket['video']
150
+ yield bucket[:]
151
+ del bucket[:]
152
+ elif len(self.bucket['image']) == self.batch_size:
153
+ bucket = self.bucket['image']
154
+ yield bucket[:]
155
+ del bucket[:]
156
+
157
+ @contextmanager
158
+ def VideoReader_contextmanager(*args, **kwargs):
159
+ vr = VideoReader(*args, **kwargs)
160
+ try:
161
+ yield vr
162
+ finally:
163
+ del vr
164
+ gc.collect()
165
+
166
+ def get_video_reader_batch(video_reader, batch_index):
167
+ frames = video_reader.get_batch(batch_index).asnumpy()
168
+ return frames
169
+
170
+ def resize_frame(frame, target_short_side):
171
+ h, w, _ = frame.shape
172
+ if h < w:
173
+ if target_short_side > h:
174
+ return frame
175
+ new_h = target_short_side
176
+ new_w = int(target_short_side * w / h)
177
+ else:
178
+ if target_short_side > w:
179
+ return frame
180
+ new_w = target_short_side
181
+ new_h = int(target_short_side * h / w)
182
+
183
+ resized_frame = cv2.resize(frame, (new_w, new_h))
184
+ return resized_frame
185
+
186
+ class ImageVideoDataset(Dataset):
187
+ def __init__(
188
+ self,
189
+ ann_path, data_root=None,
190
+ video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
191
+ image_sample_size=512,
192
+ video_repeat=0,
193
+ text_drop_ratio=0.1,
194
+ enable_bucket=False,
195
+ video_length_drop_start=0.0,
196
+ video_length_drop_end=1.0,
197
+ enable_inpaint=False,
198
+ ):
199
+ # Loading annotations from files
200
+ print(f"loading annotations from {ann_path} ...")
201
+ if ann_path.endswith('.csv'):
202
+ with open(ann_path, 'r') as csvfile:
203
+ dataset = list(csv.DictReader(csvfile))
204
+ elif ann_path.endswith('.json'):
205
+ dataset = json.load(open(ann_path))
206
+
207
+ self.data_root = data_root
208
+
209
+ # It's used to balance num of images and videos.
210
+ self.dataset = []
211
+ for data in dataset:
212
+ if data.get('type', 'image') != 'video':
213
+ self.dataset.append(data)
214
+ if video_repeat > 0:
215
+ for _ in range(video_repeat):
216
+ for data in dataset:
217
+ if data.get('type', 'image') == 'video':
218
+ self.dataset.append(data)
219
+ del dataset
220
+
221
+ self.length = len(self.dataset)
222
+ print(f"data scale: {self.length}")
223
+ # TODO: enable bucket training
224
+ self.enable_bucket = enable_bucket
225
+ self.text_drop_ratio = text_drop_ratio
226
+ self.enable_inpaint = enable_inpaint
227
+
228
+ self.video_length_drop_start = video_length_drop_start
229
+ self.video_length_drop_end = video_length_drop_end
230
+
231
+ # Video params
232
+ self.video_sample_stride = video_sample_stride
233
+ self.video_sample_n_frames = video_sample_n_frames
234
+ self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
235
+ self.video_transforms = transforms.Compose(
236
+ [
237
+ transforms.Resize(min(self.video_sample_size)),
238
+ transforms.CenterCrop(self.video_sample_size),
239
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
240
+ ]
241
+ )
242
+
243
+ # Image params
244
+ self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
245
+ self.image_transforms = transforms.Compose([
246
+ transforms.Resize(min(self.image_sample_size)),
247
+ transforms.CenterCrop(self.image_sample_size),
248
+ transforms.ToTensor(),
249
+ transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
250
+ ])
251
+
252
+ self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
253
+
254
+ def get_batch(self, idx):
255
+ data_info = self.dataset[idx % len(self.dataset)]
256
+
257
+ if data_info.get('type', 'image')=='video':
258
+ video_id, text = data_info['file_path'], data_info['text']
259
+
260
+ if self.data_root is None:
261
+ video_dir = video_id
262
+ else:
263
+ video_dir = os.path.join(self.data_root, video_id)
264
+
265
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
266
+ min_sample_n_frames = min(
267
+ self.video_sample_n_frames,
268
+ int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
269
+ )
270
+ if min_sample_n_frames == 0:
271
+ raise ValueError(f"No Frames in video.")
272
+
273
+ video_length = int(self.video_length_drop_end * len(video_reader))
274
+ clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
275
+ start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
276
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
277
+
278
+ try:
279
+ sample_args = (video_reader, batch_index)
280
+ pixel_values = func_timeout(
281
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
282
+ )
283
+ resized_frames = []
284
+ for i in range(len(pixel_values)):
285
+ frame = pixel_values[i]
286
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
287
+ resized_frames.append(resized_frame)
288
+ pixel_values = np.array(resized_frames)
289
+ except FunctionTimedOut:
290
+ raise ValueError(f"Read {idx} timeout.")
291
+ except Exception as e:
292
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
293
+
294
+ if not self.enable_bucket:
295
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
296
+ pixel_values = pixel_values / 255.
297
+ del video_reader
298
+ else:
299
+ pixel_values = pixel_values
300
+
301
+ if not self.enable_bucket:
302
+ pixel_values = self.video_transforms(pixel_values)
303
+
304
+ # Random use no text generation
305
+ if random.random() < self.text_drop_ratio:
306
+ text = ''
307
+ return pixel_values, text, 'video'
308
+ else:
309
+ image_path, text = data_info['file_path'], data_info['text']
310
+ if self.data_root is not None:
311
+ image_path = os.path.join(self.data_root, image_path)
312
+ image = Image.open(image_path).convert('RGB')
313
+ if not self.enable_bucket:
314
+ image = self.image_transforms(image).unsqueeze(0)
315
+ else:
316
+ image = np.expand_dims(np.array(image), 0)
317
+ if random.random() < self.text_drop_ratio:
318
+ text = ''
319
+ return image, text, 'image'
320
+
321
+ def __len__(self):
322
+ return self.length
323
+
324
+ def __getitem__(self, idx):
325
+ data_info = self.dataset[idx % len(self.dataset)]
326
+ data_type = data_info.get('type', 'image')
327
+ while True:
328
+ sample = {}
329
+ try:
330
+ data_info_local = self.dataset[idx % len(self.dataset)]
331
+ data_type_local = data_info_local.get('type', 'image')
332
+ if data_type_local != data_type:
333
+ raise ValueError("data_type_local != data_type")
334
+
335
+ pixel_values, name, data_type = self.get_batch(idx)
336
+ sample["pixel_values"] = pixel_values
337
+ sample["text"] = name
338
+ sample["data_type"] = data_type
339
+ sample["idx"] = idx
340
+
341
+ if len(sample) > 0:
342
+ break
343
+ except Exception as e:
344
+ print(e, self.dataset[idx % len(self.dataset)])
345
+ idx = random.randint(0, self.length-1)
346
+
347
+ if self.enable_inpaint and not self.enable_bucket:
348
+ mask = get_random_mask(pixel_values.size())
349
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
350
+ sample["mask_pixel_values"] = mask_pixel_values
351
+ sample["mask"] = mask
352
+
353
+ clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
354
+ clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
355
+ sample["clip_pixel_values"] = clip_pixel_values
356
+
357
+ ref_pixel_values = sample["pixel_values"][0].unsqueeze(0)
358
+ if (mask == 1).all():
359
+ ref_pixel_values = torch.ones_like(ref_pixel_values) * -1
360
+ sample["ref_pixel_values"] = ref_pixel_values
361
+
362
+ return sample
363
+
364
+
365
+ class ImageVideoControlDataset(Dataset):
366
+ def __init__(
367
+ self,
368
+ ann_path, data_root=None,
369
+ video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
370
+ image_sample_size=512,
371
+ video_repeat=0,
372
+ text_drop_ratio=0.1,
373
+ enable_bucket=False,
374
+ video_length_drop_start=0.0,
375
+ video_length_drop_end=1.0,
376
+ enable_inpaint=False,
377
+ ):
378
+ # Loading annotations from files
379
+ print(f"loading annotations from {ann_path} ...")
380
+ if ann_path.endswith('.csv'):
381
+ with open(ann_path, 'r') as csvfile:
382
+ dataset = list(csv.DictReader(csvfile))
383
+ elif ann_path.endswith('.json'):
384
+ dataset = json.load(open(ann_path))
385
+
386
+ self.data_root = data_root
387
+
388
+ # It's used to balance num of images and videos.
389
+ self.dataset = []
390
+ for data in dataset:
391
+ if data.get('type', 'image') != 'video':
392
+ self.dataset.append(data)
393
+ if video_repeat > 0:
394
+ for _ in range(video_repeat):
395
+ for data in dataset:
396
+ if data.get('type', 'image') == 'video':
397
+ self.dataset.append(data)
398
+ del dataset
399
+
400
+ self.length = len(self.dataset)
401
+ print(f"data scale: {self.length}")
402
+ # TODO: enable bucket training
403
+ self.enable_bucket = enable_bucket
404
+ self.text_drop_ratio = text_drop_ratio
405
+ self.enable_inpaint = enable_inpaint
406
+
407
+ self.video_length_drop_start = video_length_drop_start
408
+ self.video_length_drop_end = video_length_drop_end
409
+
410
+ # Video params
411
+ self.video_sample_stride = video_sample_stride
412
+ self.video_sample_n_frames = video_sample_n_frames
413
+ self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
414
+ self.video_transforms = transforms.Compose(
415
+ [
416
+ transforms.Resize(min(self.video_sample_size)),
417
+ transforms.CenterCrop(self.video_sample_size),
418
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
419
+ ]
420
+ )
421
+
422
+ # Image params
423
+ self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
424
+ self.image_transforms = transforms.Compose([
425
+ transforms.Resize(min(self.image_sample_size)),
426
+ transforms.CenterCrop(self.image_sample_size),
427
+ transforms.ToTensor(),
428
+ transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
429
+ ])
430
+
431
+ self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
432
+
433
+ def get_batch(self, idx):
434
+ data_info = self.dataset[idx % len(self.dataset)]
435
+ video_id, text = data_info['file_path'], data_info['text']
436
+
437
+ if data_info.get('type', 'image')=='video':
438
+ if self.data_root is None:
439
+ video_dir = video_id
440
+ else:
441
+ video_dir = os.path.join(self.data_root, video_id)
442
+
443
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
444
+ min_sample_n_frames = min(
445
+ self.video_sample_n_frames,
446
+ int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
447
+ )
448
+ if min_sample_n_frames == 0:
449
+ raise ValueError(f"No Frames in video.")
450
+
451
+ video_length = int(self.video_length_drop_end * len(video_reader))
452
+ clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
453
+ start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
454
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
455
+
456
+ try:
457
+ sample_args = (video_reader, batch_index)
458
+ pixel_values = func_timeout(
459
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
460
+ )
461
+ resized_frames = []
462
+ for i in range(len(pixel_values)):
463
+ frame = pixel_values[i]
464
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
465
+ resized_frames.append(resized_frame)
466
+ pixel_values = np.array(resized_frames)
467
+ except FunctionTimedOut:
468
+ raise ValueError(f"Read {idx} timeout.")
469
+ except Exception as e:
470
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
471
+
472
+ if not self.enable_bucket:
473
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
474
+ pixel_values = pixel_values / 255.
475
+ del video_reader
476
+ else:
477
+ pixel_values = pixel_values
478
+
479
+ if not self.enable_bucket:
480
+ pixel_values = self.video_transforms(pixel_values)
481
+
482
+ # Random use no text generation
483
+ if random.random() < self.text_drop_ratio:
484
+ text = ''
485
+
486
+ control_video_id = data_info['control_file_path']
487
+
488
+ if self.data_root is None:
489
+ control_video_id = control_video_id
490
+ else:
491
+ control_video_id = os.path.join(self.data_root, control_video_id)
492
+
493
+ with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
494
+ try:
495
+ sample_args = (control_video_reader, batch_index)
496
+ control_pixel_values = func_timeout(
497
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
498
+ )
499
+ resized_frames = []
500
+ for i in range(len(control_pixel_values)):
501
+ frame = control_pixel_values[i]
502
+ resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
503
+ resized_frames.append(resized_frame)
504
+ control_pixel_values = np.array(resized_frames)
505
+ except FunctionTimedOut:
506
+ raise ValueError(f"Read {idx} timeout.")
507
+ except Exception as e:
508
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
509
+
510
+ if not self.enable_bucket:
511
+ control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
512
+ control_pixel_values = control_pixel_values / 255.
513
+ del control_video_reader
514
+ else:
515
+ control_pixel_values = control_pixel_values
516
+
517
+ if not self.enable_bucket:
518
+ control_pixel_values = self.video_transforms(control_pixel_values)
519
+ return pixel_values, control_pixel_values, text, "video"
520
+ else:
521
+ image_path, text = data_info['file_path'], data_info['text']
522
+ if self.data_root is not None:
523
+ image_path = os.path.join(self.data_root, image_path)
524
+ image = Image.open(image_path).convert('RGB')
525
+ if not self.enable_bucket:
526
+ image = self.image_transforms(image).unsqueeze(0)
527
+ else:
528
+ image = np.expand_dims(np.array(image), 0)
529
+
530
+ if random.random() < self.text_drop_ratio:
531
+ text = ''
532
+
533
+ control_image_id = data_info['control_file_path']
534
+
535
+ if self.data_root is None:
536
+ control_image_id = control_image_id
537
+ else:
538
+ control_image_id = os.path.join(self.data_root, control_image_id)
539
+
540
+ control_image = Image.open(control_image_id).convert('RGB')
541
+ if not self.enable_bucket:
542
+ control_image = self.image_transforms(control_image).unsqueeze(0)
543
+ else:
544
+ control_image = np.expand_dims(np.array(control_image), 0)
545
+ return image, control_image, text, 'image'
546
+
547
+ def __len__(self):
548
+ return self.length
549
+
550
+ def __getitem__(self, idx):
551
+ data_info = self.dataset[idx % len(self.dataset)]
552
+ data_type = data_info.get('type', 'image')
553
+ while True:
554
+ sample = {}
555
+ try:
556
+ data_info_local = self.dataset[idx % len(self.dataset)]
557
+ data_type_local = data_info_local.get('type', 'image')
558
+ if data_type_local != data_type:
559
+ raise ValueError("data_type_local != data_type")
560
+
561
+ pixel_values, control_pixel_values, name, data_type = self.get_batch(idx)
562
+ sample["pixel_values"] = pixel_values
563
+ sample["control_pixel_values"] = control_pixel_values
564
+ sample["text"] = name
565
+ sample["data_type"] = data_type
566
+ sample["idx"] = idx
567
+
568
+ if len(sample) > 0:
569
+ break
570
+ except Exception as e:
571
+ print(e, self.dataset[idx % len(self.dataset)])
572
+ idx = random.randint(0, self.length-1)
573
+
574
+ if self.enable_inpaint and not self.enable_bucket:
575
+ mask = get_random_mask(pixel_values.size())
576
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
577
+ sample["mask_pixel_values"] = mask_pixel_values
578
+ sample["mask"] = mask
579
+
580
+ clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
581
+ clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
582
+ sample["clip_pixel_values"] = clip_pixel_values
583
+
584
+ ref_pixel_values = sample["pixel_values"][0].unsqueeze(0)
585
+ if (mask == 1).all():
586
+ ref_pixel_values = torch.ones_like(ref_pixel_values) * -1
587
+ sample["ref_pixel_values"] = ref_pixel_values
588
+
589
+ return sample
cogvideox/data/dataset_video.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import gc
3
+ import io
4
+ import json
5
+ import math
6
+ import os
7
+ import random
8
+ from contextlib import contextmanager
9
+ from threading import Thread
10
+
11
+ import albumentations
12
+ import cv2
13
+ import numpy as np
14
+ import torch
15
+ import torchvision.transforms as transforms
16
+ from decord import VideoReader
17
+ from einops import rearrange
18
+ from func_timeout import FunctionTimedOut, func_timeout
19
+ from PIL import Image
20
+ from torch.utils.data import BatchSampler, Sampler
21
+ from torch.utils.data.dataset import Dataset
22
+
23
+ VIDEO_READER_TIMEOUT = 20
24
+
25
+ def get_random_mask(shape):
26
+ f, c, h, w = shape
27
+
28
+ mask_index = np.random.randint(0, 4)
29
+ mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
30
+ if mask_index == 0:
31
+ mask[1:, :, :, :] = 1
32
+ elif mask_index == 1:
33
+ mask_frame_index = 1
34
+ mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
35
+ elif mask_index == 2:
36
+ center_x = torch.randint(0, w, (1,)).item()
37
+ center_y = torch.randint(0, h, (1,)).item()
38
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
39
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
40
+
41
+ start_x = max(center_x - block_size_x // 2, 0)
42
+ end_x = min(center_x + block_size_x // 2, w)
43
+ start_y = max(center_y - block_size_y // 2, 0)
44
+ end_y = min(center_y + block_size_y // 2, h)
45
+ mask[:, :, start_y:end_y, start_x:end_x] = 1
46
+ elif mask_index == 3:
47
+ center_x = torch.randint(0, w, (1,)).item()
48
+ center_y = torch.randint(0, h, (1,)).item()
49
+ block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
50
+ block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
51
+
52
+ start_x = max(center_x - block_size_x // 2, 0)
53
+ end_x = min(center_x + block_size_x // 2, w)
54
+ start_y = max(center_y - block_size_y // 2, 0)
55
+ end_y = min(center_y + block_size_y // 2, h)
56
+
57
+ mask_frame_before = np.random.randint(0, f // 2)
58
+ mask_frame_after = np.random.randint(f // 2, f)
59
+ mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
60
+ else:
61
+ raise ValueError(f"The mask_index {mask_index} is not define")
62
+ return mask
63
+
64
+
65
+ @contextmanager
66
+ def VideoReader_contextmanager(*args, **kwargs):
67
+ vr = VideoReader(*args, **kwargs)
68
+ try:
69
+ yield vr
70
+ finally:
71
+ del vr
72
+ gc.collect()
73
+
74
+
75
+ def get_video_reader_batch(video_reader, batch_index):
76
+ frames = video_reader.get_batch(batch_index).asnumpy()
77
+ return frames
78
+
79
+
80
+ class WebVid10M(Dataset):
81
+ def __init__(
82
+ self,
83
+ csv_path, video_folder,
84
+ sample_size=256, sample_stride=4, sample_n_frames=16,
85
+ enable_bucket=False, enable_inpaint=False, is_image=False,
86
+ ):
87
+ print(f"loading annotations from {csv_path} ...")
88
+ with open(csv_path, 'r') as csvfile:
89
+ self.dataset = list(csv.DictReader(csvfile))
90
+ self.length = len(self.dataset)
91
+ print(f"data scale: {self.length}")
92
+
93
+ self.video_folder = video_folder
94
+ self.sample_stride = sample_stride
95
+ self.sample_n_frames = sample_n_frames
96
+ self.enable_bucket = enable_bucket
97
+ self.enable_inpaint = enable_inpaint
98
+ self.is_image = is_image
99
+
100
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
101
+ self.pixel_transforms = transforms.Compose([
102
+ transforms.Resize(sample_size[0]),
103
+ transforms.CenterCrop(sample_size),
104
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
105
+ ])
106
+
107
+ def get_batch(self, idx):
108
+ video_dict = self.dataset[idx]
109
+ videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
110
+
111
+ video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
112
+ video_reader = VideoReader(video_dir)
113
+ video_length = len(video_reader)
114
+
115
+ if not self.is_image:
116
+ clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
117
+ start_idx = random.randint(0, video_length - clip_length)
118
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
119
+ else:
120
+ batch_index = [random.randint(0, video_length - 1)]
121
+
122
+ if not self.enable_bucket:
123
+ pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
124
+ pixel_values = pixel_values / 255.
125
+ del video_reader
126
+ else:
127
+ pixel_values = video_reader.get_batch(batch_index).asnumpy()
128
+
129
+ if self.is_image:
130
+ pixel_values = pixel_values[0]
131
+ return pixel_values, name
132
+
133
+ def __len__(self):
134
+ return self.length
135
+
136
+ def __getitem__(self, idx):
137
+ while True:
138
+ try:
139
+ pixel_values, name = self.get_batch(idx)
140
+ break
141
+
142
+ except Exception as e:
143
+ print("Error info:", e)
144
+ idx = random.randint(0, self.length-1)
145
+
146
+ if not self.enable_bucket:
147
+ pixel_values = self.pixel_transforms(pixel_values)
148
+ if self.enable_inpaint:
149
+ mask = get_random_mask(pixel_values.size())
150
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
151
+ sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name)
152
+ else:
153
+ sample = dict(pixel_values=pixel_values, text=name)
154
+ return sample
155
+
156
+
157
+ class VideoDataset(Dataset):
158
+ def __init__(
159
+ self,
160
+ json_path, video_folder=None,
161
+ sample_size=256, sample_stride=4, sample_n_frames=16,
162
+ enable_bucket=False, enable_inpaint=False
163
+ ):
164
+ print(f"loading annotations from {json_path} ...")
165
+ self.dataset = json.load(open(json_path, 'r'))
166
+ self.length = len(self.dataset)
167
+ print(f"data scale: {self.length}")
168
+
169
+ self.video_folder = video_folder
170
+ self.sample_stride = sample_stride
171
+ self.sample_n_frames = sample_n_frames
172
+ self.enable_bucket = enable_bucket
173
+ self.enable_inpaint = enable_inpaint
174
+
175
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
176
+ self.pixel_transforms = transforms.Compose(
177
+ [
178
+ transforms.Resize(sample_size[0]),
179
+ transforms.CenterCrop(sample_size),
180
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
181
+ ]
182
+ )
183
+
184
+ def get_batch(self, idx):
185
+ video_dict = self.dataset[idx]
186
+ video_id, name = video_dict['file_path'], video_dict['text']
187
+
188
+ if self.video_folder is None:
189
+ video_dir = video_id
190
+ else:
191
+ video_dir = os.path.join(self.video_folder, video_id)
192
+
193
+ with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
194
+ video_length = len(video_reader)
195
+
196
+ clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
197
+ start_idx = random.randint(0, video_length - clip_length)
198
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
199
+
200
+ try:
201
+ sample_args = (video_reader, batch_index)
202
+ pixel_values = func_timeout(
203
+ VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
204
+ )
205
+ except FunctionTimedOut:
206
+ raise ValueError(f"Read {idx} timeout.")
207
+ except Exception as e:
208
+ raise ValueError(f"Failed to extract frames from video. Error is {e}.")
209
+
210
+ if not self.enable_bucket:
211
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
212
+ pixel_values = pixel_values / 255.
213
+ del video_reader
214
+ else:
215
+ pixel_values = pixel_values
216
+
217
+ return pixel_values, name
218
+
219
+ def __len__(self):
220
+ return self.length
221
+
222
+ def __getitem__(self, idx):
223
+ while True:
224
+ try:
225
+ pixel_values, name = self.get_batch(idx)
226
+ break
227
+
228
+ except Exception as e:
229
+ print("Error info:", e)
230
+ idx = random.randint(0, self.length-1)
231
+
232
+ if not self.enable_bucket:
233
+ pixel_values = self.pixel_transforms(pixel_values)
234
+ if self.enable_inpaint:
235
+ mask = get_random_mask(pixel_values.size())
236
+ mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
237
+ sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name)
238
+ else:
239
+ sample = dict(pixel_values=pixel_values, text=name)
240
+ return sample
241
+
242
+
243
+ if __name__ == "__main__":
244
+ if 1:
245
+ dataset = VideoDataset(
246
+ json_path="/home/zhoumo.xjq/disk3/datasets/webvidval/results_2M_val.json",
247
+ sample_size=256,
248
+ sample_stride=4, sample_n_frames=16,
249
+ )
250
+
251
+ if 0:
252
+ dataset = WebVid10M(
253
+ csv_path="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv",
254
+ video_folder="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val",
255
+ sample_size=256,
256
+ sample_stride=4, sample_n_frames=16,
257
+ is_image=False,
258
+ )
259
+
260
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
261
+ for idx, batch in enumerate(dataloader):
262
+ print(batch["pixel_values"].shape, len(batch["text"]))
cogvideox/models/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer
2
+
3
+ from .cogvideox_transformer3d import CogVideoXTransformer3DModel
4
+ from .cogvideox_vae import AutoencoderKLCogVideoX
5
+ from .wan_image_encoder import CLIPModel
6
+ from .wan_text_encoder import WanT5EncoderModel
7
+ from .wan_transformer3d import WanTransformer3DModel
8
+ from .wan_vae import AutoencoderKLWan
cogvideox/models/cogvideox_transformer3d.py ADDED
@@ -0,0 +1,797 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import glob
17
+ import json
18
+ import os
19
+ from typing import Any, Dict, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.models.attention import Attention, FeedForward
25
+ from diffusers.models.attention_processor import (
26
+ AttentionProcessor, CogVideoXAttnProcessor2_0,
27
+ FusedCogVideoXAttnProcessor2_0)
28
+ from diffusers.models.embeddings import (CogVideoXPatchEmbed,
29
+ TimestepEmbedding, Timesteps,
30
+ get_2d_sincos_pos_embed,
31
+ get_3d_sincos_pos_embed)
32
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
33
+ from diffusers.models.modeling_utils import ModelMixin
34
+ from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
35
+ from diffusers.utils import is_torch_version, logging
36
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
37
+ from torch import nn
38
+
39
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
+
41
+ class CogVideoXPatchEmbed(nn.Module):
42
+ def __init__(
43
+ self,
44
+ patch_size: int = 2,
45
+ patch_size_t: Optional[int] = None,
46
+ in_channels: int = 16,
47
+ embed_dim: int = 1920,
48
+ text_embed_dim: int = 4096,
49
+ bias: bool = True,
50
+ sample_width: int = 90,
51
+ sample_height: int = 60,
52
+ sample_frames: int = 49,
53
+ temporal_compression_ratio: int = 4,
54
+ max_text_seq_length: int = 226,
55
+ spatial_interpolation_scale: float = 1.875,
56
+ temporal_interpolation_scale: float = 1.0,
57
+ use_positional_embeddings: bool = True,
58
+ use_learned_positional_embeddings: bool = True,
59
+ ) -> None:
60
+ super().__init__()
61
+
62
+ post_patch_height = sample_height // patch_size
63
+ post_patch_width = sample_width // patch_size
64
+ post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
65
+ self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
66
+ self.post_patch_height = post_patch_height
67
+ self.post_patch_width = post_patch_width
68
+ self.post_time_compression_frames = post_time_compression_frames
69
+ self.patch_size = patch_size
70
+ self.patch_size_t = patch_size_t
71
+ self.embed_dim = embed_dim
72
+ self.sample_height = sample_height
73
+ self.sample_width = sample_width
74
+ self.sample_frames = sample_frames
75
+ self.temporal_compression_ratio = temporal_compression_ratio
76
+ self.max_text_seq_length = max_text_seq_length
77
+ self.spatial_interpolation_scale = spatial_interpolation_scale
78
+ self.temporal_interpolation_scale = temporal_interpolation_scale
79
+ self.use_positional_embeddings = use_positional_embeddings
80
+ self.use_learned_positional_embeddings = use_learned_positional_embeddings
81
+
82
+ if patch_size_t is None:
83
+ # CogVideoX 1.0 checkpoints
84
+ self.proj = nn.Conv2d(
85
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
86
+ )
87
+ else:
88
+ # CogVideoX 1.5 checkpoints
89
+ self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim)
90
+
91
+ self.text_proj = nn.Linear(text_embed_dim, embed_dim)
92
+
93
+ if use_positional_embeddings or use_learned_positional_embeddings:
94
+ persistent = use_learned_positional_embeddings
95
+ pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
96
+ self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
97
+
98
+ def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
99
+ post_patch_height = sample_height // self.patch_size
100
+ post_patch_width = sample_width // self.patch_size
101
+ post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
102
+ num_patches = post_patch_height * post_patch_width * post_time_compression_frames
103
+
104
+ pos_embedding = get_3d_sincos_pos_embed(
105
+ self.embed_dim,
106
+ (post_patch_width, post_patch_height),
107
+ post_time_compression_frames,
108
+ self.spatial_interpolation_scale,
109
+ self.temporal_interpolation_scale,
110
+ )
111
+ pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
112
+ joint_pos_embedding = torch.zeros(
113
+ 1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
114
+ )
115
+ joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
116
+
117
+ return joint_pos_embedding
118
+
119
+ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
120
+ r"""
121
+ Args:
122
+ text_embeds (`torch.Tensor`):
123
+ Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
124
+ image_embeds (`torch.Tensor`):
125
+ Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
126
+ """
127
+ text_embeds = self.text_proj(text_embeds)
128
+
129
+ text_batch_size, text_seq_length, text_channels = text_embeds.shape
130
+ batch_size, num_frames, channels, height, width = image_embeds.shape
131
+
132
+ if self.patch_size_t is None:
133
+ image_embeds = image_embeds.reshape(-1, channels, height, width)
134
+ image_embeds = self.proj(image_embeds)
135
+ image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
136
+ image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
137
+ image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
138
+ else:
139
+ p = self.patch_size
140
+ p_t = self.patch_size_t
141
+
142
+ image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
143
+ # b, f, h, w, c => b, f // 2, 2, h // 2, 2, w // 2, 2, c
144
+ image_embeds = image_embeds.reshape(
145
+ batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
146
+ )
147
+ # b, f // 2, 2, h // 2, 2, w // 2, 2, c => b, f // 2, h // 2, w // 2, c, 2, 2, 2
148
+ image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
149
+ image_embeds = self.proj(image_embeds)
150
+
151
+ embeds = torch.cat(
152
+ [text_embeds, image_embeds], dim=1
153
+ ).contiguous() # [batch, seq_length + num_frames x height x width, channels]
154
+
155
+ if self.use_positional_embeddings or self.use_learned_positional_embeddings:
156
+ seq_length = height * width * num_frames // (self.patch_size**2)
157
+ # pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
158
+ pos_embeds = self.pos_embedding
159
+ emb_size = embeds.size()[-1]
160
+ pos_embeds_without_text = pos_embeds[:, text_seq_length: ].view(1, self.post_time_compression_frames, self.post_patch_height, self.post_patch_width, emb_size)
161
+ pos_embeds_without_text = pos_embeds_without_text.permute([0, 4, 1, 2, 3])
162
+ pos_embeds_without_text = F.interpolate(pos_embeds_without_text,size=[self.post_time_compression_frames, height // self.patch_size, width // self.patch_size], mode='trilinear', align_corners=False)
163
+ pos_embeds_without_text = pos_embeds_without_text.permute([0, 2, 3, 4, 1]).view(1, -1, emb_size)
164
+ pos_embeds = torch.cat([pos_embeds[:, :text_seq_length], pos_embeds_without_text], dim = 1)
165
+ pos_embeds = pos_embeds[:, : text_seq_length + seq_length]
166
+ embeds = embeds + pos_embeds
167
+
168
+ return embeds
169
+
170
+ @maybe_allow_in_graph
171
+ class CogVideoXBlock(nn.Module):
172
+ r"""
173
+ Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
174
+
175
+ Parameters:
176
+ dim (`int`):
177
+ The number of channels in the input and output.
178
+ num_attention_heads (`int`):
179
+ The number of heads to use for multi-head attention.
180
+ attention_head_dim (`int`):
181
+ The number of channels in each head.
182
+ time_embed_dim (`int`):
183
+ The number of channels in timestep embedding.
184
+ dropout (`float`, defaults to `0.0`):
185
+ The dropout probability to use.
186
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
187
+ Activation function to be used in feed-forward.
188
+ attention_bias (`bool`, defaults to `False`):
189
+ Whether or not to use bias in attention projection layers.
190
+ qk_norm (`bool`, defaults to `True`):
191
+ Whether or not to use normalization after query and key projections in Attention.
192
+ norm_elementwise_affine (`bool`, defaults to `True`):
193
+ Whether to use learnable elementwise affine parameters for normalization.
194
+ norm_eps (`float`, defaults to `1e-5`):
195
+ Epsilon value for normalization layers.
196
+ final_dropout (`bool` defaults to `False`):
197
+ Whether to apply a final dropout after the last feed-forward layer.
198
+ ff_inner_dim (`int`, *optional*, defaults to `None`):
199
+ Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
200
+ ff_bias (`bool`, defaults to `True`):
201
+ Whether or not to use bias in Feed-forward layer.
202
+ attention_out_bias (`bool`, defaults to `True`):
203
+ Whether or not to use bias in Attention output projection layer.
204
+ """
205
+
206
+ def __init__(
207
+ self,
208
+ dim: int,
209
+ num_attention_heads: int,
210
+ attention_head_dim: int,
211
+ time_embed_dim: int,
212
+ dropout: float = 0.0,
213
+ activation_fn: str = "gelu-approximate",
214
+ attention_bias: bool = False,
215
+ qk_norm: bool = True,
216
+ norm_elementwise_affine: bool = True,
217
+ norm_eps: float = 1e-5,
218
+ final_dropout: bool = True,
219
+ ff_inner_dim: Optional[int] = None,
220
+ ff_bias: bool = True,
221
+ attention_out_bias: bool = True,
222
+ ):
223
+ super().__init__()
224
+
225
+ # 1. Self Attention
226
+ self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
227
+
228
+ self.attn1 = Attention(
229
+ query_dim=dim,
230
+ dim_head=attention_head_dim,
231
+ heads=num_attention_heads,
232
+ qk_norm="layer_norm" if qk_norm else None,
233
+ eps=1e-6,
234
+ bias=attention_bias,
235
+ out_bias=attention_out_bias,
236
+ processor=CogVideoXAttnProcessor2_0(),
237
+ )
238
+
239
+ # 2. Feed Forward
240
+ self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
241
+
242
+ self.ff = FeedForward(
243
+ dim,
244
+ dropout=dropout,
245
+ activation_fn=activation_fn,
246
+ final_dropout=final_dropout,
247
+ inner_dim=ff_inner_dim,
248
+ bias=ff_bias,
249
+ )
250
+
251
+ def forward(
252
+ self,
253
+ hidden_states: torch.Tensor,
254
+ encoder_hidden_states: torch.Tensor,
255
+ temb: torch.Tensor,
256
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
257
+ ) -> torch.Tensor:
258
+ text_seq_length = encoder_hidden_states.size(1)
259
+
260
+ # norm & modulate
261
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
262
+ hidden_states, encoder_hidden_states, temb
263
+ )
264
+
265
+ # attention
266
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
267
+ hidden_states=norm_hidden_states,
268
+ encoder_hidden_states=norm_encoder_hidden_states,
269
+ image_rotary_emb=image_rotary_emb,
270
+ )
271
+
272
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
273
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
274
+
275
+ # norm & modulate
276
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
277
+ hidden_states, encoder_hidden_states, temb
278
+ )
279
+
280
+ # feed-forward
281
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
282
+ ff_output = self.ff(norm_hidden_states)
283
+
284
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
285
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
286
+
287
+ return hidden_states, encoder_hidden_states
288
+
289
+
290
+ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
291
+ """
292
+ A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
293
+
294
+ Parameters:
295
+ num_attention_heads (`int`, defaults to `30`):
296
+ The number of heads to use for multi-head attention.
297
+ attention_head_dim (`int`, defaults to `64`):
298
+ The number of channels in each head.
299
+ in_channels (`int`, defaults to `16`):
300
+ The number of channels in the input.
301
+ out_channels (`int`, *optional*, defaults to `16`):
302
+ The number of channels in the output.
303
+ flip_sin_to_cos (`bool`, defaults to `True`):
304
+ Whether to flip the sin to cos in the time embedding.
305
+ time_embed_dim (`int`, defaults to `512`):
306
+ Output dimension of timestep embeddings.
307
+ text_embed_dim (`int`, defaults to `4096`):
308
+ Input dimension of text embeddings from the text encoder.
309
+ num_layers (`int`, defaults to `30`):
310
+ The number of layers of Transformer blocks to use.
311
+ dropout (`float`, defaults to `0.0`):
312
+ The dropout probability to use.
313
+ attention_bias (`bool`, defaults to `True`):
314
+ Whether or not to use bias in the attention projection layers.
315
+ sample_width (`int`, defaults to `90`):
316
+ The width of the input latents.
317
+ sample_height (`int`, defaults to `60`):
318
+ The height of the input latents.
319
+ sample_frames (`int`, defaults to `49`):
320
+ The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
321
+ instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
322
+ but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
323
+ K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
324
+ patch_size (`int`, defaults to `2`):
325
+ The size of the patches to use in the patch embedding layer.
326
+ temporal_compression_ratio (`int`, defaults to `4`):
327
+ The compression ratio across the temporal dimension. See documentation for `sample_frames`.
328
+ max_text_seq_length (`int`, defaults to `226`):
329
+ The maximum sequence length of the input text embeddings.
330
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
331
+ Activation function to use in feed-forward.
332
+ timestep_activation_fn (`str`, defaults to `"silu"`):
333
+ Activation function to use when generating the timestep embeddings.
334
+ norm_elementwise_affine (`bool`, defaults to `True`):
335
+ Whether or not to use elementwise affine in normalization layers.
336
+ norm_eps (`float`, defaults to `1e-5`):
337
+ The epsilon value to use in normalization layers.
338
+ spatial_interpolation_scale (`float`, defaults to `1.875`):
339
+ Scaling factor to apply in 3D positional embeddings across spatial dimensions.
340
+ temporal_interpolation_scale (`float`, defaults to `1.0`):
341
+ Scaling factor to apply in 3D positional embeddings across temporal dimensions.
342
+ """
343
+
344
+ _supports_gradient_checkpointing = True
345
+
346
+ @register_to_config
347
+ def __init__(
348
+ self,
349
+ num_attention_heads: int = 30,
350
+ attention_head_dim: int = 64,
351
+ in_channels: int = 16,
352
+ out_channels: Optional[int] = 16,
353
+ flip_sin_to_cos: bool = True,
354
+ freq_shift: int = 0,
355
+ time_embed_dim: int = 512,
356
+ text_embed_dim: int = 4096,
357
+ num_layers: int = 30,
358
+ dropout: float = 0.0,
359
+ attention_bias: bool = True,
360
+ sample_width: int = 90,
361
+ sample_height: int = 60,
362
+ sample_frames: int = 49,
363
+ patch_size: int = 2,
364
+ patch_size_t: Optional[int] = None,
365
+ temporal_compression_ratio: int = 4,
366
+ max_text_seq_length: int = 226,
367
+ activation_fn: str = "gelu-approximate",
368
+ timestep_activation_fn: str = "silu",
369
+ norm_elementwise_affine: bool = True,
370
+ norm_eps: float = 1e-5,
371
+ spatial_interpolation_scale: float = 1.875,
372
+ temporal_interpolation_scale: float = 1.0,
373
+ use_rotary_positional_embeddings: bool = False,
374
+ use_learned_positional_embeddings: bool = False,
375
+ patch_bias: bool = True,
376
+ add_noise_in_inpaint_model: bool = False,
377
+ ):
378
+ super().__init__()
379
+ inner_dim = num_attention_heads * attention_head_dim
380
+ self.patch_size_t = patch_size_t
381
+ if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
382
+ raise ValueError(
383
+ "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
384
+ "embeddings. If you're using a custom model and/or believe this should be supported, please open an "
385
+ "issue at https://github.com/huggingface/diffusers/issues."
386
+ )
387
+
388
+ # 1. Patch embedding
389
+ self.patch_embed = CogVideoXPatchEmbed(
390
+ patch_size=patch_size,
391
+ patch_size_t=patch_size_t,
392
+ in_channels=in_channels,
393
+ embed_dim=inner_dim,
394
+ text_embed_dim=text_embed_dim,
395
+ bias=patch_bias,
396
+ sample_width=sample_width,
397
+ sample_height=sample_height,
398
+ sample_frames=sample_frames,
399
+ temporal_compression_ratio=temporal_compression_ratio,
400
+ max_text_seq_length=max_text_seq_length,
401
+ spatial_interpolation_scale=spatial_interpolation_scale,
402
+ temporal_interpolation_scale=temporal_interpolation_scale,
403
+ use_positional_embeddings=not use_rotary_positional_embeddings,
404
+ use_learned_positional_embeddings=use_learned_positional_embeddings,
405
+ )
406
+ self.embedding_dropout = nn.Dropout(dropout)
407
+
408
+ # 2. Time embeddings
409
+ self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
410
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
411
+
412
+ # 3. Define spatio-temporal transformers blocks
413
+ self.transformer_blocks = nn.ModuleList(
414
+ [
415
+ CogVideoXBlock(
416
+ dim=inner_dim,
417
+ num_attention_heads=num_attention_heads,
418
+ attention_head_dim=attention_head_dim,
419
+ time_embed_dim=time_embed_dim,
420
+ dropout=dropout,
421
+ activation_fn=activation_fn,
422
+ attention_bias=attention_bias,
423
+ norm_elementwise_affine=norm_elementwise_affine,
424
+ norm_eps=norm_eps,
425
+ )
426
+ for _ in range(num_layers)
427
+ ]
428
+ )
429
+ self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
430
+
431
+ # 4. Output blocks
432
+ self.norm_out = AdaLayerNorm(
433
+ embedding_dim=time_embed_dim,
434
+ output_dim=2 * inner_dim,
435
+ norm_elementwise_affine=norm_elementwise_affine,
436
+ norm_eps=norm_eps,
437
+ chunk_dim=1,
438
+ )
439
+
440
+ if patch_size_t is None:
441
+ # For CogVideox 1.0
442
+ output_dim = patch_size * patch_size * out_channels
443
+ else:
444
+ # For CogVideoX 1.5
445
+ output_dim = patch_size * patch_size * patch_size_t * out_channels
446
+
447
+ self.proj_out = nn.Linear(inner_dim, output_dim)
448
+
449
+ self.gradient_checkpointing = False
450
+
451
+ def _set_gradient_checkpointing(self, module, value=False):
452
+ self.gradient_checkpointing = value
453
+
454
+ @property
455
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
456
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
457
+ r"""
458
+ Returns:
459
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
460
+ indexed by its weight name.
461
+ """
462
+ # set recursively
463
+ processors = {}
464
+
465
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
466
+ if hasattr(module, "get_processor"):
467
+ processors[f"{name}.processor"] = module.get_processor()
468
+
469
+ for sub_name, child in module.named_children():
470
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
471
+
472
+ return processors
473
+
474
+ for name, module in self.named_children():
475
+ fn_recursive_add_processors(name, module, processors)
476
+
477
+ return processors
478
+
479
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
480
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
481
+ r"""
482
+ Sets the attention processor to use to compute attention.
483
+
484
+ Parameters:
485
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
486
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
487
+ for **all** `Attention` layers.
488
+
489
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
490
+ processor. This is strongly recommended when setting trainable attention processors.
491
+
492
+ """
493
+ count = len(self.attn_processors.keys())
494
+
495
+ if isinstance(processor, dict) and len(processor) != count:
496
+ raise ValueError(
497
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
498
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
499
+ )
500
+
501
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
502
+ if hasattr(module, "set_processor"):
503
+ if not isinstance(processor, dict):
504
+ module.set_processor(processor)
505
+ else:
506
+ module.set_processor(processor.pop(f"{name}.processor"))
507
+
508
+ for sub_name, child in module.named_children():
509
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
510
+
511
+ for name, module in self.named_children():
512
+ fn_recursive_attn_processor(name, module, processor)
513
+
514
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
515
+ def fuse_qkv_projections(self):
516
+ """
517
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
518
+ are fused. For cross-attention modules, key and value projection matrices are fused.
519
+
520
+ <Tip warning={true}>
521
+
522
+ This API is 🧪 experimental.
523
+
524
+ </Tip>
525
+ """
526
+ self.original_attn_processors = None
527
+
528
+ for _, attn_processor in self.attn_processors.items():
529
+ if "Added" in str(attn_processor.__class__.__name__):
530
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
531
+
532
+ self.original_attn_processors = self.attn_processors
533
+
534
+ for module in self.modules():
535
+ if isinstance(module, Attention):
536
+ module.fuse_projections(fuse=True)
537
+
538
+ self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
539
+
540
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
541
+ def unfuse_qkv_projections(self):
542
+ """Disables the fused QKV projection if enabled.
543
+
544
+ <Tip warning={true}>
545
+
546
+ This API is 🧪 experimental.
547
+
548
+ </Tip>
549
+
550
+ """
551
+ if self.original_attn_processors is not None:
552
+ self.set_attn_processor(self.original_attn_processors)
553
+
554
+ def forward(
555
+ self,
556
+ hidden_states: torch.Tensor,
557
+ encoder_hidden_states: torch.Tensor,
558
+ timestep: Union[int, float, torch.LongTensor],
559
+ timestep_cond: Optional[torch.Tensor] = None,
560
+ inpaint_latents: Optional[torch.Tensor] = None,
561
+ control_latents: Optional[torch.Tensor] = None,
562
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
563
+ return_dict: bool = True,
564
+ ):
565
+ batch_size, num_frames, channels, height, width = hidden_states.shape
566
+ if num_frames == 1 and self.patch_size_t is not None:
567
+ hidden_states = torch.cat([hidden_states, torch.zeros_like(hidden_states)], dim=1)
568
+ if inpaint_latents is not None:
569
+ inpaint_latents = torch.concat([inpaint_latents, torch.zeros_like(inpaint_latents)], dim=1)
570
+ if control_latents is not None:
571
+ control_latents = torch.concat([control_latents, torch.zeros_like(control_latents)], dim=1)
572
+ local_num_frames = num_frames + 1
573
+ else:
574
+ local_num_frames = num_frames
575
+
576
+ # 1. Time embedding
577
+ timesteps = timestep
578
+ t_emb = self.time_proj(timesteps)
579
+
580
+ # timesteps does not contain any weights and will always return f32 tensors
581
+ # but time_embedding might actually be running in fp16. so we need to cast here.
582
+ # there might be better ways to encapsulate this.
583
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
584
+ emb = self.time_embedding(t_emb, timestep_cond)
585
+
586
+ # 2. Patch embedding
587
+ if inpaint_latents is not None:
588
+ hidden_states = torch.concat([hidden_states, inpaint_latents], 2)
589
+ if control_latents is not None:
590
+ hidden_states = torch.concat([hidden_states, control_latents], 2)
591
+ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
592
+ hidden_states = self.embedding_dropout(hidden_states)
593
+
594
+ text_seq_length = encoder_hidden_states.shape[1]
595
+ encoder_hidden_states = hidden_states[:, :text_seq_length]
596
+ hidden_states = hidden_states[:, text_seq_length:]
597
+
598
+ # 3. Transformer blocks
599
+ for i, block in enumerate(self.transformer_blocks):
600
+ if self.training and self.gradient_checkpointing:
601
+
602
+ def create_custom_forward(module):
603
+ def custom_forward(*inputs):
604
+ return module(*inputs)
605
+
606
+ return custom_forward
607
+
608
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
609
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
610
+ create_custom_forward(block),
611
+ hidden_states,
612
+ encoder_hidden_states,
613
+ emb,
614
+ image_rotary_emb,
615
+ **ckpt_kwargs,
616
+ )
617
+ else:
618
+ hidden_states, encoder_hidden_states = block(
619
+ hidden_states=hidden_states,
620
+ encoder_hidden_states=encoder_hidden_states,
621
+ temb=emb,
622
+ image_rotary_emb=image_rotary_emb,
623
+ )
624
+
625
+ if not self.config.use_rotary_positional_embeddings:
626
+ # CogVideoX-2B
627
+ hidden_states = self.norm_final(hidden_states)
628
+ else:
629
+ # CogVideoX-5B
630
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
631
+ hidden_states = self.norm_final(hidden_states)
632
+ hidden_states = hidden_states[:, text_seq_length:]
633
+
634
+ # 4. Final block
635
+ hidden_states = self.norm_out(hidden_states, temb=emb)
636
+ hidden_states = self.proj_out(hidden_states)
637
+
638
+ # 5. Unpatchify
639
+ p = self.config.patch_size
640
+ p_t = self.config.patch_size_t
641
+
642
+ if p_t is None:
643
+ output = hidden_states.reshape(batch_size, local_num_frames, height // p, width // p, -1, p, p)
644
+ output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
645
+ else:
646
+ output = hidden_states.reshape(
647
+ batch_size, (local_num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
648
+ )
649
+ output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
650
+
651
+ if num_frames == 1:
652
+ output = output[:, :num_frames, :]
653
+
654
+ if not return_dict:
655
+ return (output,)
656
+ return Transformer2DModelOutput(sample=output)
657
+
658
+ @classmethod
659
+ def from_pretrained(
660
+ cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
661
+ low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
662
+ ):
663
+ if subfolder is not None:
664
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
665
+ print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
666
+
667
+ config_file = os.path.join(pretrained_model_path, 'config.json')
668
+ if not os.path.isfile(config_file):
669
+ raise RuntimeError(f"{config_file} does not exist")
670
+ with open(config_file, "r") as f:
671
+ config = json.load(f)
672
+
673
+ from diffusers.utils import WEIGHTS_NAME
674
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
675
+ model_file_safetensors = model_file.replace(".bin", ".safetensors")
676
+
677
+ if "dict_mapping" in transformer_additional_kwargs.keys():
678
+ for key in transformer_additional_kwargs["dict_mapping"]:
679
+ transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key]
680
+
681
+ if low_cpu_mem_usage:
682
+ try:
683
+ import re
684
+ from diffusers.utils import is_accelerate_available
685
+ from diffusers.models.modeling_utils import load_model_dict_into_meta
686
+ if is_accelerate_available():
687
+ import accelerate
688
+
689
+ # Instantiate model with empty weights
690
+ with accelerate.init_empty_weights():
691
+ model = cls.from_config(config, **transformer_additional_kwargs)
692
+
693
+ param_device = "cpu"
694
+ if os.path.exists(model_file):
695
+ state_dict = torch.load(model_file, map_location="cpu")
696
+ elif os.path.exists(model_file_safetensors):
697
+ from safetensors.torch import load_file, safe_open
698
+ state_dict = load_file(model_file_safetensors)
699
+ else:
700
+ from safetensors.torch import load_file, safe_open
701
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
702
+ state_dict = {}
703
+ for _model_file_safetensors in model_files_safetensors:
704
+ _state_dict = load_file(_model_file_safetensors)
705
+ for key in _state_dict:
706
+ state_dict[key] = _state_dict[key]
707
+ model._convert_deprecated_attention_blocks(state_dict)
708
+ # move the params from meta device to cpu
709
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
710
+ if len(missing_keys) > 0:
711
+ raise ValueError(
712
+ f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
713
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
714
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
715
+ " those weights or else make sure your checkpoint file is correct."
716
+ )
717
+
718
+ unexpected_keys = load_model_dict_into_meta(
719
+ model,
720
+ state_dict,
721
+ device=param_device,
722
+ dtype=torch_dtype,
723
+ model_name_or_path=pretrained_model_path,
724
+ )
725
+
726
+ if cls._keys_to_ignore_on_load_unexpected is not None:
727
+ for pat in cls._keys_to_ignore_on_load_unexpected:
728
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
729
+
730
+ if len(unexpected_keys) > 0:
731
+ print(
732
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
733
+ )
734
+ return model
735
+ except Exception as e:
736
+ print(
737
+ f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
738
+ )
739
+
740
+ model = cls.from_config(config, **transformer_additional_kwargs)
741
+ if os.path.exists(model_file):
742
+ state_dict = torch.load(model_file, map_location="cpu")
743
+ elif os.path.exists(model_file_safetensors):
744
+ from safetensors.torch import load_file, safe_open
745
+ state_dict = load_file(model_file_safetensors)
746
+ else:
747
+ from safetensors.torch import load_file, safe_open
748
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
749
+ state_dict = {}
750
+ for _model_file_safetensors in model_files_safetensors:
751
+ _state_dict = load_file(_model_file_safetensors)
752
+ for key in _state_dict:
753
+ state_dict[key] = _state_dict[key]
754
+
755
+ if model.state_dict()['patch_embed.proj.weight'].size() != state_dict['patch_embed.proj.weight'].size():
756
+ new_shape = model.state_dict()['patch_embed.proj.weight'].size()
757
+ if len(new_shape) == 5:
758
+ state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
759
+ state_dict['patch_embed.proj.weight'][:, :, :-1] = 0
760
+ elif len(new_shape) == 2:
761
+ if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
762
+ model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1]] = state_dict['patch_embed.proj.weight']
763
+ model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:] = 0
764
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
765
+ else:
766
+ model.state_dict()['patch_embed.proj.weight'][:, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1]]
767
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
768
+ else:
769
+ if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
770
+ model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1], :, :] = state_dict['patch_embed.proj.weight']
771
+ model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:, :, :] = 0
772
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
773
+ else:
774
+ model.state_dict()['patch_embed.proj.weight'][:, :, :, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1], :, :]
775
+ state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
776
+
777
+ tmp_state_dict = {}
778
+ for key in state_dict:
779
+ if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
780
+ tmp_state_dict[key] = state_dict[key]
781
+ else:
782
+ print(key, "Size don't match, skip")
783
+
784
+ state_dict = tmp_state_dict
785
+
786
+ m, u = model.load_state_dict(state_dict, strict=False)
787
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
788
+ print(m)
789
+
790
+ params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
791
+ print(f"### All Parameters: {sum(params) / 1e6} M")
792
+
793
+ params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
794
+ print(f"### attn1 Parameters: {sum(params) / 1e6} M")
795
+
796
+ model = model.to(torch_dtype)
797
+ return model
cogvideox/models/cogvideox_vae.py ADDED
@@ -0,0 +1,1675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Dict, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ import json
23
+ import os
24
+
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
27
+ from diffusers.utils import logging
28
+ from diffusers.utils.accelerate_utils import apply_forward_hook
29
+ from diffusers.models.activations import get_activation
30
+ from diffusers.models.downsampling import CogVideoXDownsample3D
31
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
32
+ from diffusers.models.modeling_utils import ModelMixin
33
+ from diffusers.models.upsampling import CogVideoXUpsample3D
34
+ from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
35
+
36
+
37
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
+
39
+
40
+ class CogVideoXSafeConv3d(nn.Conv3d):
41
+ r"""
42
+ A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
43
+ """
44
+
45
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
46
+ memory_count = (
47
+ (input.shape[0] * input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3
48
+ )
49
+
50
+ # Set to 2GB, suitable for CuDNN
51
+ if memory_count > 2:
52
+ kernel_size = self.kernel_size[0]
53
+ part_num = int(memory_count / 2) + 1
54
+ input_chunks = torch.chunk(input, part_num, dim=2)
55
+
56
+ if kernel_size > 1:
57
+ input_chunks = [input_chunks[0]] + [
58
+ torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2)
59
+ for i in range(1, len(input_chunks))
60
+ ]
61
+
62
+ output_chunks = []
63
+ for input_chunk in input_chunks:
64
+ output_chunks.append(super().forward(input_chunk))
65
+ output = torch.cat(output_chunks, dim=2)
66
+ return output
67
+ else:
68
+ return super().forward(input)
69
+
70
+
71
+ class CogVideoXCausalConv3d(nn.Module):
72
+ r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
73
+
74
+ Args:
75
+ in_channels (`int`): Number of channels in the input tensor.
76
+ out_channels (`int`): Number of output channels produced by the convolution.
77
+ kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
78
+ stride (`int`, defaults to `1`): Stride of the convolution.
79
+ dilation (`int`, defaults to `1`): Dilation rate of the convolution.
80
+ pad_mode (`str`, defaults to `"constant"`): Padding mode.
81
+ """
82
+
83
+ def __init__(
84
+ self,
85
+ in_channels: int,
86
+ out_channels: int,
87
+ kernel_size: Union[int, Tuple[int, int, int]],
88
+ stride: int = 1,
89
+ dilation: int = 1,
90
+ pad_mode: str = "constant",
91
+ ):
92
+ super().__init__()
93
+
94
+ if isinstance(kernel_size, int):
95
+ kernel_size = (kernel_size,) * 3
96
+
97
+ time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
98
+
99
+ # TODO(aryan): configure calculation based on stride and dilation in the future.
100
+ # Since CogVideoX does not use it, it is currently tailored to "just work" with Mochi
101
+ time_pad = time_kernel_size - 1
102
+ height_pad = (height_kernel_size - 1) // 2
103
+ width_pad = (width_kernel_size - 1) // 2
104
+
105
+ self.pad_mode = pad_mode
106
+ self.height_pad = height_pad
107
+ self.width_pad = width_pad
108
+ self.time_pad = time_pad
109
+ self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
110
+
111
+ self.temporal_dim = 2
112
+ self.time_kernel_size = time_kernel_size
113
+
114
+ stride = stride if isinstance(stride, tuple) else (stride, 1, 1)
115
+ dilation = (dilation, 1, 1)
116
+ self.conv = CogVideoXSafeConv3d(
117
+ in_channels=in_channels,
118
+ out_channels=out_channels,
119
+ kernel_size=kernel_size,
120
+ stride=stride,
121
+ dilation=dilation,
122
+ )
123
+
124
+ def fake_context_parallel_forward(
125
+ self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
126
+ ) -> torch.Tensor:
127
+ if self.pad_mode == "replicate":
128
+ inputs = F.pad(inputs, self.time_causal_padding, mode="replicate")
129
+ else:
130
+ kernel_size = self.time_kernel_size
131
+ if kernel_size > 1:
132
+ cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
133
+ inputs = torch.cat(cached_inputs + [inputs], dim=2)
134
+ return inputs
135
+
136
+ def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
137
+ inputs = self.fake_context_parallel_forward(inputs, conv_cache)
138
+
139
+ if self.pad_mode == "replicate":
140
+ conv_cache = None
141
+ else:
142
+ padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
143
+ conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
144
+ inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
145
+
146
+ output = self.conv(inputs)
147
+ return output, conv_cache
148
+
149
+
150
+ class CogVideoXSpatialNorm3D(nn.Module):
151
+ r"""
152
+ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific
153
+ to 3D-video like data.
154
+
155
+ CogVideoXSafeConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model.
156
+
157
+ Args:
158
+ f_channels (`int`):
159
+ The number of channels for input to group normalization layer, and output of the spatial norm layer.
160
+ zq_channels (`int`):
161
+ The number of channels for the quantized vector as described in the paper.
162
+ groups (`int`):
163
+ Number of groups to separate the channels into for group normalization.
164
+ """
165
+
166
+ def __init__(
167
+ self,
168
+ f_channels: int,
169
+ zq_channels: int,
170
+ groups: int = 32,
171
+ ):
172
+ super().__init__()
173
+ self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
174
+ self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
175
+ self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
176
+
177
+ def forward(
178
+ self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
179
+ ) -> torch.Tensor:
180
+ new_conv_cache = {}
181
+ conv_cache = conv_cache or {}
182
+
183
+ if f.shape[2] > 1 and f.shape[2] % 2 == 1:
184
+ f_first, f_rest = f[:, :, :1], f[:, :, 1:]
185
+ f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
186
+ z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
187
+ z_first = F.interpolate(z_first, size=f_first_size)
188
+ z_rest = F.interpolate(z_rest, size=f_rest_size)
189
+ zq = torch.cat([z_first, z_rest], dim=2)
190
+ else:
191
+ zq = F.interpolate(zq, size=f.shape[-3:])
192
+
193
+ conv_y, new_conv_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y"))
194
+ conv_b, new_conv_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b"))
195
+
196
+ norm_f = self.norm_layer(f)
197
+ new_f = norm_f * conv_y + conv_b
198
+ return new_f, new_conv_cache
199
+
200
+
201
+ class CogVideoXUpsample3D(nn.Module):
202
+ r"""
203
+ A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
204
+
205
+ Args:
206
+ in_channels (`int`):
207
+ Number of channels in the input image.
208
+ out_channels (`int`):
209
+ Number of channels produced by the convolution.
210
+ kernel_size (`int`, defaults to `3`):
211
+ Size of the convolving kernel.
212
+ stride (`int`, defaults to `1`):
213
+ Stride of the convolution.
214
+ padding (`int`, defaults to `1`):
215
+ Padding added to all four sides of the input.
216
+ compress_time (`bool`, defaults to `False`):
217
+ Whether or not to compress the time dimension.
218
+ """
219
+
220
+ def __init__(
221
+ self,
222
+ in_channels: int,
223
+ out_channels: int,
224
+ kernel_size: int = 3,
225
+ stride: int = 1,
226
+ padding: int = 1,
227
+ compress_time: bool = False,
228
+ ) -> None:
229
+ super().__init__()
230
+
231
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
232
+ self.compress_time = compress_time
233
+
234
+ self.auto_split_process = True
235
+ self.first_frame_flag = False
236
+
237
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
238
+ if self.compress_time:
239
+ if self.auto_split_process:
240
+ if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
241
+ # split first frame
242
+ x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
243
+
244
+ x_first = F.interpolate(x_first, scale_factor=2.0)
245
+ x_rest = F.interpolate(x_rest, scale_factor=2.0)
246
+ x_first = x_first[:, :, None, :, :]
247
+ inputs = torch.cat([x_first, x_rest], dim=2)
248
+ elif inputs.shape[2] > 1:
249
+ inputs = F.interpolate(inputs, scale_factor=2.0)
250
+ else:
251
+ inputs = inputs.squeeze(2)
252
+ inputs = F.interpolate(inputs, scale_factor=2.0)
253
+ inputs = inputs[:, :, None, :, :]
254
+ else:
255
+ if self.first_frame_flag:
256
+ inputs = inputs.squeeze(2)
257
+ inputs = F.interpolate(inputs, scale_factor=2.0)
258
+ inputs = inputs[:, :, None, :, :]
259
+ else:
260
+ inputs = F.interpolate(inputs, scale_factor=2.0)
261
+ else:
262
+ # only interpolate 2D
263
+ b, c, t, h, w = inputs.shape
264
+ inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
265
+ inputs = F.interpolate(inputs, scale_factor=2.0)
266
+ inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
267
+
268
+ b, c, t, h, w = inputs.shape
269
+ inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
270
+ inputs = self.conv(inputs)
271
+ inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
272
+
273
+ return inputs
274
+
275
+
276
+ class CogVideoXResnetBlock3D(nn.Module):
277
+ r"""
278
+ A 3D ResNet block used in the CogVideoX model.
279
+
280
+ Args:
281
+ in_channels (`int`):
282
+ Number of input channels.
283
+ out_channels (`int`, *optional*):
284
+ Number of output channels. If None, defaults to `in_channels`.
285
+ dropout (`float`, defaults to `0.0`):
286
+ Dropout rate.
287
+ temb_channels (`int`, defaults to `512`):
288
+ Number of time embedding channels.
289
+ groups (`int`, defaults to `32`):
290
+ Number of groups to separate the channels into for group normalization.
291
+ eps (`float`, defaults to `1e-6`):
292
+ Epsilon value for normalization layers.
293
+ non_linearity (`str`, defaults to `"swish"`):
294
+ Activation function to use.
295
+ conv_shortcut (bool, defaults to `False`):
296
+ Whether or not to use a convolution shortcut.
297
+ spatial_norm_dim (`int`, *optional*):
298
+ The dimension to use for spatial norm if it is to be used instead of group norm.
299
+ pad_mode (str, defaults to `"first"`):
300
+ Padding mode.
301
+ """
302
+
303
+ def __init__(
304
+ self,
305
+ in_channels: int,
306
+ out_channels: Optional[int] = None,
307
+ dropout: float = 0.0,
308
+ temb_channels: int = 512,
309
+ groups: int = 32,
310
+ eps: float = 1e-6,
311
+ non_linearity: str = "swish",
312
+ conv_shortcut: bool = False,
313
+ spatial_norm_dim: Optional[int] = None,
314
+ pad_mode: str = "first",
315
+ ):
316
+ super().__init__()
317
+
318
+ out_channels = out_channels or in_channels
319
+
320
+ self.in_channels = in_channels
321
+ self.out_channels = out_channels
322
+ self.nonlinearity = get_activation(non_linearity)
323
+ self.use_conv_shortcut = conv_shortcut
324
+ self.spatial_norm_dim = spatial_norm_dim
325
+
326
+ if spatial_norm_dim is None:
327
+ self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
328
+ self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
329
+ else:
330
+ self.norm1 = CogVideoXSpatialNorm3D(
331
+ f_channels=in_channels,
332
+ zq_channels=spatial_norm_dim,
333
+ groups=groups,
334
+ )
335
+ self.norm2 = CogVideoXSpatialNorm3D(
336
+ f_channels=out_channels,
337
+ zq_channels=spatial_norm_dim,
338
+ groups=groups,
339
+ )
340
+
341
+ self.conv1 = CogVideoXCausalConv3d(
342
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
343
+ )
344
+
345
+ if temb_channels > 0:
346
+ self.temb_proj = nn.Linear(in_features=temb_channels, out_features=out_channels)
347
+
348
+ self.dropout = nn.Dropout(dropout)
349
+ self.conv2 = CogVideoXCausalConv3d(
350
+ in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
351
+ )
352
+
353
+ if self.in_channels != self.out_channels:
354
+ if self.use_conv_shortcut:
355
+ self.conv_shortcut = CogVideoXCausalConv3d(
356
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
357
+ )
358
+ else:
359
+ self.conv_shortcut = CogVideoXSafeConv3d(
360
+ in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0
361
+ )
362
+
363
+ def forward(
364
+ self,
365
+ inputs: torch.Tensor,
366
+ temb: Optional[torch.Tensor] = None,
367
+ zq: Optional[torch.Tensor] = None,
368
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
369
+ ) -> torch.Tensor:
370
+ new_conv_cache = {}
371
+ conv_cache = conv_cache or {}
372
+
373
+ hidden_states = inputs
374
+
375
+ if zq is not None:
376
+ hidden_states, new_conv_cache["norm1"] = self.norm1(hidden_states, zq, conv_cache=conv_cache.get("norm1"))
377
+ else:
378
+ hidden_states = self.norm1(hidden_states)
379
+
380
+ hidden_states = self.nonlinearity(hidden_states)
381
+ hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1"))
382
+
383
+ if temb is not None:
384
+ hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
385
+
386
+ if zq is not None:
387
+ hidden_states, new_conv_cache["norm2"] = self.norm2(hidden_states, zq, conv_cache=conv_cache.get("norm2"))
388
+ else:
389
+ hidden_states = self.norm2(hidden_states)
390
+
391
+ hidden_states = self.nonlinearity(hidden_states)
392
+ hidden_states = self.dropout(hidden_states)
393
+ hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2"))
394
+
395
+ if self.in_channels != self.out_channels:
396
+ if self.use_conv_shortcut:
397
+ inputs, new_conv_cache["conv_shortcut"] = self.conv_shortcut(
398
+ inputs, conv_cache=conv_cache.get("conv_shortcut")
399
+ )
400
+ else:
401
+ inputs = self.conv_shortcut(inputs)
402
+
403
+ hidden_states = hidden_states + inputs
404
+ return hidden_states, new_conv_cache
405
+
406
+
407
+ class CogVideoXDownBlock3D(nn.Module):
408
+ r"""
409
+ A downsampling block used in the CogVideoX model.
410
+
411
+ Args:
412
+ in_channels (`int`):
413
+ Number of input channels.
414
+ out_channels (`int`, *optional*):
415
+ Number of output channels. If None, defaults to `in_channels`.
416
+ temb_channels (`int`, defaults to `512`):
417
+ Number of time embedding channels.
418
+ num_layers (`int`, defaults to `1`):
419
+ Number of resnet layers.
420
+ dropout (`float`, defaults to `0.0`):
421
+ Dropout rate.
422
+ resnet_eps (`float`, defaults to `1e-6`):
423
+ Epsilon value for normalization layers.
424
+ resnet_act_fn (`str`, defaults to `"swish"`):
425
+ Activation function to use.
426
+ resnet_groups (`int`, defaults to `32`):
427
+ Number of groups to separate the channels into for group normalization.
428
+ add_downsample (`bool`, defaults to `True`):
429
+ Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
430
+ compress_time (`bool`, defaults to `False`):
431
+ Whether or not to downsample across temporal dimension.
432
+ pad_mode (str, defaults to `"first"`):
433
+ Padding mode.
434
+ """
435
+
436
+ _supports_gradient_checkpointing = True
437
+
438
+ def __init__(
439
+ self,
440
+ in_channels: int,
441
+ out_channels: int,
442
+ temb_channels: int,
443
+ dropout: float = 0.0,
444
+ num_layers: int = 1,
445
+ resnet_eps: float = 1e-6,
446
+ resnet_act_fn: str = "swish",
447
+ resnet_groups: int = 32,
448
+ add_downsample: bool = True,
449
+ downsample_padding: int = 0,
450
+ compress_time: bool = False,
451
+ pad_mode: str = "first",
452
+ ):
453
+ super().__init__()
454
+
455
+ resnets = []
456
+ for i in range(num_layers):
457
+ in_channel = in_channels if i == 0 else out_channels
458
+ resnets.append(
459
+ CogVideoXResnetBlock3D(
460
+ in_channels=in_channel,
461
+ out_channels=out_channels,
462
+ dropout=dropout,
463
+ temb_channels=temb_channels,
464
+ groups=resnet_groups,
465
+ eps=resnet_eps,
466
+ non_linearity=resnet_act_fn,
467
+ pad_mode=pad_mode,
468
+ )
469
+ )
470
+
471
+ self.resnets = nn.ModuleList(resnets)
472
+ self.downsamplers = None
473
+
474
+ if add_downsample:
475
+ self.downsamplers = nn.ModuleList(
476
+ [
477
+ CogVideoXDownsample3D(
478
+ out_channels, out_channels, padding=downsample_padding, compress_time=compress_time
479
+ )
480
+ ]
481
+ )
482
+
483
+ self.gradient_checkpointing = False
484
+
485
+ def forward(
486
+ self,
487
+ hidden_states: torch.Tensor,
488
+ temb: Optional[torch.Tensor] = None,
489
+ zq: Optional[torch.Tensor] = None,
490
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
491
+ ) -> torch.Tensor:
492
+ r"""Forward method of the `CogVideoXDownBlock3D` class."""
493
+
494
+ new_conv_cache = {}
495
+ conv_cache = conv_cache or {}
496
+
497
+ for i, resnet in enumerate(self.resnets):
498
+ conv_cache_key = f"resnet_{i}"
499
+
500
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
501
+
502
+ def create_custom_forward(module):
503
+ def create_forward(*inputs):
504
+ return module(*inputs)
505
+
506
+ return create_forward
507
+
508
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
509
+ create_custom_forward(resnet),
510
+ hidden_states,
511
+ temb,
512
+ zq,
513
+ conv_cache.get(conv_cache_key),
514
+ )
515
+ else:
516
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
517
+ hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
518
+ )
519
+
520
+ if self.downsamplers is not None:
521
+ for downsampler in self.downsamplers:
522
+ hidden_states = downsampler(hidden_states)
523
+
524
+ return hidden_states, new_conv_cache
525
+
526
+
527
+ class CogVideoXMidBlock3D(nn.Module):
528
+ r"""
529
+ A middle block used in the CogVideoX model.
530
+
531
+ Args:
532
+ in_channels (`int`):
533
+ Number of input channels.
534
+ temb_channels (`int`, defaults to `512`):
535
+ Number of time embedding channels.
536
+ dropout (`float`, defaults to `0.0`):
537
+ Dropout rate.
538
+ num_layers (`int`, defaults to `1`):
539
+ Number of resnet layers.
540
+ resnet_eps (`float`, defaults to `1e-6`):
541
+ Epsilon value for normalization layers.
542
+ resnet_act_fn (`str`, defaults to `"swish"`):
543
+ Activation function to use.
544
+ resnet_groups (`int`, defaults to `32`):
545
+ Number of groups to separate the channels into for group normalization.
546
+ spatial_norm_dim (`int`, *optional*):
547
+ The dimension to use for spatial norm if it is to be used instead of group norm.
548
+ pad_mode (str, defaults to `"first"`):
549
+ Padding mode.
550
+ """
551
+
552
+ _supports_gradient_checkpointing = True
553
+
554
+ def __init__(
555
+ self,
556
+ in_channels: int,
557
+ temb_channels: int,
558
+ dropout: float = 0.0,
559
+ num_layers: int = 1,
560
+ resnet_eps: float = 1e-6,
561
+ resnet_act_fn: str = "swish",
562
+ resnet_groups: int = 32,
563
+ spatial_norm_dim: Optional[int] = None,
564
+ pad_mode: str = "first",
565
+ ):
566
+ super().__init__()
567
+
568
+ resnets = []
569
+ for _ in range(num_layers):
570
+ resnets.append(
571
+ CogVideoXResnetBlock3D(
572
+ in_channels=in_channels,
573
+ out_channels=in_channels,
574
+ dropout=dropout,
575
+ temb_channels=temb_channels,
576
+ groups=resnet_groups,
577
+ eps=resnet_eps,
578
+ spatial_norm_dim=spatial_norm_dim,
579
+ non_linearity=resnet_act_fn,
580
+ pad_mode=pad_mode,
581
+ )
582
+ )
583
+ self.resnets = nn.ModuleList(resnets)
584
+
585
+ self.gradient_checkpointing = False
586
+
587
+ def forward(
588
+ self,
589
+ hidden_states: torch.Tensor,
590
+ temb: Optional[torch.Tensor] = None,
591
+ zq: Optional[torch.Tensor] = None,
592
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
593
+ ) -> torch.Tensor:
594
+ r"""Forward method of the `CogVideoXMidBlock3D` class."""
595
+
596
+ new_conv_cache = {}
597
+ conv_cache = conv_cache or {}
598
+
599
+ for i, resnet in enumerate(self.resnets):
600
+ conv_cache_key = f"resnet_{i}"
601
+
602
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
603
+
604
+ def create_custom_forward(module):
605
+ def create_forward(*inputs):
606
+ return module(*inputs)
607
+
608
+ return create_forward
609
+
610
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
611
+ create_custom_forward(resnet), hidden_states, temb, zq, conv_cache.get(conv_cache_key)
612
+ )
613
+ else:
614
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
615
+ hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
616
+ )
617
+
618
+ return hidden_states, new_conv_cache
619
+
620
+
621
+ class CogVideoXUpBlock3D(nn.Module):
622
+ r"""
623
+ An upsampling block used in the CogVideoX model.
624
+
625
+ Args:
626
+ in_channels (`int`):
627
+ Number of input channels.
628
+ out_channels (`int`, *optional*):
629
+ Number of output channels. If None, defaults to `in_channels`.
630
+ temb_channels (`int`, defaults to `512`):
631
+ Number of time embedding channels.
632
+ dropout (`float`, defaults to `0.0`):
633
+ Dropout rate.
634
+ num_layers (`int`, defaults to `1`):
635
+ Number of resnet layers.
636
+ resnet_eps (`float`, defaults to `1e-6`):
637
+ Epsilon value for normalization layers.
638
+ resnet_act_fn (`str`, defaults to `"swish"`):
639
+ Activation function to use.
640
+ resnet_groups (`int`, defaults to `32`):
641
+ Number of groups to separate the channels into for group normalization.
642
+ spatial_norm_dim (`int`, defaults to `16`):
643
+ The dimension to use for spatial norm if it is to be used instead of group norm.
644
+ add_upsample (`bool`, defaults to `True`):
645
+ Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension.
646
+ compress_time (`bool`, defaults to `False`):
647
+ Whether or not to downsample across temporal dimension.
648
+ pad_mode (str, defaults to `"first"`):
649
+ Padding mode.
650
+ """
651
+
652
+ def __init__(
653
+ self,
654
+ in_channels: int,
655
+ out_channels: int,
656
+ temb_channels: int,
657
+ dropout: float = 0.0,
658
+ num_layers: int = 1,
659
+ resnet_eps: float = 1e-6,
660
+ resnet_act_fn: str = "swish",
661
+ resnet_groups: int = 32,
662
+ spatial_norm_dim: int = 16,
663
+ add_upsample: bool = True,
664
+ upsample_padding: int = 1,
665
+ compress_time: bool = False,
666
+ pad_mode: str = "first",
667
+ ):
668
+ super().__init__()
669
+
670
+ resnets = []
671
+ for i in range(num_layers):
672
+ in_channel = in_channels if i == 0 else out_channels
673
+ resnets.append(
674
+ CogVideoXResnetBlock3D(
675
+ in_channels=in_channel,
676
+ out_channels=out_channels,
677
+ dropout=dropout,
678
+ temb_channels=temb_channels,
679
+ groups=resnet_groups,
680
+ eps=resnet_eps,
681
+ non_linearity=resnet_act_fn,
682
+ spatial_norm_dim=spatial_norm_dim,
683
+ pad_mode=pad_mode,
684
+ )
685
+ )
686
+
687
+ self.resnets = nn.ModuleList(resnets)
688
+ self.upsamplers = None
689
+
690
+ if add_upsample:
691
+ self.upsamplers = nn.ModuleList(
692
+ [
693
+ CogVideoXUpsample3D(
694
+ out_channels, out_channels, padding=upsample_padding, compress_time=compress_time
695
+ )
696
+ ]
697
+ )
698
+
699
+ self.gradient_checkpointing = False
700
+
701
+ def forward(
702
+ self,
703
+ hidden_states: torch.Tensor,
704
+ temb: Optional[torch.Tensor] = None,
705
+ zq: Optional[torch.Tensor] = None,
706
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
707
+ ) -> torch.Tensor:
708
+ r"""Forward method of the `CogVideoXUpBlock3D` class."""
709
+
710
+ new_conv_cache = {}
711
+ conv_cache = conv_cache or {}
712
+
713
+ for i, resnet in enumerate(self.resnets):
714
+ conv_cache_key = f"resnet_{i}"
715
+
716
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
717
+
718
+ def create_custom_forward(module):
719
+ def create_forward(*inputs):
720
+ return module(*inputs)
721
+
722
+ return create_forward
723
+
724
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
725
+ create_custom_forward(resnet),
726
+ hidden_states,
727
+ temb,
728
+ zq,
729
+ conv_cache.get(conv_cache_key),
730
+ )
731
+ else:
732
+ hidden_states, new_conv_cache[conv_cache_key] = resnet(
733
+ hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
734
+ )
735
+
736
+ if self.upsamplers is not None:
737
+ for upsampler in self.upsamplers:
738
+ hidden_states = upsampler(hidden_states)
739
+
740
+ return hidden_states, new_conv_cache
741
+
742
+
743
+ class CogVideoXEncoder3D(nn.Module):
744
+ r"""
745
+ The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation.
746
+
747
+ Args:
748
+ in_channels (`int`, *optional*, defaults to 3):
749
+ The number of input channels.
750
+ out_channels (`int`, *optional*, defaults to 3):
751
+ The number of output channels.
752
+ down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
753
+ The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
754
+ options.
755
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
756
+ The number of output channels for each block.
757
+ act_fn (`str`, *optional*, defaults to `"silu"`):
758
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
759
+ layers_per_block (`int`, *optional*, defaults to 2):
760
+ The number of layers per block.
761
+ norm_num_groups (`int`, *optional*, defaults to 32):
762
+ The number of groups for normalization.
763
+ """
764
+
765
+ _supports_gradient_checkpointing = True
766
+
767
+ def __init__(
768
+ self,
769
+ in_channels: int = 3,
770
+ out_channels: int = 16,
771
+ down_block_types: Tuple[str, ...] = (
772
+ "CogVideoXDownBlock3D",
773
+ "CogVideoXDownBlock3D",
774
+ "CogVideoXDownBlock3D",
775
+ "CogVideoXDownBlock3D",
776
+ ),
777
+ block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
778
+ layers_per_block: int = 3,
779
+ act_fn: str = "silu",
780
+ norm_eps: float = 1e-6,
781
+ norm_num_groups: int = 32,
782
+ dropout: float = 0.0,
783
+ pad_mode: str = "first",
784
+ temporal_compression_ratio: float = 4,
785
+ ):
786
+ super().__init__()
787
+
788
+ # log2 of temporal_compress_times
789
+ temporal_compress_level = int(np.log2(temporal_compression_ratio))
790
+
791
+ self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode)
792
+ self.down_blocks = nn.ModuleList([])
793
+
794
+ # down blocks
795
+ output_channel = block_out_channels[0]
796
+ for i, down_block_type in enumerate(down_block_types):
797
+ input_channel = output_channel
798
+ output_channel = block_out_channels[i]
799
+ is_final_block = i == len(block_out_channels) - 1
800
+ compress_time = i < temporal_compress_level
801
+
802
+ if down_block_type == "CogVideoXDownBlock3D":
803
+ down_block = CogVideoXDownBlock3D(
804
+ in_channels=input_channel,
805
+ out_channels=output_channel,
806
+ temb_channels=0,
807
+ dropout=dropout,
808
+ num_layers=layers_per_block,
809
+ resnet_eps=norm_eps,
810
+ resnet_act_fn=act_fn,
811
+ resnet_groups=norm_num_groups,
812
+ add_downsample=not is_final_block,
813
+ compress_time=compress_time,
814
+ )
815
+ else:
816
+ raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`")
817
+
818
+ self.down_blocks.append(down_block)
819
+
820
+ # mid block
821
+ self.mid_block = CogVideoXMidBlock3D(
822
+ in_channels=block_out_channels[-1],
823
+ temb_channels=0,
824
+ dropout=dropout,
825
+ num_layers=2,
826
+ resnet_eps=norm_eps,
827
+ resnet_act_fn=act_fn,
828
+ resnet_groups=norm_num_groups,
829
+ pad_mode=pad_mode,
830
+ )
831
+
832
+ self.norm_out = nn.GroupNorm(norm_num_groups, block_out_channels[-1], eps=1e-6)
833
+ self.conv_act = nn.SiLU()
834
+ self.conv_out = CogVideoXCausalConv3d(
835
+ block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode
836
+ )
837
+
838
+ self.gradient_checkpointing = False
839
+
840
+ def forward(
841
+ self,
842
+ sample: torch.Tensor,
843
+ temb: Optional[torch.Tensor] = None,
844
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
845
+ ) -> torch.Tensor:
846
+ r"""The forward method of the `CogVideoXEncoder3D` class."""
847
+
848
+ new_conv_cache = {}
849
+ conv_cache = conv_cache or {}
850
+
851
+ hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
852
+
853
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
854
+
855
+ def create_custom_forward(module):
856
+ def custom_forward(*inputs):
857
+ return module(*inputs)
858
+
859
+ return custom_forward
860
+
861
+ # 1. Down
862
+ for i, down_block in enumerate(self.down_blocks):
863
+ conv_cache_key = f"down_block_{i}"
864
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
865
+ create_custom_forward(down_block),
866
+ hidden_states,
867
+ temb,
868
+ None,
869
+ conv_cache.get(conv_cache_key),
870
+ )
871
+
872
+ # 2. Mid
873
+ hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
874
+ create_custom_forward(self.mid_block),
875
+ hidden_states,
876
+ temb,
877
+ None,
878
+ conv_cache.get("mid_block"),
879
+ )
880
+ else:
881
+ # 1. Down
882
+ for i, down_block in enumerate(self.down_blocks):
883
+ conv_cache_key = f"down_block_{i}"
884
+ hidden_states, new_conv_cache[conv_cache_key] = down_block(
885
+ hidden_states, temb, None, conv_cache=conv_cache.get(conv_cache_key)
886
+ )
887
+
888
+ # 2. Mid
889
+ hidden_states, new_conv_cache["mid_block"] = self.mid_block(
890
+ hidden_states, temb, None, conv_cache=conv_cache.get("mid_block")
891
+ )
892
+
893
+ # 3. Post-process
894
+ hidden_states = self.norm_out(hidden_states)
895
+ hidden_states = self.conv_act(hidden_states)
896
+
897
+ hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
898
+
899
+ return hidden_states, new_conv_cache
900
+
901
+
902
+ class CogVideoXDecoder3D(nn.Module):
903
+ r"""
904
+ The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output
905
+ sample.
906
+
907
+ Args:
908
+ in_channels (`int`, *optional*, defaults to 3):
909
+ The number of input channels.
910
+ out_channels (`int`, *optional*, defaults to 3):
911
+ The number of output channels.
912
+ up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
913
+ The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
914
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
915
+ The number of output channels for each block.
916
+ act_fn (`str`, *optional*, defaults to `"silu"`):
917
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
918
+ layers_per_block (`int`, *optional*, defaults to 2):
919
+ The number of layers per block.
920
+ norm_num_groups (`int`, *optional*, defaults to 32):
921
+ The number of groups for normalization.
922
+ """
923
+
924
+ _supports_gradient_checkpointing = True
925
+
926
+ def __init__(
927
+ self,
928
+ in_channels: int = 16,
929
+ out_channels: int = 3,
930
+ up_block_types: Tuple[str, ...] = (
931
+ "CogVideoXUpBlock3D",
932
+ "CogVideoXUpBlock3D",
933
+ "CogVideoXUpBlock3D",
934
+ "CogVideoXUpBlock3D",
935
+ ),
936
+ block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
937
+ layers_per_block: int = 3,
938
+ act_fn: str = "silu",
939
+ norm_eps: float = 1e-6,
940
+ norm_num_groups: int = 32,
941
+ dropout: float = 0.0,
942
+ pad_mode: str = "first",
943
+ temporal_compression_ratio: float = 4,
944
+ ):
945
+ super().__init__()
946
+
947
+ reversed_block_out_channels = list(reversed(block_out_channels))
948
+
949
+ self.conv_in = CogVideoXCausalConv3d(
950
+ in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode
951
+ )
952
+
953
+ # mid block
954
+ self.mid_block = CogVideoXMidBlock3D(
955
+ in_channels=reversed_block_out_channels[0],
956
+ temb_channels=0,
957
+ num_layers=2,
958
+ resnet_eps=norm_eps,
959
+ resnet_act_fn=act_fn,
960
+ resnet_groups=norm_num_groups,
961
+ spatial_norm_dim=in_channels,
962
+ pad_mode=pad_mode,
963
+ )
964
+
965
+ # up blocks
966
+ self.up_blocks = nn.ModuleList([])
967
+
968
+ output_channel = reversed_block_out_channels[0]
969
+ temporal_compress_level = int(np.log2(temporal_compression_ratio))
970
+
971
+ for i, up_block_type in enumerate(up_block_types):
972
+ prev_output_channel = output_channel
973
+ output_channel = reversed_block_out_channels[i]
974
+ is_final_block = i == len(block_out_channels) - 1
975
+ compress_time = i < temporal_compress_level
976
+
977
+ if up_block_type == "CogVideoXUpBlock3D":
978
+ up_block = CogVideoXUpBlock3D(
979
+ in_channels=prev_output_channel,
980
+ out_channels=output_channel,
981
+ temb_channels=0,
982
+ dropout=dropout,
983
+ num_layers=layers_per_block + 1,
984
+ resnet_eps=norm_eps,
985
+ resnet_act_fn=act_fn,
986
+ resnet_groups=norm_num_groups,
987
+ spatial_norm_dim=in_channels,
988
+ add_upsample=not is_final_block,
989
+ compress_time=compress_time,
990
+ pad_mode=pad_mode,
991
+ )
992
+ prev_output_channel = output_channel
993
+ else:
994
+ raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`")
995
+
996
+ self.up_blocks.append(up_block)
997
+
998
+ self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups)
999
+ self.conv_act = nn.SiLU()
1000
+ self.conv_out = CogVideoXCausalConv3d(
1001
+ reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode
1002
+ )
1003
+
1004
+ self.gradient_checkpointing = False
1005
+
1006
+ def forward(
1007
+ self,
1008
+ sample: torch.Tensor,
1009
+ temb: Optional[torch.Tensor] = None,
1010
+ conv_cache: Optional[Dict[str, torch.Tensor]] = None,
1011
+ ) -> torch.Tensor:
1012
+ r"""The forward method of the `CogVideoXDecoder3D` class."""
1013
+
1014
+ new_conv_cache = {}
1015
+ conv_cache = conv_cache or {}
1016
+
1017
+ hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
1018
+
1019
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1020
+
1021
+ def create_custom_forward(module):
1022
+ def custom_forward(*inputs):
1023
+ return module(*inputs)
1024
+
1025
+ return custom_forward
1026
+
1027
+ # 1. Mid
1028
+ hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
1029
+ create_custom_forward(self.mid_block),
1030
+ hidden_states,
1031
+ temb,
1032
+ sample,
1033
+ conv_cache.get("mid_block"),
1034
+ )
1035
+
1036
+ # 2. Up
1037
+ for i, up_block in enumerate(self.up_blocks):
1038
+ conv_cache_key = f"up_block_{i}"
1039
+ hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
1040
+ create_custom_forward(up_block),
1041
+ hidden_states,
1042
+ temb,
1043
+ sample,
1044
+ conv_cache.get(conv_cache_key),
1045
+ )
1046
+ else:
1047
+ # 1. Mid
1048
+ hidden_states, new_conv_cache["mid_block"] = self.mid_block(
1049
+ hidden_states, temb, sample, conv_cache=conv_cache.get("mid_block")
1050
+ )
1051
+
1052
+ # 2. Up
1053
+ for i, up_block in enumerate(self.up_blocks):
1054
+ conv_cache_key = f"up_block_{i}"
1055
+ hidden_states, new_conv_cache[conv_cache_key] = up_block(
1056
+ hidden_states, temb, sample, conv_cache=conv_cache.get(conv_cache_key)
1057
+ )
1058
+
1059
+ # 3. Post-process
1060
+ hidden_states, new_conv_cache["norm_out"] = self.norm_out(
1061
+ hidden_states, sample, conv_cache=conv_cache.get("norm_out")
1062
+ )
1063
+ hidden_states = self.conv_act(hidden_states)
1064
+ hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
1065
+
1066
+ return hidden_states, new_conv_cache
1067
+
1068
+
1069
+ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1070
+ r"""
1071
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
1072
+ [CogVideoX](https://github.com/THUDM/CogVideo).
1073
+
1074
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
1075
+ for all models (such as downloading or saving).
1076
+
1077
+ Parameters:
1078
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
1079
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
1080
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
1081
+ Tuple of downsample block types.
1082
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
1083
+ Tuple of upsample block types.
1084
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
1085
+ Tuple of block output channels.
1086
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
1087
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
1088
+ scaling_factor (`float`, *optional*, defaults to `1.15258426`):
1089
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
1090
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
1091
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
1092
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
1093
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
1094
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
1095
+ force_upcast (`bool`, *optional*, default to `True`):
1096
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
1097
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
1098
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
1099
+ """
1100
+
1101
+ _supports_gradient_checkpointing = True
1102
+ _no_split_modules = ["CogVideoXResnetBlock3D"]
1103
+
1104
+ @register_to_config
1105
+ def __init__(
1106
+ self,
1107
+ in_channels: int = 3,
1108
+ out_channels: int = 3,
1109
+ down_block_types: Tuple[str] = (
1110
+ "CogVideoXDownBlock3D",
1111
+ "CogVideoXDownBlock3D",
1112
+ "CogVideoXDownBlock3D",
1113
+ "CogVideoXDownBlock3D",
1114
+ ),
1115
+ up_block_types: Tuple[str] = (
1116
+ "CogVideoXUpBlock3D",
1117
+ "CogVideoXUpBlock3D",
1118
+ "CogVideoXUpBlock3D",
1119
+ "CogVideoXUpBlock3D",
1120
+ ),
1121
+ block_out_channels: Tuple[int] = (128, 256, 256, 512),
1122
+ latent_channels: int = 16,
1123
+ layers_per_block: int = 3,
1124
+ act_fn: str = "silu",
1125
+ norm_eps: float = 1e-6,
1126
+ norm_num_groups: int = 32,
1127
+ temporal_compression_ratio: float = 4,
1128
+ sample_height: int = 480,
1129
+ sample_width: int = 720,
1130
+ scaling_factor: float = 1.15258426,
1131
+ shift_factor: Optional[float] = None,
1132
+ latents_mean: Optional[Tuple[float]] = None,
1133
+ latents_std: Optional[Tuple[float]] = None,
1134
+ force_upcast: float = True,
1135
+ use_quant_conv: bool = False,
1136
+ use_post_quant_conv: bool = False,
1137
+ invert_scale_latents: bool = False,
1138
+ ):
1139
+ super().__init__()
1140
+
1141
+ self.encoder = CogVideoXEncoder3D(
1142
+ in_channels=in_channels,
1143
+ out_channels=latent_channels,
1144
+ down_block_types=down_block_types,
1145
+ block_out_channels=block_out_channels,
1146
+ layers_per_block=layers_per_block,
1147
+ act_fn=act_fn,
1148
+ norm_eps=norm_eps,
1149
+ norm_num_groups=norm_num_groups,
1150
+ temporal_compression_ratio=temporal_compression_ratio,
1151
+ )
1152
+ self.decoder = CogVideoXDecoder3D(
1153
+ in_channels=latent_channels,
1154
+ out_channels=out_channels,
1155
+ up_block_types=up_block_types,
1156
+ block_out_channels=block_out_channels,
1157
+ layers_per_block=layers_per_block,
1158
+ act_fn=act_fn,
1159
+ norm_eps=norm_eps,
1160
+ norm_num_groups=norm_num_groups,
1161
+ temporal_compression_ratio=temporal_compression_ratio,
1162
+ )
1163
+ self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None
1164
+ self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None
1165
+
1166
+ self.use_slicing = False
1167
+ self.use_tiling = False
1168
+ self.auto_split_process = False
1169
+
1170
+ # Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
1171
+ # recommended because the temporal parts of the VAE, here, are tricky to understand.
1172
+ # If you decode X latent frames together, the number of output frames is:
1173
+ # (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames
1174
+ #
1175
+ # Example with num_latent_frames_batch_size = 2:
1176
+ # - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together
1177
+ # => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
1178
+ # => 6 * 8 = 48 frames
1179
+ # - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together
1180
+ # => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) +
1181
+ # ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
1182
+ # => 1 * 9 + 5 * 8 = 49 frames
1183
+ # It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that
1184
+ # setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
1185
+ # number of temporal frames.
1186
+ self.num_latent_frames_batch_size = 2
1187
+ self.num_sample_frames_batch_size = 8
1188
+
1189
+ # We make the minimum height and width of sample for tiling half that of the generally supported
1190
+ self.tile_sample_min_height = sample_height // 2
1191
+ self.tile_sample_min_width = sample_width // 2
1192
+ self.tile_latent_min_height = int(
1193
+ self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
1194
+ )
1195
+ self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
1196
+
1197
+ # These are experimental overlap factors that were chosen based on experimentation and seem to work best for
1198
+ # 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX
1199
+ # and so the tiling implementation has only been tested on those specific resolutions.
1200
+ self.tile_overlap_factor_height = 1 / 6
1201
+ self.tile_overlap_factor_width = 1 / 5
1202
+
1203
+ def _set_gradient_checkpointing(self, module, value=False):
1204
+ if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
1205
+ module.gradient_checkpointing = value
1206
+
1207
+ def enable_tiling(
1208
+ self,
1209
+ tile_sample_min_height: Optional[int] = None,
1210
+ tile_sample_min_width: Optional[int] = None,
1211
+ tile_overlap_factor_height: Optional[float] = None,
1212
+ tile_overlap_factor_width: Optional[float] = None,
1213
+ ) -> None:
1214
+ r"""
1215
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
1216
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
1217
+ processing larger images.
1218
+
1219
+ Args:
1220
+ tile_sample_min_height (`int`, *optional*):
1221
+ The minimum height required for a sample to be separated into tiles across the height dimension.
1222
+ tile_sample_min_width (`int`, *optional*):
1223
+ The minimum width required for a sample to be separated into tiles across the width dimension.
1224
+ tile_overlap_factor_height (`int`, *optional*):
1225
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
1226
+ no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
1227
+ value might cause more tiles to be processed leading to slow down of the decoding process.
1228
+ tile_overlap_factor_width (`int`, *optional*):
1229
+ The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
1230
+ are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
1231
+ value might cause more tiles to be processed leading to slow down of the decoding process.
1232
+ """
1233
+ self.use_tiling = True
1234
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
1235
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
1236
+ self.tile_latent_min_height = int(
1237
+ self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
1238
+ )
1239
+ self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
1240
+ self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
1241
+ self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
1242
+
1243
+ def disable_tiling(self) -> None:
1244
+ r"""
1245
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
1246
+ decoding in one step.
1247
+ """
1248
+ self.use_tiling = False
1249
+
1250
+ def enable_slicing(self) -> None:
1251
+ r"""
1252
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
1253
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
1254
+ """
1255
+ self.use_slicing = True
1256
+
1257
+ def disable_slicing(self) -> None:
1258
+ r"""
1259
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
1260
+ decoding in one step.
1261
+ """
1262
+ self.use_slicing = False
1263
+
1264
+ def _set_first_frame(self):
1265
+ for name, module in self.named_modules():
1266
+ if isinstance(module, CogVideoXUpsample3D):
1267
+ module.auto_split_process = False
1268
+ module.first_frame_flag = True
1269
+
1270
+ def _set_rest_frame(self):
1271
+ for name, module in self.named_modules():
1272
+ if isinstance(module, CogVideoXUpsample3D):
1273
+ module.auto_split_process = False
1274
+ module.first_frame_flag = False
1275
+
1276
+ def enable_auto_split_process(self) -> None:
1277
+ self.auto_split_process = True
1278
+ for name, module in self.named_modules():
1279
+ if isinstance(module, CogVideoXUpsample3D):
1280
+ module.auto_split_process = True
1281
+
1282
+ def disable_auto_split_process(self) -> None:
1283
+ self.auto_split_process = False
1284
+
1285
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
1286
+ batch_size, num_channels, num_frames, height, width = x.shape
1287
+
1288
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
1289
+ return self.tiled_encode(x)
1290
+
1291
+ frame_batch_size = self.num_sample_frames_batch_size
1292
+ # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
1293
+ # As the extra single frame is handled inside the loop, it is not required to round up here.
1294
+ num_batches = max(num_frames // frame_batch_size, 1)
1295
+ conv_cache = None
1296
+ enc = []
1297
+
1298
+ for i in range(num_batches):
1299
+ remaining_frames = num_frames % frame_batch_size
1300
+ start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
1301
+ end_frame = frame_batch_size * (i + 1) + remaining_frames
1302
+ x_intermediate = x[:, :, start_frame:end_frame]
1303
+ x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
1304
+ if self.quant_conv is not None:
1305
+ x_intermediate = self.quant_conv(x_intermediate)
1306
+ enc.append(x_intermediate)
1307
+
1308
+ enc = torch.cat(enc, dim=2)
1309
+ return enc
1310
+
1311
+ @apply_forward_hook
1312
+ def encode(
1313
+ self, x: torch.Tensor, return_dict: bool = True
1314
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
1315
+ """
1316
+ Encode a batch of images into latents.
1317
+
1318
+ Args:
1319
+ x (`torch.Tensor`): Input batch of images.
1320
+ return_dict (`bool`, *optional*, defaults to `True`):
1321
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
1322
+
1323
+ Returns:
1324
+ The latent representations of the encoded videos. If `return_dict` is True, a
1325
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
1326
+ """
1327
+ if self.use_slicing and x.shape[0] > 1:
1328
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
1329
+ h = torch.cat(encoded_slices)
1330
+ else:
1331
+ h = self._encode(x)
1332
+
1333
+ posterior = DiagonalGaussianDistribution(h)
1334
+
1335
+ if not return_dict:
1336
+ return (posterior,)
1337
+ return AutoencoderKLOutput(latent_dist=posterior)
1338
+
1339
+ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1340
+ batch_size, num_channels, num_frames, height, width = z.shape
1341
+
1342
+ if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
1343
+ return self.tiled_decode(z, return_dict=return_dict)
1344
+
1345
+ if self.auto_split_process:
1346
+ frame_batch_size = self.num_latent_frames_batch_size
1347
+ num_batches = max(num_frames // frame_batch_size, 1)
1348
+ conv_cache = None
1349
+ dec = []
1350
+
1351
+ for i in range(num_batches):
1352
+ remaining_frames = num_frames % frame_batch_size
1353
+ start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
1354
+ end_frame = frame_batch_size * (i + 1) + remaining_frames
1355
+ z_intermediate = z[:, :, start_frame:end_frame]
1356
+ if self.post_quant_conv is not None:
1357
+ z_intermediate = self.post_quant_conv(z_intermediate)
1358
+ z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
1359
+ dec.append(z_intermediate)
1360
+ else:
1361
+ conv_cache = None
1362
+ start_frame = 0
1363
+ end_frame = 1
1364
+ dec = []
1365
+
1366
+ self._set_first_frame()
1367
+ z_intermediate = z[:, :, start_frame:end_frame]
1368
+ if self.post_quant_conv is not None:
1369
+ z_intermediate = self.post_quant_conv(z_intermediate)
1370
+ z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
1371
+ dec.append(z_intermediate)
1372
+
1373
+ self._set_rest_frame()
1374
+ start_frame = end_frame
1375
+ end_frame += self.num_latent_frames_batch_size
1376
+
1377
+ while start_frame < num_frames:
1378
+ z_intermediate = z[:, :, start_frame:end_frame]
1379
+ if self.post_quant_conv is not None:
1380
+ z_intermediate = self.post_quant_conv(z_intermediate)
1381
+ z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
1382
+ dec.append(z_intermediate)
1383
+ start_frame = end_frame
1384
+ end_frame += self.num_latent_frames_batch_size
1385
+
1386
+ dec = torch.cat(dec, dim=2)
1387
+
1388
+ if not return_dict:
1389
+ return (dec,)
1390
+
1391
+ return DecoderOutput(sample=dec)
1392
+
1393
+ @apply_forward_hook
1394
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1395
+ """
1396
+ Decode a batch of images.
1397
+
1398
+ Args:
1399
+ z (`torch.Tensor`): Input batch of latent vectors.
1400
+ return_dict (`bool`, *optional*, defaults to `True`):
1401
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1402
+
1403
+ Returns:
1404
+ [`~models.vae.DecoderOutput`] or `tuple`:
1405
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1406
+ returned.
1407
+ """
1408
+ if self.use_slicing and z.shape[0] > 1:
1409
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
1410
+ decoded = torch.cat(decoded_slices)
1411
+ else:
1412
+ decoded = self._decode(z).sample
1413
+
1414
+ if not return_dict:
1415
+ return (decoded,)
1416
+ return DecoderOutput(sample=decoded)
1417
+
1418
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1419
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
1420
+ for y in range(blend_extent):
1421
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
1422
+ y / blend_extent
1423
+ )
1424
+ return b
1425
+
1426
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1427
+ blend_extent = min(a.shape[4], b.shape[4], blend_extent)
1428
+ for x in range(blend_extent):
1429
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
1430
+ x / blend_extent
1431
+ )
1432
+ return b
1433
+
1434
+ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
1435
+ r"""Encode a batch of images using a tiled encoder.
1436
+
1437
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
1438
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
1439
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
1440
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
1441
+ output, but they should be much less noticeable.
1442
+
1443
+ Args:
1444
+ x (`torch.Tensor`): Input batch of videos.
1445
+
1446
+ Returns:
1447
+ `torch.Tensor`:
1448
+ The latent representation of the encoded videos.
1449
+ """
1450
+ # For a rough memory estimate, take a look at the `tiled_decode` method.
1451
+ batch_size, num_channels, num_frames, height, width = x.shape
1452
+
1453
+ overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height))
1454
+ overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width))
1455
+ blend_extent_height = int(self.tile_latent_min_height * self.tile_overlap_factor_height)
1456
+ blend_extent_width = int(self.tile_latent_min_width * self.tile_overlap_factor_width)
1457
+ row_limit_height = self.tile_latent_min_height - blend_extent_height
1458
+ row_limit_width = self.tile_latent_min_width - blend_extent_width
1459
+ frame_batch_size = self.num_sample_frames_batch_size
1460
+
1461
+ # Split x into overlapping tiles and encode them separately.
1462
+ # The tiles have an overlap to avoid seams between tiles.
1463
+ rows = []
1464
+ for i in range(0, height, overlap_height):
1465
+ row = []
1466
+ for j in range(0, width, overlap_width):
1467
+ # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
1468
+ # As the extra single frame is handled inside the loop, it is not required to round up here.
1469
+ num_batches = max(num_frames // frame_batch_size, 1)
1470
+ conv_cache = None
1471
+ time = []
1472
+
1473
+ for k in range(num_batches):
1474
+ remaining_frames = num_frames % frame_batch_size
1475
+ start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
1476
+ end_frame = frame_batch_size * (k + 1) + remaining_frames
1477
+ tile = x[
1478
+ :,
1479
+ :,
1480
+ start_frame:end_frame,
1481
+ i : i + self.tile_sample_min_height,
1482
+ j : j + self.tile_sample_min_width,
1483
+ ]
1484
+ tile, conv_cache = self.encoder(tile, conv_cache=conv_cache)
1485
+ if self.quant_conv is not None:
1486
+ tile = self.quant_conv(tile)
1487
+ time.append(tile)
1488
+
1489
+ row.append(torch.cat(time, dim=2))
1490
+ rows.append(row)
1491
+
1492
+ result_rows = []
1493
+ for i, row in enumerate(rows):
1494
+ result_row = []
1495
+ for j, tile in enumerate(row):
1496
+ # blend the above tile and the left tile
1497
+ # to the current tile and add the current tile to the result row
1498
+ if i > 0:
1499
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
1500
+ if j > 0:
1501
+ tile = self.blend_h(row[j - 1], tile, blend_extent_width)
1502
+ result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
1503
+ result_rows.append(torch.cat(result_row, dim=4))
1504
+
1505
+ enc = torch.cat(result_rows, dim=3)
1506
+ return enc
1507
+
1508
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1509
+ r"""
1510
+ Decode a batch of images using a tiled decoder.
1511
+
1512
+ Args:
1513
+ z (`torch.Tensor`): Input batch of latent vectors.
1514
+ return_dict (`bool`, *optional*, defaults to `True`):
1515
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1516
+
1517
+ Returns:
1518
+ [`~models.vae.DecoderOutput`] or `tuple`:
1519
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1520
+ returned.
1521
+ """
1522
+ # Rough memory assessment:
1523
+ # - In CogVideoX-2B, there are a total of 24 CausalConv3d layers.
1524
+ # - The biggest intermediate dimensions are: [1, 128, 9, 480, 720].
1525
+ # - Assume fp16 (2 bytes per value).
1526
+ # Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB
1527
+ #
1528
+ # Memory assessment when using tiling:
1529
+ # - Assume everything as above but now HxW is 240x360 by tiling in half
1530
+ # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB
1531
+
1532
+ batch_size, num_channels, num_frames, height, width = z.shape
1533
+
1534
+ overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
1535
+ overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
1536
+ blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
1537
+ blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
1538
+ row_limit_height = self.tile_sample_min_height - blend_extent_height
1539
+ row_limit_width = self.tile_sample_min_width - blend_extent_width
1540
+ frame_batch_size = self.num_latent_frames_batch_size
1541
+
1542
+ # Split z into overlapping tiles and decode them separately.
1543
+ # The tiles have an overlap to avoid seams between tiles.
1544
+ rows = []
1545
+ for i in range(0, height, overlap_height):
1546
+ row = []
1547
+ for j in range(0, width, overlap_width):
1548
+ if self.auto_split_process:
1549
+ num_batches = max(num_frames // frame_batch_size, 1)
1550
+ conv_cache = None
1551
+ time = []
1552
+
1553
+ for k in range(num_batches):
1554
+ remaining_frames = num_frames % frame_batch_size
1555
+ start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
1556
+ end_frame = frame_batch_size * (k + 1) + remaining_frames
1557
+ tile = z[
1558
+ :,
1559
+ :,
1560
+ start_frame:end_frame,
1561
+ i : i + self.tile_latent_min_height,
1562
+ j : j + self.tile_latent_min_width,
1563
+ ]
1564
+ if self.post_quant_conv is not None:
1565
+ tile = self.post_quant_conv(tile)
1566
+ tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
1567
+ time.append(tile)
1568
+
1569
+ row.append(torch.cat(time, dim=2))
1570
+ else:
1571
+ conv_cache = None
1572
+ start_frame = 0
1573
+ end_frame = 1
1574
+ dec = []
1575
+
1576
+ tile = z[
1577
+ :,
1578
+ :,
1579
+ start_frame:end_frame,
1580
+ i : i + self.tile_latent_min_height,
1581
+ j : j + self.tile_latent_min_width,
1582
+ ]
1583
+
1584
+ self._set_first_frame()
1585
+ if self.post_quant_conv is not None:
1586
+ tile = self.post_quant_conv(tile)
1587
+ tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
1588
+ dec.append(tile)
1589
+
1590
+ self._set_rest_frame()
1591
+ start_frame = end_frame
1592
+ end_frame += self.num_latent_frames_batch_size
1593
+
1594
+ while start_frame < num_frames:
1595
+ tile = z[
1596
+ :,
1597
+ :,
1598
+ start_frame:end_frame,
1599
+ i : i + self.tile_latent_min_height,
1600
+ j : j + self.tile_latent_min_width,
1601
+ ]
1602
+ if self.post_quant_conv is not None:
1603
+ tile = self.post_quant_conv(tile)
1604
+ tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
1605
+ dec.append(tile)
1606
+ start_frame = end_frame
1607
+ end_frame += self.num_latent_frames_batch_size
1608
+
1609
+ row.append(torch.cat(dec, dim=2))
1610
+ rows.append(row)
1611
+
1612
+ result_rows = []
1613
+ for i, row in enumerate(rows):
1614
+ result_row = []
1615
+ for j, tile in enumerate(row):
1616
+ # blend the above tile and the left tile
1617
+ # to the current tile and add the current tile to the result row
1618
+ if i > 0:
1619
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
1620
+ if j > 0:
1621
+ tile = self.blend_h(row[j - 1], tile, blend_extent_width)
1622
+ result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
1623
+ result_rows.append(torch.cat(result_row, dim=4))
1624
+
1625
+ dec = torch.cat(result_rows, dim=3)
1626
+
1627
+ if not return_dict:
1628
+ return (dec,)
1629
+
1630
+ return DecoderOutput(sample=dec)
1631
+
1632
+ def forward(
1633
+ self,
1634
+ sample: torch.Tensor,
1635
+ sample_posterior: bool = False,
1636
+ return_dict: bool = True,
1637
+ generator: Optional[torch.Generator] = None,
1638
+ ) -> Union[torch.Tensor, torch.Tensor]:
1639
+ x = sample
1640
+ posterior = self.encode(x).latent_dist
1641
+ if sample_posterior:
1642
+ z = posterior.sample(generator=generator)
1643
+ else:
1644
+ z = posterior.mode()
1645
+ dec = self.decode(z)
1646
+ if not return_dict:
1647
+ return (dec,)
1648
+ return dec
1649
+
1650
+ @classmethod
1651
+ def from_pretrained(cls, pretrained_model_path, subfolder=None, **vae_additional_kwargs):
1652
+ if subfolder is not None:
1653
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
1654
+
1655
+ config_file = os.path.join(pretrained_model_path, 'config.json')
1656
+ if not os.path.isfile(config_file):
1657
+ raise RuntimeError(f"{config_file} does not exist")
1658
+ with open(config_file, "r") as f:
1659
+ config = json.load(f)
1660
+
1661
+ model = cls.from_config(config, **vae_additional_kwargs)
1662
+ from diffusers.utils import WEIGHTS_NAME
1663
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
1664
+ model_file_safetensors = model_file.replace(".bin", ".safetensors")
1665
+ if os.path.exists(model_file_safetensors):
1666
+ from safetensors.torch import load_file, safe_open
1667
+ state_dict = load_file(model_file_safetensors)
1668
+ else:
1669
+ if not os.path.isfile(model_file):
1670
+ raise RuntimeError(f"{model_file} does not exist")
1671
+ state_dict = torch.load(model_file, map_location="cpu")
1672
+ m, u = model.load_state_dict(state_dict, strict=False)
1673
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
1674
+ print(m, u)
1675
+ return model
cogvideox/models/wan_image_encoder.py ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torchvision.transforms as T
9
+
10
+ from .wan_transformer3d import attention
11
+ from .wan_xlm_roberta import XLMRoberta
12
+ from diffusers.configuration_utils import ConfigMixin
13
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
14
+ from diffusers.models.modeling_utils import ModelMixin
15
+
16
+
17
+ __all__ = [
18
+ 'XLMRobertaCLIP',
19
+ 'clip_xlm_roberta_vit_h_14',
20
+ 'CLIPModel',
21
+ ]
22
+
23
+
24
+ def pos_interpolate(pos, seq_len):
25
+ if pos.size(1) == seq_len:
26
+ return pos
27
+ else:
28
+ src_grid = int(math.sqrt(pos.size(1)))
29
+ tar_grid = int(math.sqrt(seq_len))
30
+ n = pos.size(1) - src_grid * src_grid
31
+ return torch.cat([
32
+ pos[:, :n],
33
+ F.interpolate(
34
+ pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
35
+ 0, 3, 1, 2),
36
+ size=(tar_grid, tar_grid),
37
+ mode='bicubic',
38
+ align_corners=False).flatten(2).transpose(1, 2)
39
+ ],
40
+ dim=1)
41
+
42
+
43
+ class QuickGELU(nn.Module):
44
+
45
+ def forward(self, x):
46
+ return x * torch.sigmoid(1.702 * x)
47
+
48
+
49
+ class LayerNorm(nn.LayerNorm):
50
+
51
+ def forward(self, x):
52
+ return super().forward(x.float()).type_as(x)
53
+
54
+
55
+ class SelfAttention(nn.Module):
56
+
57
+ def __init__(self,
58
+ dim,
59
+ num_heads,
60
+ causal=False,
61
+ attn_dropout=0.0,
62
+ proj_dropout=0.0):
63
+ assert dim % num_heads == 0
64
+ super().__init__()
65
+ self.dim = dim
66
+ self.num_heads = num_heads
67
+ self.head_dim = dim // num_heads
68
+ self.causal = causal
69
+ self.attn_dropout = attn_dropout
70
+ self.proj_dropout = proj_dropout
71
+
72
+ # layers
73
+ self.to_qkv = nn.Linear(dim, dim * 3)
74
+ self.proj = nn.Linear(dim, dim)
75
+
76
+ def forward(self, x):
77
+ """
78
+ x: [B, L, C].
79
+ """
80
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
81
+
82
+ # compute query, key, value
83
+ q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
84
+
85
+ # compute attention
86
+ p = self.attn_dropout if self.training else 0.0
87
+ x = attention(q, k, v, dropout_p=p, causal=self.causal)
88
+ x = x.reshape(b, s, c)
89
+
90
+ # output
91
+ x = self.proj(x)
92
+ x = F.dropout(x, self.proj_dropout, self.training)
93
+ return x
94
+
95
+
96
+ class SwiGLU(nn.Module):
97
+
98
+ def __init__(self, dim, mid_dim):
99
+ super().__init__()
100
+ self.dim = dim
101
+ self.mid_dim = mid_dim
102
+
103
+ # layers
104
+ self.fc1 = nn.Linear(dim, mid_dim)
105
+ self.fc2 = nn.Linear(dim, mid_dim)
106
+ self.fc3 = nn.Linear(mid_dim, dim)
107
+
108
+ def forward(self, x):
109
+ x = F.silu(self.fc1(x)) * self.fc2(x)
110
+ x = self.fc3(x)
111
+ return x
112
+
113
+
114
+ class AttentionBlock(nn.Module):
115
+
116
+ def __init__(self,
117
+ dim,
118
+ mlp_ratio,
119
+ num_heads,
120
+ post_norm=False,
121
+ causal=False,
122
+ activation='quick_gelu',
123
+ attn_dropout=0.0,
124
+ proj_dropout=0.0,
125
+ norm_eps=1e-5):
126
+ assert activation in ['quick_gelu', 'gelu', 'swi_glu']
127
+ super().__init__()
128
+ self.dim = dim
129
+ self.mlp_ratio = mlp_ratio
130
+ self.num_heads = num_heads
131
+ self.post_norm = post_norm
132
+ self.causal = causal
133
+ self.norm_eps = norm_eps
134
+
135
+ # layers
136
+ self.norm1 = LayerNorm(dim, eps=norm_eps)
137
+ self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
138
+ proj_dropout)
139
+ self.norm2 = LayerNorm(dim, eps=norm_eps)
140
+ if activation == 'swi_glu':
141
+ self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
142
+ else:
143
+ self.mlp = nn.Sequential(
144
+ nn.Linear(dim, int(dim * mlp_ratio)),
145
+ QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
146
+ nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
147
+
148
+ def forward(self, x):
149
+ if self.post_norm:
150
+ x = x + self.norm1(self.attn(x))
151
+ x = x + self.norm2(self.mlp(x))
152
+ else:
153
+ x = x + self.attn(self.norm1(x))
154
+ x = x + self.mlp(self.norm2(x))
155
+ return x
156
+
157
+
158
+ class AttentionPool(nn.Module):
159
+
160
+ def __init__(self,
161
+ dim,
162
+ mlp_ratio,
163
+ num_heads,
164
+ activation='gelu',
165
+ proj_dropout=0.0,
166
+ norm_eps=1e-5):
167
+ assert dim % num_heads == 0
168
+ super().__init__()
169
+ self.dim = dim
170
+ self.mlp_ratio = mlp_ratio
171
+ self.num_heads = num_heads
172
+ self.head_dim = dim // num_heads
173
+ self.proj_dropout = proj_dropout
174
+ self.norm_eps = norm_eps
175
+
176
+ # layers
177
+ gain = 1.0 / math.sqrt(dim)
178
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
179
+ self.to_q = nn.Linear(dim, dim)
180
+ self.to_kv = nn.Linear(dim, dim * 2)
181
+ self.proj = nn.Linear(dim, dim)
182
+ self.norm = LayerNorm(dim, eps=norm_eps)
183
+ self.mlp = nn.Sequential(
184
+ nn.Linear(dim, int(dim * mlp_ratio)),
185
+ QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
186
+ nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
187
+
188
+ def forward(self, x):
189
+ """
190
+ x: [B, L, C].
191
+ """
192
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
193
+
194
+ # compute query, key, value
195
+ q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
196
+ k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
197
+
198
+ # compute attention
199
+ x = flash_attention(q, k, v, version=2)
200
+ x = x.reshape(b, 1, c)
201
+
202
+ # output
203
+ x = self.proj(x)
204
+ x = F.dropout(x, self.proj_dropout, self.training)
205
+
206
+ # mlp
207
+ x = x + self.mlp(self.norm(x))
208
+ return x[:, 0]
209
+
210
+
211
+ class VisionTransformer(nn.Module):
212
+
213
+ def __init__(self,
214
+ image_size=224,
215
+ patch_size=16,
216
+ dim=768,
217
+ mlp_ratio=4,
218
+ out_dim=512,
219
+ num_heads=12,
220
+ num_layers=12,
221
+ pool_type='token',
222
+ pre_norm=True,
223
+ post_norm=False,
224
+ activation='quick_gelu',
225
+ attn_dropout=0.0,
226
+ proj_dropout=0.0,
227
+ embedding_dropout=0.0,
228
+ norm_eps=1e-5):
229
+ if image_size % patch_size != 0:
230
+ print(
231
+ '[WARNING] image_size is not divisible by patch_size',
232
+ flush=True)
233
+ assert pool_type in ('token', 'token_fc', 'attn_pool')
234
+ out_dim = out_dim or dim
235
+ super().__init__()
236
+ self.image_size = image_size
237
+ self.patch_size = patch_size
238
+ self.num_patches = (image_size // patch_size)**2
239
+ self.dim = dim
240
+ self.mlp_ratio = mlp_ratio
241
+ self.out_dim = out_dim
242
+ self.num_heads = num_heads
243
+ self.num_layers = num_layers
244
+ self.pool_type = pool_type
245
+ self.post_norm = post_norm
246
+ self.norm_eps = norm_eps
247
+
248
+ # embeddings
249
+ gain = 1.0 / math.sqrt(dim)
250
+ self.patch_embedding = nn.Conv2d(
251
+ 3,
252
+ dim,
253
+ kernel_size=patch_size,
254
+ stride=patch_size,
255
+ bias=not pre_norm)
256
+ if pool_type in ('token', 'token_fc'):
257
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
258
+ self.pos_embedding = nn.Parameter(gain * torch.randn(
259
+ 1, self.num_patches +
260
+ (1 if pool_type in ('token', 'token_fc') else 0), dim))
261
+ self.dropout = nn.Dropout(embedding_dropout)
262
+
263
+ # transformer
264
+ self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
265
+ self.transformer = nn.Sequential(*[
266
+ AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
267
+ activation, attn_dropout, proj_dropout, norm_eps)
268
+ for _ in range(num_layers)
269
+ ])
270
+ self.post_norm = LayerNorm(dim, eps=norm_eps)
271
+
272
+ # head
273
+ if pool_type == 'token':
274
+ self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
275
+ elif pool_type == 'token_fc':
276
+ self.head = nn.Linear(dim, out_dim)
277
+ elif pool_type == 'attn_pool':
278
+ self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
279
+ proj_dropout, norm_eps)
280
+
281
+ def forward(self, x, interpolation=False, use_31_block=False):
282
+ b = x.size(0)
283
+
284
+ # embeddings
285
+ x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
286
+ if self.pool_type in ('token', 'token_fc'):
287
+ x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
288
+ if interpolation:
289
+ e = pos_interpolate(self.pos_embedding, x.size(1))
290
+ else:
291
+ e = self.pos_embedding
292
+ x = self.dropout(x + e)
293
+ if self.pre_norm is not None:
294
+ x = self.pre_norm(x)
295
+
296
+ # transformer
297
+ if use_31_block:
298
+ x = self.transformer[:-1](x)
299
+ return x
300
+ else:
301
+ x = self.transformer(x)
302
+ return x
303
+
304
+
305
+ class XLMRobertaWithHead(XLMRoberta):
306
+
307
+ def __init__(self, **kwargs):
308
+ self.out_dim = kwargs.pop('out_dim')
309
+ super().__init__(**kwargs)
310
+
311
+ # head
312
+ mid_dim = (self.dim + self.out_dim) // 2
313
+ self.head = nn.Sequential(
314
+ nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
315
+ nn.Linear(mid_dim, self.out_dim, bias=False))
316
+
317
+ def forward(self, ids):
318
+ # xlm-roberta
319
+ x = super().forward(ids)
320
+
321
+ # average pooling
322
+ mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
323
+ x = (x * mask).sum(dim=1) / mask.sum(dim=1)
324
+
325
+ # head
326
+ x = self.head(x)
327
+ return x
328
+
329
+
330
+ class XLMRobertaCLIP(nn.Module):
331
+
332
+ def __init__(self,
333
+ embed_dim=1024,
334
+ image_size=224,
335
+ patch_size=14,
336
+ vision_dim=1280,
337
+ vision_mlp_ratio=4,
338
+ vision_heads=16,
339
+ vision_layers=32,
340
+ vision_pool='token',
341
+ vision_pre_norm=True,
342
+ vision_post_norm=False,
343
+ activation='gelu',
344
+ vocab_size=250002,
345
+ max_text_len=514,
346
+ type_size=1,
347
+ pad_id=1,
348
+ text_dim=1024,
349
+ text_heads=16,
350
+ text_layers=24,
351
+ text_post_norm=True,
352
+ text_dropout=0.1,
353
+ attn_dropout=0.0,
354
+ proj_dropout=0.0,
355
+ embedding_dropout=0.0,
356
+ norm_eps=1e-5):
357
+ super().__init__()
358
+ self.embed_dim = embed_dim
359
+ self.image_size = image_size
360
+ self.patch_size = patch_size
361
+ self.vision_dim = vision_dim
362
+ self.vision_mlp_ratio = vision_mlp_ratio
363
+ self.vision_heads = vision_heads
364
+ self.vision_layers = vision_layers
365
+ self.vision_pre_norm = vision_pre_norm
366
+ self.vision_post_norm = vision_post_norm
367
+ self.activation = activation
368
+ self.vocab_size = vocab_size
369
+ self.max_text_len = max_text_len
370
+ self.type_size = type_size
371
+ self.pad_id = pad_id
372
+ self.text_dim = text_dim
373
+ self.text_heads = text_heads
374
+ self.text_layers = text_layers
375
+ self.text_post_norm = text_post_norm
376
+ self.norm_eps = norm_eps
377
+
378
+ # models
379
+ self.visual = VisionTransformer(
380
+ image_size=image_size,
381
+ patch_size=patch_size,
382
+ dim=vision_dim,
383
+ mlp_ratio=vision_mlp_ratio,
384
+ out_dim=embed_dim,
385
+ num_heads=vision_heads,
386
+ num_layers=vision_layers,
387
+ pool_type=vision_pool,
388
+ pre_norm=vision_pre_norm,
389
+ post_norm=vision_post_norm,
390
+ activation=activation,
391
+ attn_dropout=attn_dropout,
392
+ proj_dropout=proj_dropout,
393
+ embedding_dropout=embedding_dropout,
394
+ norm_eps=norm_eps)
395
+ self.textual = XLMRobertaWithHead(
396
+ vocab_size=vocab_size,
397
+ max_seq_len=max_text_len,
398
+ type_size=type_size,
399
+ pad_id=pad_id,
400
+ dim=text_dim,
401
+ out_dim=embed_dim,
402
+ num_heads=text_heads,
403
+ num_layers=text_layers,
404
+ post_norm=text_post_norm,
405
+ dropout=text_dropout)
406
+ self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
407
+
408
+ def forward(self, imgs, txt_ids):
409
+ """
410
+ imgs: [B, 3, H, W] of torch.float32.
411
+ - mean: [0.48145466, 0.4578275, 0.40821073]
412
+ - std: [0.26862954, 0.26130258, 0.27577711]
413
+ txt_ids: [B, L] of torch.long.
414
+ Encoded by data.CLIPTokenizer.
415
+ """
416
+ xi = self.visual(imgs)
417
+ xt = self.textual(txt_ids)
418
+ return xi, xt
419
+
420
+ def param_groups(self):
421
+ groups = [{
422
+ 'params': [
423
+ p for n, p in self.named_parameters()
424
+ if 'norm' in n or n.endswith('bias')
425
+ ],
426
+ 'weight_decay': 0.0
427
+ }, {
428
+ 'params': [
429
+ p for n, p in self.named_parameters()
430
+ if not ('norm' in n or n.endswith('bias'))
431
+ ]
432
+ }]
433
+ return groups
434
+
435
+
436
+ def _clip(pretrained=False,
437
+ pretrained_name=None,
438
+ model_cls=XLMRobertaCLIP,
439
+ return_transforms=False,
440
+ return_tokenizer=False,
441
+ tokenizer_padding='eos',
442
+ dtype=torch.float32,
443
+ device='cpu',
444
+ **kwargs):
445
+ # init a model on device
446
+ with torch.device(device):
447
+ model = model_cls(**kwargs)
448
+
449
+ # set device
450
+ model = model.to(dtype=dtype, device=device)
451
+ output = (model,)
452
+
453
+ # init transforms
454
+ if return_transforms:
455
+ # mean and std
456
+ if 'siglip' in pretrained_name.lower():
457
+ mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
458
+ else:
459
+ mean = [0.48145466, 0.4578275, 0.40821073]
460
+ std = [0.26862954, 0.26130258, 0.27577711]
461
+
462
+ # transforms
463
+ transforms = T.Compose([
464
+ T.Resize((model.image_size, model.image_size),
465
+ interpolation=T.InterpolationMode.BICUBIC),
466
+ T.ToTensor(),
467
+ T.Normalize(mean=mean, std=std)
468
+ ])
469
+ output += (transforms,)
470
+ return output[0] if len(output) == 1 else output
471
+
472
+
473
+ def clip_xlm_roberta_vit_h_14(
474
+ pretrained=False,
475
+ pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
476
+ **kwargs):
477
+ cfg = dict(
478
+ embed_dim=1024,
479
+ image_size=224,
480
+ patch_size=14,
481
+ vision_dim=1280,
482
+ vision_mlp_ratio=4,
483
+ vision_heads=16,
484
+ vision_layers=32,
485
+ vision_pool='token',
486
+ activation='gelu',
487
+ vocab_size=250002,
488
+ max_text_len=514,
489
+ type_size=1,
490
+ pad_id=1,
491
+ text_dim=1024,
492
+ text_heads=16,
493
+ text_layers=24,
494
+ text_post_norm=True,
495
+ text_dropout=0.1,
496
+ attn_dropout=0.0,
497
+ proj_dropout=0.0,
498
+ embedding_dropout=0.0)
499
+ cfg.update(**kwargs)
500
+ return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
501
+
502
+
503
+ class CLIPModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
504
+
505
+ def __init__(self):
506
+ super(CLIPModel, self).__init__()
507
+ # init model
508
+ self.model, self.transforms = clip_xlm_roberta_vit_h_14(
509
+ pretrained=False,
510
+ return_transforms=True,
511
+ return_tokenizer=False)
512
+
513
+ def forward(self, videos):
514
+ # preprocess
515
+ size = (self.model.image_size,) * 2
516
+ videos = torch.cat([
517
+ F.interpolate(
518
+ u.transpose(0, 1),
519
+ size=size,
520
+ mode='bicubic',
521
+ align_corners=False) for u in videos
522
+ ])
523
+ videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
524
+
525
+ # forward
526
+ with torch.cuda.amp.autocast(dtype=self.dtype):
527
+ out = self.model.visual(videos, use_31_block=True)
528
+ return out
529
+
530
+ @classmethod
531
+ def from_pretrained(cls, pretrained_model_path, transformer_additional_kwargs={}):
532
+ def filter_kwargs(cls, kwargs):
533
+ import inspect
534
+ sig = inspect.signature(cls.__init__)
535
+ valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
536
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
537
+ return filtered_kwargs
538
+
539
+ model = cls(**filter_kwargs(cls, transformer_additional_kwargs))
540
+ if pretrained_model_path.endswith(".safetensors"):
541
+ from safetensors.torch import load_file, safe_open
542
+ state_dict = load_file(pretrained_model_path)
543
+ else:
544
+ state_dict = torch.load(pretrained_model_path, map_location="cpu")
545
+ tmp_state_dict = {}
546
+ for key in state_dict:
547
+ tmp_state_dict["model." + key] = state_dict[key]
548
+ state_dict = tmp_state_dict
549
+ m, u = model.load_state_dict(state_dict)
550
+
551
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
552
+ print(m, u)
553
+ return model
cogvideox/models/wan_text_encoder.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/t5.py
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import math
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from diffusers.configuration_utils import ConfigMixin
10
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
11
+ from diffusers.models.modeling_utils import ModelMixin
12
+
13
+
14
+ def fp16_clamp(x):
15
+ if x.dtype == torch.float16 and torch.isinf(x).any():
16
+ clamp = torch.finfo(x.dtype).max - 1000
17
+ x = torch.clamp(x, min=-clamp, max=clamp)
18
+ return x
19
+
20
+
21
+ def init_weights(m):
22
+ if isinstance(m, T5LayerNorm):
23
+ nn.init.ones_(m.weight)
24
+ elif isinstance(m, T5FeedForward):
25
+ nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
26
+ nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
27
+ nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
28
+ elif isinstance(m, T5Attention):
29
+ nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
30
+ nn.init.normal_(m.k.weight, std=m.dim**-0.5)
31
+ nn.init.normal_(m.v.weight, std=m.dim**-0.5)
32
+ nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
33
+ elif isinstance(m, T5RelativeEmbedding):
34
+ nn.init.normal_(
35
+ m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
36
+
37
+
38
+ class GELU(nn.Module):
39
+ def forward(self, x):
40
+ return 0.5 * x * (1.0 + torch.tanh(
41
+ math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
42
+
43
+
44
+ class T5LayerNorm(nn.Module):
45
+ def __init__(self, dim, eps=1e-6):
46
+ super(T5LayerNorm, self).__init__()
47
+ self.dim = dim
48
+ self.eps = eps
49
+ self.weight = nn.Parameter(torch.ones(dim))
50
+
51
+ def forward(self, x):
52
+ x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
53
+ self.eps)
54
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
55
+ x = x.type_as(self.weight)
56
+ return self.weight * x
57
+
58
+
59
+ class T5Attention(nn.Module):
60
+ def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
61
+ assert dim_attn % num_heads == 0
62
+ super(T5Attention, self).__init__()
63
+ self.dim = dim
64
+ self.dim_attn = dim_attn
65
+ self.num_heads = num_heads
66
+ self.head_dim = dim_attn // num_heads
67
+
68
+ # layers
69
+ self.q = nn.Linear(dim, dim_attn, bias=False)
70
+ self.k = nn.Linear(dim, dim_attn, bias=False)
71
+ self.v = nn.Linear(dim, dim_attn, bias=False)
72
+ self.o = nn.Linear(dim_attn, dim, bias=False)
73
+ self.dropout = nn.Dropout(dropout)
74
+
75
+ def forward(self, x, context=None, mask=None, pos_bias=None):
76
+ """
77
+ x: [B, L1, C].
78
+ context: [B, L2, C] or None.
79
+ mask: [B, L2] or [B, L1, L2] or None.
80
+ """
81
+ # check inputs
82
+ context = x if context is None else context
83
+ b, n, c = x.size(0), self.num_heads, self.head_dim
84
+
85
+ # compute query, key, value
86
+ q = self.q(x).view(b, -1, n, c)
87
+ k = self.k(context).view(b, -1, n, c)
88
+ v = self.v(context).view(b, -1, n, c)
89
+
90
+ # attention bias
91
+ attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
92
+ if pos_bias is not None:
93
+ attn_bias += pos_bias
94
+ if mask is not None:
95
+ assert mask.ndim in [2, 3]
96
+ mask = mask.view(b, 1, 1,
97
+ -1) if mask.ndim == 2 else mask.unsqueeze(1)
98
+ attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
99
+
100
+ # compute attention (T5 does not use scaling)
101
+ attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
102
+ attn = F.softmax(attn.float(), dim=-1).type_as(attn)
103
+ x = torch.einsum('bnij,bjnc->binc', attn, v)
104
+
105
+ # output
106
+ x = x.reshape(b, -1, n * c)
107
+ x = self.o(x)
108
+ x = self.dropout(x)
109
+ return x
110
+
111
+
112
+ class T5FeedForward(nn.Module):
113
+
114
+ def __init__(self, dim, dim_ffn, dropout=0.1):
115
+ super(T5FeedForward, self).__init__()
116
+ self.dim = dim
117
+ self.dim_ffn = dim_ffn
118
+
119
+ # layers
120
+ self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
121
+ self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
122
+ self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
123
+ self.dropout = nn.Dropout(dropout)
124
+
125
+ def forward(self, x):
126
+ x = self.fc1(x) * self.gate(x)
127
+ x = self.dropout(x)
128
+ x = self.fc2(x)
129
+ x = self.dropout(x)
130
+ return x
131
+
132
+
133
+ class T5SelfAttention(nn.Module):
134
+ def __init__(self,
135
+ dim,
136
+ dim_attn,
137
+ dim_ffn,
138
+ num_heads,
139
+ num_buckets,
140
+ shared_pos=True,
141
+ dropout=0.1):
142
+ super(T5SelfAttention, self).__init__()
143
+ self.dim = dim
144
+ self.dim_attn = dim_attn
145
+ self.dim_ffn = dim_ffn
146
+ self.num_heads = num_heads
147
+ self.num_buckets = num_buckets
148
+ self.shared_pos = shared_pos
149
+
150
+ # layers
151
+ self.norm1 = T5LayerNorm(dim)
152
+ self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
153
+ self.norm2 = T5LayerNorm(dim)
154
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
155
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
156
+ num_buckets, num_heads, bidirectional=True)
157
+
158
+ def forward(self, x, mask=None, pos_bias=None):
159
+ e = pos_bias if self.shared_pos else self.pos_embedding(
160
+ x.size(1), x.size(1))
161
+ x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
162
+ x = fp16_clamp(x + self.ffn(self.norm2(x)))
163
+ return x
164
+
165
+
166
+ class T5CrossAttention(nn.Module):
167
+ def __init__(self,
168
+ dim,
169
+ dim_attn,
170
+ dim_ffn,
171
+ num_heads,
172
+ num_buckets,
173
+ shared_pos=True,
174
+ dropout=0.1):
175
+ super(T5CrossAttention, self).__init__()
176
+ self.dim = dim
177
+ self.dim_attn = dim_attn
178
+ self.dim_ffn = dim_ffn
179
+ self.num_heads = num_heads
180
+ self.num_buckets = num_buckets
181
+ self.shared_pos = shared_pos
182
+
183
+ # layers
184
+ self.norm1 = T5LayerNorm(dim)
185
+ self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
186
+ self.norm2 = T5LayerNorm(dim)
187
+ self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
188
+ self.norm3 = T5LayerNorm(dim)
189
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
190
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
191
+ num_buckets, num_heads, bidirectional=False)
192
+
193
+ def forward(self,
194
+ x,
195
+ mask=None,
196
+ encoder_states=None,
197
+ encoder_mask=None,
198
+ pos_bias=None):
199
+ e = pos_bias if self.shared_pos else self.pos_embedding(
200
+ x.size(1), x.size(1))
201
+ x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
202
+ x = fp16_clamp(x + self.cross_attn(
203
+ self.norm2(x), context=encoder_states, mask=encoder_mask))
204
+ x = fp16_clamp(x + self.ffn(self.norm3(x)))
205
+ return x
206
+
207
+
208
+ class T5RelativeEmbedding(nn.Module):
209
+ def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
210
+ super(T5RelativeEmbedding, self).__init__()
211
+ self.num_buckets = num_buckets
212
+ self.num_heads = num_heads
213
+ self.bidirectional = bidirectional
214
+ self.max_dist = max_dist
215
+
216
+ # layers
217
+ self.embedding = nn.Embedding(num_buckets, num_heads)
218
+
219
+ def forward(self, lq, lk):
220
+ device = self.embedding.weight.device
221
+ # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
222
+ # torch.arange(lq).unsqueeze(1).to(device)
223
+ if torch.device(type="meta") != device:
224
+ rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
225
+ torch.arange(lq, device=device).unsqueeze(1)
226
+ else:
227
+ rel_pos = torch.arange(lk).unsqueeze(0) - \
228
+ torch.arange(lq).unsqueeze(1)
229
+ rel_pos = self._relative_position_bucket(rel_pos)
230
+ rel_pos_embeds = self.embedding(rel_pos)
231
+ rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
232
+ 0) # [1, N, Lq, Lk]
233
+ return rel_pos_embeds.contiguous()
234
+
235
+ def _relative_position_bucket(self, rel_pos):
236
+ # preprocess
237
+ if self.bidirectional:
238
+ num_buckets = self.num_buckets // 2
239
+ rel_buckets = (rel_pos > 0).long() * num_buckets
240
+ rel_pos = torch.abs(rel_pos)
241
+ else:
242
+ num_buckets = self.num_buckets
243
+ rel_buckets = 0
244
+ rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
245
+
246
+ # embeddings for small and large positions
247
+ max_exact = num_buckets // 2
248
+ rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
249
+ math.log(self.max_dist / max_exact) *
250
+ (num_buckets - max_exact)).long()
251
+ rel_pos_large = torch.min(
252
+ rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
253
+ rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
254
+ return rel_buckets
255
+
256
+ class WanT5EncoderModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
257
+ def __init__(self,
258
+ vocab,
259
+ dim,
260
+ dim_attn,
261
+ dim_ffn,
262
+ num_heads,
263
+ num_layers,
264
+ num_buckets,
265
+ shared_pos=True,
266
+ dropout=0.1):
267
+ super(WanT5EncoderModel, self).__init__()
268
+ self.dim = dim
269
+ self.dim_attn = dim_attn
270
+ self.dim_ffn = dim_ffn
271
+ self.num_heads = num_heads
272
+ self.num_layers = num_layers
273
+ self.num_buckets = num_buckets
274
+ self.shared_pos = shared_pos
275
+
276
+ # layers
277
+ self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
278
+ else nn.Embedding(vocab, dim)
279
+ self.pos_embedding = T5RelativeEmbedding(
280
+ num_buckets, num_heads, bidirectional=True) if shared_pos else None
281
+ self.dropout = nn.Dropout(dropout)
282
+ self.blocks = nn.ModuleList([
283
+ T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
284
+ shared_pos, dropout) for _ in range(num_layers)
285
+ ])
286
+ self.norm = T5LayerNorm(dim)
287
+
288
+ # initialize weights
289
+ self.apply(init_weights)
290
+
291
+ def forward(
292
+ self,
293
+ input_ids: Optional[torch.LongTensor] = None,
294
+ attention_mask: Optional[torch.FloatTensor] = None,
295
+ ):
296
+ x = self.token_embedding(input_ids)
297
+ x = self.dropout(x)
298
+ e = self.pos_embedding(x.size(1),
299
+ x.size(1)) if self.shared_pos else None
300
+ for block in self.blocks:
301
+ x = block(x, attention_mask, pos_bias=e)
302
+ x = self.norm(x)
303
+ x = self.dropout(x)
304
+ return (x, )
305
+
306
+ @classmethod
307
+ def from_pretrained(cls, pretrained_model_path, additional_kwargs={}):
308
+ def filter_kwargs(cls, kwargs):
309
+ import inspect
310
+ sig = inspect.signature(cls.__init__)
311
+ valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
312
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
313
+ return filtered_kwargs
314
+
315
+ model = cls(**filter_kwargs(cls, additional_kwargs))
316
+ if pretrained_model_path.endswith(".safetensors"):
317
+ from safetensors.torch import load_file, safe_open
318
+ state_dict = load_file(pretrained_model_path)
319
+ else:
320
+ state_dict = torch.load(pretrained_model_path, map_location="cpu")
321
+ m, u = model.load_state_dict(state_dict, strict=False)
322
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
323
+ print(m, u)
324
+ return model
cogvideox/models/wan_transformer3d.py ADDED
@@ -0,0 +1,961 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/model.py
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+
4
+ import glob
5
+ import json
6
+ import math
7
+ import os
8
+ import warnings
9
+ from typing import Any, Dict
10
+
11
+ import torch
12
+ import torch.cuda.amp as amp
13
+ import torch.nn as nn
14
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
15
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
16
+ from diffusers.models.modeling_utils import ModelMixin
17
+ from diffusers.utils import is_torch_version, logging
18
+ from torch import nn
19
+
20
+ try:
21
+ import flash_attn_interface
22
+ FLASH_ATTN_3_AVAILABLE = True
23
+ except ModuleNotFoundError:
24
+ FLASH_ATTN_3_AVAILABLE = False
25
+
26
+ try:
27
+ import flash_attn
28
+ FLASH_ATTN_2_AVAILABLE = True
29
+ except ModuleNotFoundError:
30
+ FLASH_ATTN_2_AVAILABLE = False
31
+
32
+
33
+ def flash_attention(
34
+ q,
35
+ k,
36
+ v,
37
+ q_lens=None,
38
+ k_lens=None,
39
+ dropout_p=0.,
40
+ softmax_scale=None,
41
+ q_scale=None,
42
+ causal=False,
43
+ window_size=(-1, -1),
44
+ deterministic=False,
45
+ dtype=torch.bfloat16,
46
+ version=None,
47
+ ):
48
+ """
49
+ q: [B, Lq, Nq, C1].
50
+ k: [B, Lk, Nk, C1].
51
+ v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
52
+ q_lens: [B].
53
+ k_lens: [B].
54
+ dropout_p: float. Dropout probability.
55
+ softmax_scale: float. The scaling of QK^T before applying softmax.
56
+ causal: bool. Whether to apply causal attention mask.
57
+ window_size: (left right). If not (-1, -1), apply sliding window local attention.
58
+ deterministic: bool. If True, slightly slower and uses more memory.
59
+ dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
60
+ """
61
+ half_dtypes = (torch.float16, torch.bfloat16)
62
+ assert dtype in half_dtypes
63
+ assert q.device.type == 'cuda' and q.size(-1) <= 256
64
+
65
+ # params
66
+ b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
67
+
68
+ def half(x):
69
+ return x if x.dtype in half_dtypes else x.to(dtype)
70
+
71
+ # preprocess query
72
+ if q_lens is None:
73
+ q = half(q.flatten(0, 1))
74
+ q_lens = torch.tensor(
75
+ [lq] * b, dtype=torch.int32).to(
76
+ device=q.device, non_blocking=True)
77
+ else:
78
+ q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
79
+
80
+ # preprocess key, value
81
+ if k_lens is None:
82
+ k = half(k.flatten(0, 1))
83
+ v = half(v.flatten(0, 1))
84
+ k_lens = torch.tensor(
85
+ [lk] * b, dtype=torch.int32).to(
86
+ device=k.device, non_blocking=True)
87
+ else:
88
+ k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
89
+ v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
90
+
91
+ q = q.to(v.dtype)
92
+ k = k.to(v.dtype)
93
+
94
+ if q_scale is not None:
95
+ q = q * q_scale
96
+
97
+ if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
98
+ warnings.warn(
99
+ 'Flash attention 3 is not available, use flash attention 2 instead.'
100
+ )
101
+
102
+ # apply attention
103
+ if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
104
+ # Note: dropout_p, window_size are not supported in FA3 now.
105
+ x = flash_attn_interface.flash_attn_varlen_func(
106
+ q=q,
107
+ k=k,
108
+ v=v,
109
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
110
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
111
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
112
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
113
+ seqused_q=None,
114
+ seqused_k=None,
115
+ max_seqlen_q=lq,
116
+ max_seqlen_k=lk,
117
+ softmax_scale=softmax_scale,
118
+ causal=causal,
119
+ deterministic=deterministic)[0].unflatten(0, (b, lq))
120
+ else:
121
+ assert FLASH_ATTN_2_AVAILABLE
122
+ x = flash_attn.flash_attn_varlen_func(
123
+ q=q,
124
+ k=k,
125
+ v=v,
126
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
127
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
128
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
129
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
130
+ max_seqlen_q=lq,
131
+ max_seqlen_k=lk,
132
+ dropout_p=dropout_p,
133
+ softmax_scale=softmax_scale,
134
+ causal=causal,
135
+ window_size=window_size,
136
+ deterministic=deterministic).unflatten(0, (b, lq))
137
+
138
+ # output
139
+ return x.type(out_dtype)
140
+
141
+
142
+ def attention(
143
+ q,
144
+ k,
145
+ v,
146
+ q_lens=None,
147
+ k_lens=None,
148
+ dropout_p=0.,
149
+ softmax_scale=None,
150
+ q_scale=None,
151
+ causal=False,
152
+ window_size=(-1, -1),
153
+ deterministic=False,
154
+ dtype=torch.bfloat16,
155
+ fa_version=None,
156
+ ):
157
+ if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
158
+ return flash_attention(
159
+ q=q,
160
+ k=k,
161
+ v=v,
162
+ q_lens=q_lens,
163
+ k_lens=k_lens,
164
+ dropout_p=dropout_p,
165
+ softmax_scale=softmax_scale,
166
+ q_scale=q_scale,
167
+ causal=causal,
168
+ window_size=window_size,
169
+ deterministic=deterministic,
170
+ dtype=dtype,
171
+ version=fa_version,
172
+ )
173
+ else:
174
+ if q_lens is not None or k_lens is not None:
175
+ warnings.warn(
176
+ 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
177
+ )
178
+ attn_mask = None
179
+
180
+ q = q.transpose(1, 2)
181
+ k = k.transpose(1, 2)
182
+ v = v.transpose(1, 2)
183
+
184
+ out = torch.nn.functional.scaled_dot_product_attention(
185
+ q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
186
+
187
+ out = out.transpose(1, 2).contiguous()
188
+ return out
189
+
190
+
191
+ def sinusoidal_embedding_1d(dim, position):
192
+ # preprocess
193
+ assert dim % 2 == 0
194
+ half = dim // 2
195
+ position = position.type(torch.float64)
196
+
197
+ # calculation
198
+ sinusoid = torch.outer(
199
+ position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
200
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
201
+ return x
202
+
203
+
204
+ @amp.autocast(enabled=False)
205
+ def rope_params(max_seq_len, dim, theta=10000):
206
+ assert dim % 2 == 0
207
+ freqs = torch.outer(
208
+ torch.arange(max_seq_len),
209
+ 1.0 / torch.pow(theta,
210
+ torch.arange(0, dim, 2).to(torch.float64).div(dim)))
211
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
212
+ return freqs
213
+
214
+
215
+ @amp.autocast(enabled=False)
216
+ def rope_apply(x, grid_sizes, freqs):
217
+ n, c = x.size(2), x.size(3) // 2
218
+
219
+ # split freqs
220
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
221
+
222
+ # loop over samples
223
+ output = []
224
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
225
+ seq_len = f * h * w
226
+
227
+ # precompute multipliers
228
+ x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float32).reshape(
229
+ seq_len, n, -1, 2))
230
+ freqs_i = torch.cat([
231
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
232
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
233
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
234
+ ],
235
+ dim=-1).reshape(seq_len, 1, -1)
236
+
237
+ # apply rotary embedding
238
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
239
+ x_i = torch.cat([x_i, x[i, seq_len:]])
240
+
241
+ # append to collection
242
+ output.append(x_i)
243
+ return torch.stack(output).float()
244
+
245
+
246
+ class WanRMSNorm(nn.Module):
247
+
248
+ def __init__(self, dim, eps=1e-5):
249
+ super().__init__()
250
+ self.dim = dim
251
+ self.eps = eps
252
+ self.weight = nn.Parameter(torch.ones(dim))
253
+
254
+ def forward(self, x):
255
+ r"""
256
+ Args:
257
+ x(Tensor): Shape [B, L, C]
258
+ """
259
+ return self._norm(x.float()).type_as(x) * self.weight
260
+
261
+ def _norm(self, x):
262
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
263
+
264
+
265
+ class WanLayerNorm(nn.LayerNorm):
266
+
267
+ def __init__(self, dim, eps=1e-6, elementwise_affine=False):
268
+ super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
269
+
270
+ def forward(self, x):
271
+ r"""
272
+ Args:
273
+ x(Tensor): Shape [B, L, C]
274
+ """
275
+ return super().forward(x.float()).type_as(x)
276
+
277
+
278
+ class WanSelfAttention(nn.Module):
279
+
280
+ def __init__(self,
281
+ dim,
282
+ num_heads,
283
+ window_size=(-1, -1),
284
+ qk_norm=True,
285
+ eps=1e-6):
286
+ assert dim % num_heads == 0
287
+ super().__init__()
288
+ self.dim = dim
289
+ self.num_heads = num_heads
290
+ self.head_dim = dim // num_heads
291
+ self.window_size = window_size
292
+ self.qk_norm = qk_norm
293
+ self.eps = eps
294
+
295
+ # layers
296
+ self.q = nn.Linear(dim, dim)
297
+ self.k = nn.Linear(dim, dim)
298
+ self.v = nn.Linear(dim, dim)
299
+ self.o = nn.Linear(dim, dim)
300
+ self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
301
+ self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
302
+
303
+ def forward(self, x, seq_lens, grid_sizes, freqs, dtype):
304
+ r"""
305
+ Args:
306
+ x(Tensor): Shape [B, L, num_heads, C / num_heads]
307
+ seq_lens(Tensor): Shape [B]
308
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
309
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
310
+ """
311
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
312
+
313
+ # query, key, value function
314
+ def qkv_fn(x):
315
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
316
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
317
+ v = self.v(x).view(b, s, n, d)
318
+ return q, k, v
319
+
320
+ q, k, v = qkv_fn(x)
321
+
322
+ x = attention(
323
+ q=rope_apply(q, grid_sizes, freqs).to(dtype),
324
+ k=rope_apply(k, grid_sizes, freqs).to(dtype),
325
+ v=v.to(dtype),
326
+ k_lens=seq_lens,
327
+ window_size=self.window_size)
328
+ x = x.to(dtype)
329
+
330
+ # output
331
+ x = x.flatten(2)
332
+ x = self.o(x)
333
+ return x
334
+
335
+
336
+ class WanT2VCrossAttention(WanSelfAttention):
337
+
338
+ def forward(self, x, context, context_lens):
339
+ r"""
340
+ Args:
341
+ x(Tensor): Shape [B, L1, C]
342
+ context(Tensor): Shape [B, L2, C]
343
+ context_lens(Tensor): Shape [B]
344
+ """
345
+ b, n, d = x.size(0), self.num_heads, self.head_dim
346
+
347
+ # compute query, key, value
348
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
349
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
350
+ v = self.v(context).view(b, -1, n, d)
351
+
352
+ # compute attention
353
+ x = attention(q, k, v, k_lens=context_lens)
354
+
355
+ # output
356
+ x = x.flatten(2)
357
+ x = self.o(x)
358
+ return x
359
+
360
+
361
+ class WanI2VCrossAttention(WanSelfAttention):
362
+
363
+ def __init__(self,
364
+ dim,
365
+ num_heads,
366
+ window_size=(-1, -1),
367
+ qk_norm=True,
368
+ eps=1e-6):
369
+ super().__init__(dim, num_heads, window_size, qk_norm, eps)
370
+
371
+ self.k_img = nn.Linear(dim, dim)
372
+ self.v_img = nn.Linear(dim, dim)
373
+ # self.alpha = nn.Parameter(torch.zeros((1, )))
374
+ self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
375
+
376
+ def forward(self, x, context, context_lens):
377
+ r"""
378
+ Args:
379
+ x(Tensor): Shape [B, L1, C]
380
+ context(Tensor): Shape [B, L2, C]
381
+ context_lens(Tensor): Shape [B]
382
+ """
383
+ context_img = context[:, :257]
384
+ context = context[:, 257:]
385
+ b, n, d = x.size(0), self.num_heads, self.head_dim
386
+
387
+ # compute query, key, value
388
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
389
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
390
+ v = self.v(context).view(b, -1, n, d)
391
+ k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
392
+ v_img = self.v_img(context_img).view(b, -1, n, d)
393
+ img_x = attention(q, k_img, v_img, k_lens=None)
394
+ # compute attention
395
+ x = attention(q, k, v, k_lens=context_lens)
396
+
397
+ # output
398
+ x = x.flatten(2)
399
+ img_x = img_x.flatten(2)
400
+ x = x + img_x
401
+ x = self.o(x)
402
+ return x
403
+
404
+
405
+ WAN_CROSSATTENTION_CLASSES = {
406
+ 't2v_cross_attn': WanT2VCrossAttention,
407
+ 'i2v_cross_attn': WanI2VCrossAttention,
408
+ }
409
+
410
+
411
+ class WanAttentionBlock(nn.Module):
412
+
413
+ def __init__(self,
414
+ cross_attn_type,
415
+ dim,
416
+ ffn_dim,
417
+ num_heads,
418
+ window_size=(-1, -1),
419
+ qk_norm=True,
420
+ cross_attn_norm=False,
421
+ eps=1e-6):
422
+ super().__init__()
423
+ self.dim = dim
424
+ self.ffn_dim = ffn_dim
425
+ self.num_heads = num_heads
426
+ self.window_size = window_size
427
+ self.qk_norm = qk_norm
428
+ self.cross_attn_norm = cross_attn_norm
429
+ self.eps = eps
430
+
431
+ # layers
432
+ self.norm1 = WanLayerNorm(dim, eps)
433
+ self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
434
+ eps)
435
+ self.norm3 = WanLayerNorm(
436
+ dim, eps,
437
+ elementwise_affine=True) if cross_attn_norm else nn.Identity()
438
+ self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
439
+ num_heads,
440
+ (-1, -1),
441
+ qk_norm,
442
+ eps)
443
+ self.norm2 = WanLayerNorm(dim, eps)
444
+ self.ffn = nn.Sequential(
445
+ nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
446
+ nn.Linear(ffn_dim, dim))
447
+
448
+ # modulation
449
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
450
+
451
+ def forward(
452
+ self,
453
+ x,
454
+ e,
455
+ seq_lens,
456
+ grid_sizes,
457
+ freqs,
458
+ context,
459
+ context_lens,
460
+ dtype=torch.float32
461
+ ):
462
+ r"""
463
+ Args:
464
+ x(Tensor): Shape [B, L, C]
465
+ e(Tensor): Shape [B, 6, C]
466
+ seq_lens(Tensor): Shape [B], length of each sequence in batch
467
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
468
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
469
+ """
470
+ e = (self.modulation + e).chunk(6, dim=1)
471
+
472
+ # self-attention
473
+ temp_x = self.norm1(x) * (1 + e[1]) + e[0]
474
+ temp_x = temp_x.to(dtype)
475
+
476
+ y = self.self_attn(temp_x, seq_lens, grid_sizes, freqs, dtype)
477
+ x = x + y * e[2]
478
+
479
+ # cross-attention & ffn function
480
+ def cross_attn_ffn(x, context, context_lens, e):
481
+ x = x + self.cross_attn(self.norm3(x), context, context_lens)
482
+ temp_x = self.norm2(x) * (1 + e[4]) + e[3]
483
+ temp_x = temp_x.to(dtype)
484
+
485
+ y = self.ffn(temp_x)
486
+ x = x + y * e[5]
487
+ return x
488
+
489
+ x = cross_attn_ffn(x, context, context_lens, e)
490
+ return x
491
+
492
+
493
+ class Head(nn.Module):
494
+
495
+ def __init__(self, dim, out_dim, patch_size, eps=1e-6):
496
+ super().__init__()
497
+ self.dim = dim
498
+ self.out_dim = out_dim
499
+ self.patch_size = patch_size
500
+ self.eps = eps
501
+
502
+ # layers
503
+ out_dim = math.prod(patch_size) * out_dim
504
+ self.norm = WanLayerNorm(dim, eps)
505
+ self.head = nn.Linear(dim, out_dim)
506
+
507
+ # modulation
508
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
509
+
510
+ def forward(self, x, e):
511
+ r"""
512
+ Args:
513
+ x(Tensor): Shape [B, L1, C]
514
+ e(Tensor): Shape [B, C]
515
+ """
516
+ e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
517
+ x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
518
+ return x
519
+
520
+
521
+ class MLPProj(torch.nn.Module):
522
+
523
+ def __init__(self, in_dim, out_dim):
524
+ super().__init__()
525
+
526
+ self.proj = torch.nn.Sequential(
527
+ torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
528
+ torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
529
+ torch.nn.LayerNorm(out_dim))
530
+
531
+ def forward(self, image_embeds):
532
+ clip_extra_context_tokens = self.proj(image_embeds)
533
+ return clip_extra_context_tokens
534
+
535
+
536
+
537
+ class WanTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
538
+ r"""
539
+ Wan diffusion backbone supporting both text-to-video and image-to-video.
540
+ """
541
+
542
+ # ignore_for_config = [
543
+ # 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
544
+ # ]
545
+ # _no_split_modules = ['WanAttentionBlock']
546
+ _supports_gradient_checkpointing = True
547
+
548
+ @register_to_config
549
+ def __init__(
550
+ self,
551
+ model_type='t2v',
552
+ patch_size=(1, 2, 2),
553
+ text_len=512,
554
+ in_dim=16,
555
+ dim=2048,
556
+ ffn_dim=8192,
557
+ freq_dim=256,
558
+ text_dim=4096,
559
+ out_dim=16,
560
+ num_heads=16,
561
+ num_layers=32,
562
+ window_size=(-1, -1),
563
+ qk_norm=True,
564
+ cross_attn_norm=True,
565
+ eps=1e-6,
566
+ in_channels=16,
567
+ hidden_size=2048,
568
+ ):
569
+ r"""
570
+ Initialize the diffusion model backbone.
571
+
572
+ Args:
573
+ model_type (`str`, *optional*, defaults to 't2v'):
574
+ Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
575
+ patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
576
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
577
+ text_len (`int`, *optional*, defaults to 512):
578
+ Fixed length for text embeddings
579
+ in_dim (`int`, *optional*, defaults to 16):
580
+ Input video channels (C_in)
581
+ dim (`int`, *optional*, defaults to 2048):
582
+ Hidden dimension of the transformer
583
+ ffn_dim (`int`, *optional*, defaults to 8192):
584
+ Intermediate dimension in feed-forward network
585
+ freq_dim (`int`, *optional*, defaults to 256):
586
+ Dimension for sinusoidal time embeddings
587
+ text_dim (`int`, *optional*, defaults to 4096):
588
+ Input dimension for text embeddings
589
+ out_dim (`int`, *optional*, defaults to 16):
590
+ Output video channels (C_out)
591
+ num_heads (`int`, *optional*, defaults to 16):
592
+ Number of attention heads
593
+ num_layers (`int`, *optional*, defaults to 32):
594
+ Number of transformer blocks
595
+ window_size (`tuple`, *optional*, defaults to (-1, -1)):
596
+ Window size for local attention (-1 indicates global attention)
597
+ qk_norm (`bool`, *optional*, defaults to True):
598
+ Enable query/key normalization
599
+ cross_attn_norm (`bool`, *optional*, defaults to False):
600
+ Enable cross-attention normalization
601
+ eps (`float`, *optional*, defaults to 1e-6):
602
+ Epsilon value for normalization layers
603
+ """
604
+
605
+ super().__init__()
606
+
607
+ assert model_type in ['t2v', 'i2v']
608
+ self.model_type = model_type
609
+
610
+ self.patch_size = patch_size
611
+ self.text_len = text_len
612
+ self.in_dim = in_dim
613
+ self.dim = dim
614
+ self.ffn_dim = ffn_dim
615
+ self.freq_dim = freq_dim
616
+ self.text_dim = text_dim
617
+ self.out_dim = out_dim
618
+ self.num_heads = num_heads
619
+ self.num_layers = num_layers
620
+ self.window_size = window_size
621
+ self.qk_norm = qk_norm
622
+ self.cross_attn_norm = cross_attn_norm
623
+ self.eps = eps
624
+
625
+ # embeddings
626
+ self.patch_embedding = nn.Conv3d(
627
+ in_dim, dim, kernel_size=patch_size, stride=patch_size)
628
+ self.text_embedding = nn.Sequential(
629
+ nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
630
+ nn.Linear(dim, dim))
631
+
632
+ self.time_embedding = nn.Sequential(
633
+ nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
634
+ self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
635
+
636
+ # blocks
637
+ cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
638
+ self.blocks = nn.ModuleList([
639
+ WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
640
+ window_size, qk_norm, cross_attn_norm, eps)
641
+ for _ in range(num_layers)
642
+ ])
643
+
644
+ # head
645
+ self.head = Head(dim, out_dim, patch_size, eps)
646
+
647
+ # buffers (don't use register_buffer otherwise dtype will be changed in to())
648
+ assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
649
+ d = dim // num_heads
650
+ self.freqs = torch.cat(
651
+ [
652
+ rope_params(1024, d - 4 * (d // 6)),
653
+ rope_params(1024, 2 * (d // 6)),
654
+ rope_params(1024, 2 * (d // 6))
655
+ ],
656
+ dim=1
657
+ )
658
+
659
+ if model_type == 'i2v':
660
+ self.img_emb = MLPProj(1280, dim)
661
+
662
+ self.gradient_checkpointing = False
663
+
664
+ def _set_gradient_checkpointing(self, module, value=False):
665
+ self.gradient_checkpointing = value
666
+
667
+ def forward(
668
+ self,
669
+ x,
670
+ t,
671
+ context,
672
+ seq_len,
673
+ clip_fea=None,
674
+ y=None,
675
+ ):
676
+ r"""
677
+ Forward pass through the diffusion model
678
+
679
+ Args:
680
+ x (List[Tensor]):
681
+ List of input video tensors, each with shape [C_in, F, H, W]
682
+ t (Tensor):
683
+ Diffusion timesteps tensor of shape [B]
684
+ context (List[Tensor]):
685
+ List of text embeddings each with shape [L, C]
686
+ seq_len (`int`):
687
+ Maximum sequence length for positional encoding
688
+ clip_fea (Tensor, *optional*):
689
+ CLIP image features for image-to-video mode
690
+ y (List[Tensor], *optional*):
691
+ Conditional video inputs for image-to-video mode, same shape as x
692
+
693
+ Returns:
694
+ List[Tensor]:
695
+ List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
696
+ """
697
+ if self.model_type == 'i2v':
698
+ assert clip_fea is not None and y is not None
699
+ # params
700
+ device = self.patch_embedding.weight.device
701
+ dtype = x.dtype
702
+ if self.freqs.device != device and torch.device(type="meta") != device:
703
+ self.freqs = self.freqs.to(device)
704
+
705
+ if y is not None:
706
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
707
+
708
+ # embeddings
709
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
710
+ grid_sizes = torch.stack(
711
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
712
+ x = [u.flatten(2).transpose(1, 2) for u in x]
713
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
714
+ assert seq_lens.max() <= seq_len
715
+ x = torch.cat([
716
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
717
+ dim=1) for u in x
718
+ ])
719
+
720
+ # time embeddings
721
+ with amp.autocast(dtype=torch.float32):
722
+ e = self.time_embedding(
723
+ sinusoidal_embedding_1d(self.freq_dim, t).float())
724
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
725
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
726
+ # to bfloat16 for saving memeory
727
+ e0 = e0.to(dtype)
728
+ e = e.to(dtype)
729
+
730
+ # context
731
+ context_lens = None
732
+ context = self.text_embedding(
733
+ torch.stack([
734
+ torch.cat(
735
+ [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
736
+ for u in context
737
+ ]))
738
+
739
+ if clip_fea is not None:
740
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
741
+ context = torch.concat([context_clip, context], dim=1)
742
+
743
+
744
+ for block in self.blocks:
745
+ if self.training and self.gradient_checkpointing:
746
+
747
+ def create_custom_forward(module):
748
+ def custom_forward(*inputs):
749
+ return module(*inputs)
750
+
751
+ return custom_forward
752
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
753
+ x = torch.utils.checkpoint.checkpoint(
754
+ create_custom_forward(block),
755
+ x,
756
+ e0,
757
+ seq_lens,
758
+ grid_sizes,
759
+ self.freqs,
760
+ context,
761
+ context_lens,
762
+ dtype,
763
+ **ckpt_kwargs,
764
+ )
765
+ else:
766
+ # arguments
767
+ kwargs = dict(
768
+ e=e0,
769
+ seq_lens=seq_lens,
770
+ grid_sizes=grid_sizes,
771
+ freqs=self.freqs,
772
+ context=context,
773
+ context_lens=context_lens,
774
+ dtype=dtype
775
+ )
776
+ x = block(x, **kwargs)
777
+
778
+ # head
779
+ x = self.head(x, e)
780
+
781
+ # unpatchify
782
+ x = self.unpatchify(x, grid_sizes)
783
+ x = torch.stack(x)
784
+ return x
785
+
786
+
787
+ def unpatchify(self, x, grid_sizes):
788
+ r"""
789
+ Reconstruct video tensors from patch embeddings.
790
+
791
+ Args:
792
+ x (List[Tensor]):
793
+ List of patchified features, each with shape [L, C_out * prod(patch_size)]
794
+ grid_sizes (Tensor):
795
+ Original spatial-temporal grid dimensions before patching,
796
+ shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
797
+
798
+ Returns:
799
+ List[Tensor]:
800
+ Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
801
+ """
802
+
803
+ c = self.out_dim
804
+ out = []
805
+ for u, v in zip(x, grid_sizes.tolist()):
806
+ u = u[:math.prod(v)].view(*v, *self.patch_size, c)
807
+ u = torch.einsum('fhwpqrc->cfphqwr', u)
808
+ u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
809
+ out.append(u)
810
+ return out
811
+
812
+ def init_weights(self):
813
+ r"""
814
+ Initialize model parameters using Xavier initialization.
815
+ """
816
+
817
+ # basic init
818
+ for m in self.modules():
819
+ if isinstance(m, nn.Linear):
820
+ nn.init.xavier_uniform_(m.weight)
821
+ if m.bias is not None:
822
+ nn.init.zeros_(m.bias)
823
+
824
+ # init embeddings
825
+ nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
826
+ for m in self.text_embedding.modules():
827
+ if isinstance(m, nn.Linear):
828
+ nn.init.normal_(m.weight, std=.02)
829
+ for m in self.time_embedding.modules():
830
+ if isinstance(m, nn.Linear):
831
+ nn.init.normal_(m.weight, std=.02)
832
+
833
+ # init output layer
834
+ nn.init.zeros_(self.head.head.weight)
835
+
836
+ @classmethod
837
+ def from_pretrained(
838
+ cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
839
+ low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
840
+ ):
841
+ if subfolder is not None:
842
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
843
+ print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
844
+
845
+ config_file = os.path.join(pretrained_model_path, 'config.json')
846
+ if not os.path.isfile(config_file):
847
+ raise RuntimeError(f"{config_file} does not exist")
848
+ with open(config_file, "r") as f:
849
+ config = json.load(f)
850
+
851
+ from diffusers.utils import WEIGHTS_NAME
852
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
853
+ model_file_safetensors = model_file.replace(".bin", ".safetensors")
854
+
855
+ if "dict_mapping" in transformer_additional_kwargs.keys():
856
+ for key in transformer_additional_kwargs["dict_mapping"]:
857
+ transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key]
858
+
859
+ if low_cpu_mem_usage:
860
+ try:
861
+ import re
862
+
863
+ from diffusers.models.modeling_utils import \
864
+ load_model_dict_into_meta
865
+ from diffusers.utils import is_accelerate_available
866
+ if is_accelerate_available():
867
+ import accelerate
868
+
869
+ # Instantiate model with empty weights
870
+ with accelerate.init_empty_weights():
871
+ model = cls.from_config(config, **transformer_additional_kwargs)
872
+
873
+ param_device = "cpu"
874
+ if os.path.exists(model_file):
875
+ state_dict = torch.load(model_file, map_location="cpu")
876
+ elif os.path.exists(model_file_safetensors):
877
+ from safetensors.torch import load_file, safe_open
878
+ state_dict = load_file(model_file_safetensors)
879
+ else:
880
+ from safetensors.torch import load_file, safe_open
881
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
882
+ state_dict = {}
883
+ print(model_files_safetensors)
884
+ for _model_file_safetensors in model_files_safetensors:
885
+ _state_dict = load_file(_model_file_safetensors)
886
+ for key in _state_dict:
887
+ state_dict[key] = _state_dict[key]
888
+ model._convert_deprecated_attention_blocks(state_dict)
889
+ # move the params from meta device to cpu
890
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
891
+ if len(missing_keys) > 0:
892
+ raise ValueError(
893
+ f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
894
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
895
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
896
+ " those weights or else make sure your checkpoint file is correct."
897
+ )
898
+
899
+ unexpected_keys = load_model_dict_into_meta(
900
+ model,
901
+ state_dict,
902
+ device=param_device,
903
+ dtype=torch_dtype,
904
+ model_name_or_path=pretrained_model_path,
905
+ )
906
+
907
+ if cls._keys_to_ignore_on_load_unexpected is not None:
908
+ for pat in cls._keys_to_ignore_on_load_unexpected:
909
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
910
+
911
+ if len(unexpected_keys) > 0:
912
+ print(
913
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
914
+ )
915
+ return model
916
+ except Exception as e:
917
+ print(
918
+ f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
919
+ )
920
+
921
+ model = cls.from_config(config, **transformer_additional_kwargs)
922
+ if os.path.exists(model_file):
923
+ state_dict = torch.load(model_file, map_location="cpu")
924
+ elif os.path.exists(model_file_safetensors):
925
+ from safetensors.torch import load_file, safe_open
926
+ state_dict = load_file(model_file_safetensors)
927
+ else:
928
+ from safetensors.torch import load_file, safe_open
929
+ model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
930
+ state_dict = {}
931
+ for _model_file_safetensors in model_files_safetensors:
932
+ _state_dict = load_file(_model_file_safetensors)
933
+ for key in _state_dict:
934
+ state_dict[key] = _state_dict[key]
935
+
936
+ if model.state_dict()['patch_embedding.weight'].size() != state_dict['patch_embedding.weight'].size():
937
+ model.state_dict()['patch_embedding.weight'][:, :state_dict['patch_embedding.weight'].size()[1], :, :] = state_dict['patch_embedding.weight']
938
+ model.state_dict()['patch_embedding.weight'][:, state_dict['patch_embedding.weight'].size()[1]:, :, :] = 0
939
+ state_dict['patch_embedding.weight'] = model.state_dict()['patch_embedding.weight']
940
+
941
+ tmp_state_dict = {}
942
+ for key in state_dict:
943
+ if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
944
+ tmp_state_dict[key] = state_dict[key]
945
+ else:
946
+ print(key, "Size don't match, skip")
947
+
948
+ state_dict = tmp_state_dict
949
+
950
+ m, u = model.load_state_dict(state_dict, strict=False)
951
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
952
+ print(m)
953
+
954
+ params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
955
+ print(f"### All Parameters: {sum(params) / 1e6} M")
956
+
957
+ params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
958
+ print(f"### attn1 Parameters: {sum(params) / 1e6} M")
959
+
960
+ model = model.to(torch_dtype)
961
+ return model
cogvideox/models/wan_vae.py ADDED
@@ -0,0 +1,706 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/vae.py
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ from typing import Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
10
+ from diffusers.models.autoencoders.vae import (DecoderOutput,
11
+ DiagonalGaussianDistribution)
12
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
13
+ from diffusers.models.modeling_utils import ModelMixin
14
+ from diffusers.utils.accelerate_utils import apply_forward_hook
15
+ from einops import rearrange
16
+
17
+ CACHE_T = 2
18
+
19
+
20
+ class CausalConv3d(nn.Conv3d):
21
+ """
22
+ Causal 3d convolusion.
23
+ """
24
+
25
+ def __init__(self, *args, **kwargs):
26
+ super().__init__(*args, **kwargs)
27
+ self._padding = (self.padding[2], self.padding[2], self.padding[1],
28
+ self.padding[1], 2 * self.padding[0], 0)
29
+ self.padding = (0, 0, 0)
30
+
31
+ def forward(self, x, cache_x=None):
32
+ padding = list(self._padding)
33
+ if cache_x is not None and self._padding[4] > 0:
34
+ cache_x = cache_x.to(x.device)
35
+ x = torch.cat([cache_x, x], dim=2)
36
+ padding[4] -= cache_x.shape[2]
37
+ x = F.pad(x, padding)
38
+
39
+ return super().forward(x)
40
+
41
+
42
+ class RMS_norm(nn.Module):
43
+
44
+ def __init__(self, dim, channel_first=True, images=True, bias=False):
45
+ super().__init__()
46
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
47
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
48
+
49
+ self.channel_first = channel_first
50
+ self.scale = dim**0.5
51
+ self.gamma = nn.Parameter(torch.ones(shape))
52
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
53
+
54
+ def forward(self, x):
55
+ return F.normalize(
56
+ x, dim=(1 if self.channel_first else
57
+ -1)) * self.scale * self.gamma + self.bias
58
+
59
+
60
+ class Upsample(nn.Upsample):
61
+
62
+ def forward(self, x):
63
+ """
64
+ Fix bfloat16 support for nearest neighbor interpolation.
65
+ """
66
+ return super().forward(x.float()).type_as(x)
67
+
68
+
69
+ class Resample(nn.Module):
70
+
71
+ def __init__(self, dim, mode):
72
+ assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
73
+ 'downsample3d')
74
+ super().__init__()
75
+ self.dim = dim
76
+ self.mode = mode
77
+
78
+ # layers
79
+ if mode == 'upsample2d':
80
+ self.resample = nn.Sequential(
81
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
82
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
83
+ elif mode == 'upsample3d':
84
+ self.resample = nn.Sequential(
85
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
86
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
87
+ self.time_conv = CausalConv3d(
88
+ dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
89
+
90
+ elif mode == 'downsample2d':
91
+ self.resample = nn.Sequential(
92
+ nn.ZeroPad2d((0, 1, 0, 1)),
93
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
94
+ elif mode == 'downsample3d':
95
+ self.resample = nn.Sequential(
96
+ nn.ZeroPad2d((0, 1, 0, 1)),
97
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
98
+ self.time_conv = CausalConv3d(
99
+ dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
100
+
101
+ else:
102
+ self.resample = nn.Identity()
103
+
104
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
105
+ b, c, t, h, w = x.size()
106
+ if self.mode == 'upsample3d':
107
+ if feat_cache is not None:
108
+ idx = feat_idx[0]
109
+ if feat_cache[idx] is None:
110
+ feat_cache[idx] = 'Rep'
111
+ feat_idx[0] += 1
112
+ else:
113
+
114
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
115
+ if cache_x.shape[2] < 2 and feat_cache[
116
+ idx] is not None and feat_cache[idx] != 'Rep':
117
+ # cache last frame of last two chunk
118
+ cache_x = torch.cat([
119
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
120
+ cache_x.device), cache_x
121
+ ],
122
+ dim=2)
123
+ if cache_x.shape[2] < 2 and feat_cache[
124
+ idx] is not None and feat_cache[idx] == 'Rep':
125
+ cache_x = torch.cat([
126
+ torch.zeros_like(cache_x).to(cache_x.device),
127
+ cache_x
128
+ ],
129
+ dim=2)
130
+ if feat_cache[idx] == 'Rep':
131
+ x = self.time_conv(x)
132
+ else:
133
+ x = self.time_conv(x, feat_cache[idx])
134
+ feat_cache[idx] = cache_x
135
+ feat_idx[0] += 1
136
+
137
+ x = x.reshape(b, 2, c, t, h, w)
138
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
139
+ 3)
140
+ x = x.reshape(b, c, t * 2, h, w)
141
+ t = x.shape[2]
142
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
143
+ x = self.resample(x)
144
+ x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
145
+
146
+ if self.mode == 'downsample3d':
147
+ if feat_cache is not None:
148
+ idx = feat_idx[0]
149
+ if feat_cache[idx] is None:
150
+ feat_cache[idx] = x.clone()
151
+ feat_idx[0] += 1
152
+ else:
153
+
154
+ cache_x = x[:, :, -1:, :, :].clone()
155
+ # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
156
+ # # cache last frame of last two chunk
157
+ # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
158
+
159
+ x = self.time_conv(
160
+ torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
161
+ feat_cache[idx] = cache_x
162
+ feat_idx[0] += 1
163
+ return x
164
+
165
+ def init_weight(self, conv):
166
+ conv_weight = conv.weight
167
+ nn.init.zeros_(conv_weight)
168
+ c1, c2, t, h, w = conv_weight.size()
169
+ one_matrix = torch.eye(c1, c2)
170
+ init_matrix = one_matrix
171
+ nn.init.zeros_(conv_weight)
172
+ #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
173
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
174
+ conv.weight.data.copy_(conv_weight)
175
+ nn.init.zeros_(conv.bias.data)
176
+
177
+ def init_weight2(self, conv):
178
+ conv_weight = conv.weight.data
179
+ nn.init.zeros_(conv_weight)
180
+ c1, c2, t, h, w = conv_weight.size()
181
+ init_matrix = torch.eye(c1 // 2, c2)
182
+ #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
183
+ conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
184
+ conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
185
+ conv.weight.data.copy_(conv_weight)
186
+ nn.init.zeros_(conv.bias.data)
187
+
188
+
189
+ class ResidualBlock(nn.Module):
190
+
191
+ def __init__(self, in_dim, out_dim, dropout=0.0):
192
+ super().__init__()
193
+ self.in_dim = in_dim
194
+ self.out_dim = out_dim
195
+
196
+ # layers
197
+ self.residual = nn.Sequential(
198
+ RMS_norm(in_dim, images=False), nn.SiLU(),
199
+ CausalConv3d(in_dim, out_dim, 3, padding=1),
200
+ RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
201
+ CausalConv3d(out_dim, out_dim, 3, padding=1))
202
+ self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
203
+ if in_dim != out_dim else nn.Identity()
204
+
205
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
206
+ h = self.shortcut(x)
207
+ for layer in self.residual:
208
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
209
+ idx = feat_idx[0]
210
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
211
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
212
+ # cache last frame of last two chunk
213
+ cache_x = torch.cat([
214
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
215
+ cache_x.device), cache_x
216
+ ],
217
+ dim=2)
218
+ x = layer(x, feat_cache[idx])
219
+ feat_cache[idx] = cache_x
220
+ feat_idx[0] += 1
221
+ else:
222
+ x = layer(x)
223
+ return x + h
224
+
225
+
226
+ class AttentionBlock(nn.Module):
227
+ """
228
+ Causal self-attention with a single head.
229
+ """
230
+
231
+ def __init__(self, dim):
232
+ super().__init__()
233
+ self.dim = dim
234
+
235
+ # layers
236
+ self.norm = RMS_norm(dim)
237
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
238
+ self.proj = nn.Conv2d(dim, dim, 1)
239
+
240
+ # zero out the last layer params
241
+ nn.init.zeros_(self.proj.weight)
242
+
243
+ def forward(self, x):
244
+ identity = x
245
+ b, c, t, h, w = x.size()
246
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
247
+ x = self.norm(x)
248
+ # compute query, key, value
249
+ q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3,
250
+ -1).permute(0, 1, 3,
251
+ 2).contiguous().chunk(
252
+ 3, dim=-1)
253
+
254
+ # apply attention
255
+ x = F.scaled_dot_product_attention(
256
+ q,
257
+ k,
258
+ v,
259
+ )
260
+ x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
261
+
262
+ # output
263
+ x = self.proj(x)
264
+ x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
265
+ return x + identity
266
+
267
+
268
+ class Encoder3d(nn.Module):
269
+
270
+ def __init__(self,
271
+ dim=128,
272
+ z_dim=4,
273
+ dim_mult=[1, 2, 4, 4],
274
+ num_res_blocks=2,
275
+ attn_scales=[],
276
+ temperal_downsample=[True, True, False],
277
+ dropout=0.0):
278
+ super().__init__()
279
+ self.dim = dim
280
+ self.z_dim = z_dim
281
+ self.dim_mult = dim_mult
282
+ self.num_res_blocks = num_res_blocks
283
+ self.attn_scales = attn_scales
284
+ self.temperal_downsample = temperal_downsample
285
+
286
+ # dimensions
287
+ dims = [dim * u for u in [1] + dim_mult]
288
+ scale = 1.0
289
+
290
+ # init block
291
+ self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
292
+
293
+ # downsample blocks
294
+ downsamples = []
295
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
296
+ # residual (+attention) blocks
297
+ for _ in range(num_res_blocks):
298
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
299
+ if scale in attn_scales:
300
+ downsamples.append(AttentionBlock(out_dim))
301
+ in_dim = out_dim
302
+
303
+ # downsample block
304
+ if i != len(dim_mult) - 1:
305
+ mode = 'downsample3d' if temperal_downsample[
306
+ i] else 'downsample2d'
307
+ downsamples.append(Resample(out_dim, mode=mode))
308
+ scale /= 2.0
309
+ self.downsamples = nn.Sequential(*downsamples)
310
+
311
+ # middle blocks
312
+ self.middle = nn.Sequential(
313
+ ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
314
+ ResidualBlock(out_dim, out_dim, dropout))
315
+
316
+ # output blocks
317
+ self.head = nn.Sequential(
318
+ RMS_norm(out_dim, images=False), nn.SiLU(),
319
+ CausalConv3d(out_dim, z_dim, 3, padding=1))
320
+
321
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
322
+ if feat_cache is not None:
323
+ idx = feat_idx[0]
324
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
325
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
326
+ # cache last frame of last two chunk
327
+ cache_x = torch.cat([
328
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
329
+ cache_x.device), cache_x
330
+ ],
331
+ dim=2)
332
+ x = self.conv1(x, feat_cache[idx])
333
+ feat_cache[idx] = cache_x
334
+ feat_idx[0] += 1
335
+ else:
336
+ x = self.conv1(x)
337
+
338
+ ## downsamples
339
+ for layer in self.downsamples:
340
+ if feat_cache is not None:
341
+ x = layer(x, feat_cache, feat_idx)
342
+ else:
343
+ x = layer(x)
344
+
345
+ ## middle
346
+ for layer in self.middle:
347
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
348
+ x = layer(x, feat_cache, feat_idx)
349
+ else:
350
+ x = layer(x)
351
+
352
+ ## head
353
+ for layer in self.head:
354
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
355
+ idx = feat_idx[0]
356
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
357
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
358
+ # cache last frame of last two chunk
359
+ cache_x = torch.cat([
360
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
361
+ cache_x.device), cache_x
362
+ ],
363
+ dim=2)
364
+ x = layer(x, feat_cache[idx])
365
+ feat_cache[idx] = cache_x
366
+ feat_idx[0] += 1
367
+ else:
368
+ x = layer(x)
369
+ return x
370
+
371
+
372
+ class Decoder3d(nn.Module):
373
+
374
+ def __init__(self,
375
+ dim=128,
376
+ z_dim=4,
377
+ dim_mult=[1, 2, 4, 4],
378
+ num_res_blocks=2,
379
+ attn_scales=[],
380
+ temperal_upsample=[False, True, True],
381
+ dropout=0.0):
382
+ super().__init__()
383
+ self.dim = dim
384
+ self.z_dim = z_dim
385
+ self.dim_mult = dim_mult
386
+ self.num_res_blocks = num_res_blocks
387
+ self.attn_scales = attn_scales
388
+ self.temperal_upsample = temperal_upsample
389
+
390
+ # dimensions
391
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
392
+ scale = 1.0 / 2**(len(dim_mult) - 2)
393
+
394
+ # init block
395
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
396
+
397
+ # middle blocks
398
+ self.middle = nn.Sequential(
399
+ ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
400
+ ResidualBlock(dims[0], dims[0], dropout))
401
+
402
+ # upsample blocks
403
+ upsamples = []
404
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
405
+ # residual (+attention) blocks
406
+ if i == 1 or i == 2 or i == 3:
407
+ in_dim = in_dim // 2
408
+ for _ in range(num_res_blocks + 1):
409
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
410
+ if scale in attn_scales:
411
+ upsamples.append(AttentionBlock(out_dim))
412
+ in_dim = out_dim
413
+
414
+ # upsample block
415
+ if i != len(dim_mult) - 1:
416
+ mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
417
+ upsamples.append(Resample(out_dim, mode=mode))
418
+ scale *= 2.0
419
+ self.upsamples = nn.Sequential(*upsamples)
420
+
421
+ # output blocks
422
+ self.head = nn.Sequential(
423
+ RMS_norm(out_dim, images=False), nn.SiLU(),
424
+ CausalConv3d(out_dim, 3, 3, padding=1))
425
+
426
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
427
+ ## conv1
428
+ if feat_cache is not None:
429
+ idx = feat_idx[0]
430
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
431
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
432
+ # cache last frame of last two chunk
433
+ cache_x = torch.cat([
434
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
435
+ cache_x.device), cache_x
436
+ ],
437
+ dim=2)
438
+ x = self.conv1(x, feat_cache[idx])
439
+ feat_cache[idx] = cache_x
440
+ feat_idx[0] += 1
441
+ else:
442
+ x = self.conv1(x)
443
+
444
+ ## middle
445
+ for layer in self.middle:
446
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
447
+ x = layer(x, feat_cache, feat_idx)
448
+ else:
449
+ x = layer(x)
450
+
451
+ ## upsamples
452
+ for layer in self.upsamples:
453
+ if feat_cache is not None:
454
+ x = layer(x, feat_cache, feat_idx)
455
+ else:
456
+ x = layer(x)
457
+
458
+ ## head
459
+ for layer in self.head:
460
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
461
+ idx = feat_idx[0]
462
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
463
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
464
+ # cache last frame of last two chunk
465
+ cache_x = torch.cat([
466
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
467
+ cache_x.device), cache_x
468
+ ],
469
+ dim=2)
470
+ x = layer(x, feat_cache[idx])
471
+ feat_cache[idx] = cache_x
472
+ feat_idx[0] += 1
473
+ else:
474
+ x = layer(x)
475
+ return x
476
+
477
+
478
+ def count_conv3d(model):
479
+ count = 0
480
+ for m in model.modules():
481
+ if isinstance(m, CausalConv3d):
482
+ count += 1
483
+ return count
484
+
485
+
486
+ class AutoencoderKLWan_(nn.Module):
487
+
488
+ def __init__(self,
489
+ dim=128,
490
+ z_dim=4,
491
+ dim_mult=[1, 2, 4, 4],
492
+ num_res_blocks=2,
493
+ attn_scales=[],
494
+ temperal_downsample=[True, True, False],
495
+ dropout=0.0):
496
+ super().__init__()
497
+ self.dim = dim
498
+ self.z_dim = z_dim
499
+ self.dim_mult = dim_mult
500
+ self.num_res_blocks = num_res_blocks
501
+ self.attn_scales = attn_scales
502
+ self.temperal_downsample = temperal_downsample
503
+ self.temperal_upsample = temperal_downsample[::-1]
504
+
505
+ # modules
506
+ self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
507
+ attn_scales, self.temperal_downsample, dropout)
508
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
509
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
510
+ self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
511
+ attn_scales, self.temperal_upsample, dropout)
512
+
513
+ def forward(self, x):
514
+ mu, log_var = self.encode(x)
515
+ z = self.reparameterize(mu, log_var)
516
+ x_recon = self.decode(z)
517
+ return x_recon, mu, log_var
518
+
519
+ def encode(self, x, scale):
520
+ self.clear_cache()
521
+ ## cache
522
+ t = x.shape[2]
523
+ iter_ = 1 + (t - 1) // 4
524
+ scale = [item.to(x.device, x.dtype) for item in scale]
525
+ ## 对encode输入的x,按时间拆分为1、4、4、4....
526
+ for i in range(iter_):
527
+ self._enc_conv_idx = [0]
528
+ if i == 0:
529
+ out = self.encoder(
530
+ x[:, :, :1, :, :],
531
+ feat_cache=self._enc_feat_map,
532
+ feat_idx=self._enc_conv_idx)
533
+ else:
534
+ out_ = self.encoder(
535
+ x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
536
+ feat_cache=self._enc_feat_map,
537
+ feat_idx=self._enc_conv_idx)
538
+ out = torch.cat([out, out_], 2)
539
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
540
+ if isinstance(scale[0], torch.Tensor):
541
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
542
+ 1, self.z_dim, 1, 1, 1)
543
+ else:
544
+ mu = (mu - scale[0]) * scale[1]
545
+ x = torch.cat([mu, log_var], dim = 1)
546
+ self.clear_cache()
547
+ return x
548
+
549
+ def decode(self, z, scale):
550
+ self.clear_cache()
551
+ # z: [b,c,t,h,w]
552
+ scale = [item.to(z.device, z.dtype) for item in scale]
553
+ if isinstance(scale[0], torch.Tensor):
554
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
555
+ 1, self.z_dim, 1, 1, 1)
556
+ else:
557
+ z = z / scale[1] + scale[0]
558
+ iter_ = z.shape[2]
559
+ x = self.conv2(z)
560
+ for i in range(iter_):
561
+ self._conv_idx = [0]
562
+ if i == 0:
563
+ out = self.decoder(
564
+ x[:, :, i:i + 1, :, :],
565
+ feat_cache=self._feat_map,
566
+ feat_idx=self._conv_idx)
567
+ else:
568
+ out_ = self.decoder(
569
+ x[:, :, i:i + 1, :, :],
570
+ feat_cache=self._feat_map,
571
+ feat_idx=self._conv_idx)
572
+ out = torch.cat([out, out_], 2)
573
+ self.clear_cache()
574
+ return out
575
+
576
+ def reparameterize(self, mu, log_var):
577
+ std = torch.exp(0.5 * log_var)
578
+ eps = torch.randn_like(std)
579
+ return eps * std + mu
580
+
581
+ def sample(self, imgs, deterministic=False):
582
+ mu, log_var = self.encode(imgs)
583
+ if deterministic:
584
+ return mu
585
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
586
+ return mu + std * torch.randn_like(std)
587
+
588
+ def clear_cache(self):
589
+ self._conv_num = count_conv3d(self.decoder)
590
+ self._conv_idx = [0]
591
+ self._feat_map = [None] * self._conv_num
592
+ #cache encode
593
+ self._enc_conv_num = count_conv3d(self.encoder)
594
+ self._enc_conv_idx = [0]
595
+ self._enc_feat_map = [None] * self._enc_conv_num
596
+
597
+
598
+ def _video_vae(z_dim=None, **kwargs):
599
+ """
600
+ Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
601
+ """
602
+ # params
603
+ cfg = dict(
604
+ dim=96,
605
+ z_dim=z_dim,
606
+ dim_mult=[1, 2, 4, 4],
607
+ num_res_blocks=2,
608
+ attn_scales=[],
609
+ temperal_downsample=[False, True, True],
610
+ dropout=0.0)
611
+ cfg.update(**kwargs)
612
+
613
+ # init model
614
+ model = AutoencoderKLWan_(**cfg)
615
+
616
+ return model
617
+
618
+
619
+ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
620
+
621
+ @register_to_config
622
+ def __init__(
623
+ self,
624
+ latent_channels=16,
625
+ temporal_compression_ratio=4,
626
+ spacial_compression_ratio=8
627
+ ):
628
+ super().__init__()
629
+ mean = [
630
+ -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
631
+ 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
632
+ ]
633
+ std = [
634
+ 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
635
+ 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
636
+ ]
637
+ self.mean = torch.tensor(mean, dtype=torch.float32)
638
+ self.std = torch.tensor(std, dtype=torch.float32)
639
+ self.scale = [self.mean, 1.0 / self.std]
640
+
641
+ # init model
642
+ self.model = _video_vae(
643
+ z_dim=latent_channels,
644
+ )
645
+
646
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
647
+ x = [
648
+ self.model.encode(u.unsqueeze(0), self.scale).squeeze(0)
649
+ for u in x
650
+ ]
651
+ x = torch.stack(x)
652
+ return x
653
+
654
+ @apply_forward_hook
655
+ def encode(
656
+ self, x: torch.Tensor, return_dict: bool = True
657
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
658
+ h = self._encode(x)
659
+
660
+ posterior = DiagonalGaussianDistribution(h)
661
+
662
+ if not return_dict:
663
+ return (posterior,)
664
+ return AutoencoderKLOutput(latent_dist=posterior)
665
+
666
+ def _decode(self, zs):
667
+ dec = [
668
+ self.model.decode(u.unsqueeze(0), self.scale).clamp_(-1, 1).squeeze(0)
669
+ for u in zs
670
+ ]
671
+ dec = torch.stack(dec)
672
+
673
+ return DecoderOutput(sample=dec)
674
+
675
+ @apply_forward_hook
676
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
677
+ print(z.size())
678
+ decoded = self._decode(z).sample
679
+
680
+ if not return_dict:
681
+ return (decoded,)
682
+ return DecoderOutput(sample=decoded)
683
+
684
+ @classmethod
685
+ def from_pretrained(cls, pretrained_model_path, additional_kwargs={}):
686
+ def filter_kwargs(cls, kwargs):
687
+ import inspect
688
+ sig = inspect.signature(cls.__init__)
689
+ valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
690
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
691
+ return filtered_kwargs
692
+
693
+ model = cls(**filter_kwargs(cls, additional_kwargs))
694
+ if pretrained_model_path.endswith(".safetensors"):
695
+ from safetensors.torch import load_file, safe_open
696
+ state_dict = load_file(pretrained_model_path)
697
+ else:
698
+ state_dict = torch.load(pretrained_model_path, map_location="cpu")
699
+ tmp_state_dict = {}
700
+ for key in state_dict:
701
+ tmp_state_dict["model." + key] = state_dict[key]
702
+ state_dict = tmp_state_dict
703
+ m, u = model.load_state_dict(state_dict, strict=False)
704
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
705
+ print(m, u)
706
+ return model
cogvideox/models/wan_xlm_roberta.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ __all__ = ['XLMRoberta', 'xlm_roberta_large']
8
+
9
+
10
+ class SelfAttention(nn.Module):
11
+
12
+ def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
13
+ assert dim % num_heads == 0
14
+ super().__init__()
15
+ self.dim = dim
16
+ self.num_heads = num_heads
17
+ self.head_dim = dim // num_heads
18
+ self.eps = eps
19
+
20
+ # layers
21
+ self.q = nn.Linear(dim, dim)
22
+ self.k = nn.Linear(dim, dim)
23
+ self.v = nn.Linear(dim, dim)
24
+ self.o = nn.Linear(dim, dim)
25
+ self.dropout = nn.Dropout(dropout)
26
+
27
+ def forward(self, x, mask):
28
+ """
29
+ x: [B, L, C].
30
+ """
31
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
32
+
33
+ # compute query, key, value
34
+ q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
35
+ k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
36
+ v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
37
+
38
+ # compute attention
39
+ p = self.dropout.p if self.training else 0.0
40
+ x = F.scaled_dot_product_attention(q, k, v, mask, p)
41
+ x = x.permute(0, 2, 1, 3).reshape(b, s, c)
42
+
43
+ # output
44
+ x = self.o(x)
45
+ x = self.dropout(x)
46
+ return x
47
+
48
+
49
+ class AttentionBlock(nn.Module):
50
+
51
+ def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
52
+ super().__init__()
53
+ self.dim = dim
54
+ self.num_heads = num_heads
55
+ self.post_norm = post_norm
56
+ self.eps = eps
57
+
58
+ # layers
59
+ self.attn = SelfAttention(dim, num_heads, dropout, eps)
60
+ self.norm1 = nn.LayerNorm(dim, eps=eps)
61
+ self.ffn = nn.Sequential(
62
+ nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
63
+ nn.Dropout(dropout))
64
+ self.norm2 = nn.LayerNorm(dim, eps=eps)
65
+
66
+ def forward(self, x, mask):
67
+ if self.post_norm:
68
+ x = self.norm1(x + self.attn(x, mask))
69
+ x = self.norm2(x + self.ffn(x))
70
+ else:
71
+ x = x + self.attn(self.norm1(x), mask)
72
+ x = x + self.ffn(self.norm2(x))
73
+ return x
74
+
75
+
76
+ class XLMRoberta(nn.Module):
77
+ """
78
+ XLMRobertaModel with no pooler and no LM head.
79
+ """
80
+
81
+ def __init__(self,
82
+ vocab_size=250002,
83
+ max_seq_len=514,
84
+ type_size=1,
85
+ pad_id=1,
86
+ dim=1024,
87
+ num_heads=16,
88
+ num_layers=24,
89
+ post_norm=True,
90
+ dropout=0.1,
91
+ eps=1e-5):
92
+ super().__init__()
93
+ self.vocab_size = vocab_size
94
+ self.max_seq_len = max_seq_len
95
+ self.type_size = type_size
96
+ self.pad_id = pad_id
97
+ self.dim = dim
98
+ self.num_heads = num_heads
99
+ self.num_layers = num_layers
100
+ self.post_norm = post_norm
101
+ self.eps = eps
102
+
103
+ # embeddings
104
+ self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
105
+ self.type_embedding = nn.Embedding(type_size, dim)
106
+ self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
107
+ self.dropout = nn.Dropout(dropout)
108
+
109
+ # blocks
110
+ self.blocks = nn.ModuleList([
111
+ AttentionBlock(dim, num_heads, post_norm, dropout, eps)
112
+ for _ in range(num_layers)
113
+ ])
114
+
115
+ # norm layer
116
+ self.norm = nn.LayerNorm(dim, eps=eps)
117
+
118
+ def forward(self, ids):
119
+ """
120
+ ids: [B, L] of torch.LongTensor.
121
+ """
122
+ b, s = ids.shape
123
+ mask = ids.ne(self.pad_id).long()
124
+
125
+ # embeddings
126
+ x = self.token_embedding(ids) + \
127
+ self.type_embedding(torch.zeros_like(ids)) + \
128
+ self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
129
+ if self.post_norm:
130
+ x = self.norm(x)
131
+ x = self.dropout(x)
132
+
133
+ # blocks
134
+ mask = torch.where(
135
+ mask.view(b, 1, 1, s).gt(0), 0.0,
136
+ torch.finfo(x.dtype).min)
137
+ for block in self.blocks:
138
+ x = block(x, mask)
139
+
140
+ # output
141
+ if not self.post_norm:
142
+ x = self.norm(x)
143
+ return x
144
+
145
+
146
+ def xlm_roberta_large(pretrained=False,
147
+ return_tokenizer=False,
148
+ device='cpu',
149
+ **kwargs):
150
+ """
151
+ XLMRobertaLarge adapted from Huggingface.
152
+ """
153
+ # params
154
+ cfg = dict(
155
+ vocab_size=250002,
156
+ max_seq_len=514,
157
+ type_size=1,
158
+ pad_id=1,
159
+ dim=1024,
160
+ num_heads=16,
161
+ num_layers=24,
162
+ post_norm=True,
163
+ dropout=0.1,
164
+ eps=1e-5)
165
+ cfg.update(**kwargs)
166
+
167
+ # init a model on device
168
+ with torch.device(device):
169
+ model = XLMRoberta(**cfg)
170
+ return model
cogvideox/pipeline/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .pipeline_cogvideox_fun import CogVideoXFunPipeline
2
+ from .pipeline_cogvideox_fun_control import CogVideoXFunControlPipeline
3
+ from .pipeline_cogvideox_fun_inpaint import CogVideoXFunInpaintPipeline
4
+ from .pipeline_wan_fun import WanFunPipeline
5
+ from .pipeline_wan_fun_inpaint import WanFunInpaintPipeline
6
+
7
+ WanPipeline = WanFunPipeline
8
+ WanI2VPipeline = WanFunInpaintPipeline
cogvideox/pipeline/pipeline_cogvideox_fun.py ADDED
@@ -0,0 +1,877 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
24
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
25
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
26
+ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
27
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
28
+ from diffusers.utils.torch_utils import randn_tensor
29
+ from diffusers.video_processor import VideoProcessor
30
+
31
+ from ..models import (AutoencoderKLCogVideoX,
32
+ CogVideoXTransformer3DModel, T5EncoderModel,
33
+ T5Tokenizer)
34
+
35
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+
38
+ EXAMPLE_DOC_STRING = """
39
+ Examples:
40
+ ```python
41
+ >>> import torch
42
+ >>> from diffusers import CogVideoXFunPipeline
43
+ >>> from diffusers.utils import export_to_video
44
+
45
+ >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
46
+ >>> pipe = CogVideoXFunPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
47
+ >>> prompt = (
48
+ ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
49
+ ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
50
+ ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
51
+ ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
52
+ ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
53
+ ... "atmosphere of this unique musical performance."
54
+ ... )
55
+ >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
56
+ >>> export_to_video(video, "output.mp4", fps=8)
57
+ ```
58
+ """
59
+
60
+
61
+ # Copied from diffusers.models.embeddings.get_3d_rotary_pos_embed
62
+ def get_3d_rotary_pos_embed(
63
+ embed_dim,
64
+ crops_coords,
65
+ grid_size,
66
+ temporal_size,
67
+ theta: int = 10000,
68
+ use_real: bool = True,
69
+ grid_type: str = "linspace",
70
+ max_size: Optional[Tuple[int, int]] = None,
71
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
72
+ """
73
+ RoPE for video tokens with 3D structure.
74
+
75
+ Args:
76
+ embed_dim: (`int`):
77
+ The embedding dimension size, corresponding to hidden_size_head.
78
+ crops_coords (`Tuple[int]`):
79
+ The top-left and bottom-right coordinates of the crop.
80
+ grid_size (`Tuple[int]`):
81
+ The grid size of the spatial positional embedding (height, width).
82
+ temporal_size (`int`):
83
+ The size of the temporal dimension.
84
+ theta (`float`):
85
+ Scaling factor for frequency computation.
86
+ grid_type (`str`):
87
+ Whether to use "linspace" or "slice" to compute grids.
88
+
89
+ Returns:
90
+ `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
91
+ """
92
+ if use_real is not True:
93
+ raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
94
+
95
+ if grid_type == "linspace":
96
+ start, stop = crops_coords
97
+ grid_size_h, grid_size_w = grid_size
98
+ grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
99
+ grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
100
+ grid_t = np.arange(temporal_size, dtype=np.float32)
101
+ grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
102
+ elif grid_type == "slice":
103
+ max_h, max_w = max_size
104
+ grid_size_h, grid_size_w = grid_size
105
+ grid_h = np.arange(max_h, dtype=np.float32)
106
+ grid_w = np.arange(max_w, dtype=np.float32)
107
+ grid_t = np.arange(temporal_size, dtype=np.float32)
108
+ else:
109
+ raise ValueError("Invalid value passed for `grid_type`.")
110
+
111
+ # Compute dimensions for each axis
112
+ dim_t = embed_dim // 4
113
+ dim_h = embed_dim // 8 * 3
114
+ dim_w = embed_dim // 8 * 3
115
+
116
+ # Temporal frequencies
117
+ freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
118
+ # Spatial frequencies for height and width
119
+ freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
120
+ freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
121
+
122
+ # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
123
+ def combine_time_height_width(freqs_t, freqs_h, freqs_w):
124
+ freqs_t = freqs_t[:, None, None, :].expand(
125
+ -1, grid_size_h, grid_size_w, -1
126
+ ) # temporal_size, grid_size_h, grid_size_w, dim_t
127
+ freqs_h = freqs_h[None, :, None, :].expand(
128
+ temporal_size, -1, grid_size_w, -1
129
+ ) # temporal_size, grid_size_h, grid_size_2, dim_h
130
+ freqs_w = freqs_w[None, None, :, :].expand(
131
+ temporal_size, grid_size_h, -1, -1
132
+ ) # temporal_size, grid_size_h, grid_size_2, dim_w
133
+
134
+ freqs = torch.cat(
135
+ [freqs_t, freqs_h, freqs_w], dim=-1
136
+ ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
137
+ freqs = freqs.view(
138
+ temporal_size * grid_size_h * grid_size_w, -1
139
+ ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
140
+ return freqs
141
+
142
+ t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
143
+ h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
144
+ w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
145
+
146
+ if grid_type == "slice":
147
+ t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size]
148
+ h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h]
149
+ w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w]
150
+
151
+ cos = combine_time_height_width(t_cos, h_cos, w_cos)
152
+ sin = combine_time_height_width(t_sin, h_sin, w_sin)
153
+ return cos, sin
154
+
155
+
156
+ # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
157
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
158
+ tw = tgt_width
159
+ th = tgt_height
160
+ h, w = src
161
+ r = h / w
162
+ if r > (th / tw):
163
+ resize_height = th
164
+ resize_width = int(round(th / h * w))
165
+ else:
166
+ resize_width = tw
167
+ resize_height = int(round(tw / w * h))
168
+
169
+ crop_top = int(round((th - resize_height) / 2.0))
170
+ crop_left = int(round((tw - resize_width) / 2.0))
171
+
172
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
173
+
174
+
175
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
176
+ def retrieve_timesteps(
177
+ scheduler,
178
+ num_inference_steps: Optional[int] = None,
179
+ device: Optional[Union[str, torch.device]] = None,
180
+ timesteps: Optional[List[int]] = None,
181
+ sigmas: Optional[List[float]] = None,
182
+ **kwargs,
183
+ ):
184
+ """
185
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
186
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
187
+
188
+ Args:
189
+ scheduler (`SchedulerMixin`):
190
+ The scheduler to get timesteps from.
191
+ num_inference_steps (`int`):
192
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
193
+ must be `None`.
194
+ device (`str` or `torch.device`, *optional*):
195
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
196
+ timesteps (`List[int]`, *optional*):
197
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
198
+ `num_inference_steps` and `sigmas` must be `None`.
199
+ sigmas (`List[float]`, *optional*):
200
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
201
+ `num_inference_steps` and `timesteps` must be `None`.
202
+
203
+ Returns:
204
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
205
+ second element is the number of inference steps.
206
+ """
207
+ if timesteps is not None and sigmas is not None:
208
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
209
+ if timesteps is not None:
210
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
211
+ if not accepts_timesteps:
212
+ raise ValueError(
213
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
214
+ f" timestep schedules. Please check whether you are using the correct scheduler."
215
+ )
216
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
217
+ timesteps = scheduler.timesteps
218
+ num_inference_steps = len(timesteps)
219
+ elif sigmas is not None:
220
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
221
+ if not accept_sigmas:
222
+ raise ValueError(
223
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
224
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
225
+ )
226
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
227
+ timesteps = scheduler.timesteps
228
+ num_inference_steps = len(timesteps)
229
+ else:
230
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
231
+ timesteps = scheduler.timesteps
232
+ return timesteps, num_inference_steps
233
+
234
+
235
+ @dataclass
236
+ class CogVideoXFunPipelineOutput(BaseOutput):
237
+ r"""
238
+ Output class for CogVideo pipelines.
239
+
240
+ Args:
241
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
242
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
243
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
244
+ `(batch_size, num_frames, channels, height, width)`.
245
+ """
246
+
247
+ videos: torch.Tensor
248
+
249
+
250
+ class CogVideoXFunPipeline(DiffusionPipeline):
251
+ r"""
252
+ Pipeline for text-to-video generation using CogVideoX_Fun.
253
+
254
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
255
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
256
+
257
+ Args:
258
+ vae ([`AutoencoderKL`]):
259
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
260
+ text_encoder ([`T5EncoderModel`]):
261
+ Frozen text-encoder. CogVideoX uses
262
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
263
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
264
+ tokenizer (`T5Tokenizer`):
265
+ Tokenizer of class
266
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
267
+ transformer ([`CogVideoXTransformer3DModel`]):
268
+ A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
269
+ scheduler ([`SchedulerMixin`]):
270
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
271
+ """
272
+
273
+ _optional_components = []
274
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
275
+
276
+ _callback_tensor_inputs = [
277
+ "latents",
278
+ "prompt_embeds",
279
+ "negative_prompt_embeds",
280
+ ]
281
+
282
+ def __init__(
283
+ self,
284
+ tokenizer: T5Tokenizer,
285
+ text_encoder: T5EncoderModel,
286
+ vae: AutoencoderKLCogVideoX,
287
+ transformer: CogVideoXTransformer3DModel,
288
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
289
+ ):
290
+ super().__init__()
291
+
292
+ self.register_modules(
293
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
294
+ )
295
+ self.vae_scale_factor_spatial = (
296
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
297
+ )
298
+ self.vae_scale_factor_temporal = (
299
+ self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
300
+ )
301
+
302
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
303
+
304
+ def _get_t5_prompt_embeds(
305
+ self,
306
+ prompt: Union[str, List[str]] = None,
307
+ num_videos_per_prompt: int = 1,
308
+ max_sequence_length: int = 226,
309
+ device: Optional[torch.device] = None,
310
+ dtype: Optional[torch.dtype] = None,
311
+ ):
312
+ device = device or self._execution_device
313
+ dtype = dtype or self.text_encoder.dtype
314
+
315
+ prompt = [prompt] if isinstance(prompt, str) else prompt
316
+ batch_size = len(prompt)
317
+
318
+ text_inputs = self.tokenizer(
319
+ prompt,
320
+ padding="max_length",
321
+ max_length=max_sequence_length,
322
+ truncation=True,
323
+ add_special_tokens=True,
324
+ return_tensors="pt",
325
+ )
326
+ text_input_ids = text_inputs.input_ids
327
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
328
+
329
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
330
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
331
+ logger.warning(
332
+ "The following part of your input was truncated because `max_sequence_length` is set to "
333
+ f" {max_sequence_length} tokens: {removed_text}"
334
+ )
335
+
336
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
337
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
338
+
339
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
340
+ _, seq_len, _ = prompt_embeds.shape
341
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
342
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
343
+
344
+ return prompt_embeds
345
+
346
+ def encode_prompt(
347
+ self,
348
+ prompt: Union[str, List[str]],
349
+ negative_prompt: Optional[Union[str, List[str]]] = None,
350
+ do_classifier_free_guidance: bool = True,
351
+ num_videos_per_prompt: int = 1,
352
+ prompt_embeds: Optional[torch.Tensor] = None,
353
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
354
+ max_sequence_length: int = 226,
355
+ device: Optional[torch.device] = None,
356
+ dtype: Optional[torch.dtype] = None,
357
+ ):
358
+ r"""
359
+ Encodes the prompt into text encoder hidden states.
360
+
361
+ Args:
362
+ prompt (`str` or `List[str]`, *optional*):
363
+ prompt to be encoded
364
+ negative_prompt (`str` or `List[str]`, *optional*):
365
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
366
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
367
+ less than `1`).
368
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
369
+ Whether to use classifier free guidance or not.
370
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
371
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
372
+ prompt_embeds (`torch.Tensor`, *optional*):
373
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
374
+ provided, text embeddings will be generated from `prompt` input argument.
375
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
376
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
377
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
378
+ argument.
379
+ device: (`torch.device`, *optional*):
380
+ torch device
381
+ dtype: (`torch.dtype`, *optional*):
382
+ torch dtype
383
+ """
384
+ device = device or self._execution_device
385
+
386
+ prompt = [prompt] if isinstance(prompt, str) else prompt
387
+ if prompt is not None:
388
+ batch_size = len(prompt)
389
+ else:
390
+ batch_size = prompt_embeds.shape[0]
391
+
392
+ if prompt_embeds is None:
393
+ prompt_embeds = self._get_t5_prompt_embeds(
394
+ prompt=prompt,
395
+ num_videos_per_prompt=num_videos_per_prompt,
396
+ max_sequence_length=max_sequence_length,
397
+ device=device,
398
+ dtype=dtype,
399
+ )
400
+
401
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
402
+ negative_prompt = negative_prompt or ""
403
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
404
+
405
+ if prompt is not None and type(prompt) is not type(negative_prompt):
406
+ raise TypeError(
407
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
408
+ f" {type(prompt)}."
409
+ )
410
+ elif batch_size != len(negative_prompt):
411
+ raise ValueError(
412
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
413
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
414
+ " the batch size of `prompt`."
415
+ )
416
+
417
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
418
+ prompt=negative_prompt,
419
+ num_videos_per_prompt=num_videos_per_prompt,
420
+ max_sequence_length=max_sequence_length,
421
+ device=device,
422
+ dtype=dtype,
423
+ )
424
+
425
+ return prompt_embeds, negative_prompt_embeds
426
+
427
+ def prepare_latents(
428
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
429
+ ):
430
+ if isinstance(generator, list) and len(generator) != batch_size:
431
+ raise ValueError(
432
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
433
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
434
+ )
435
+
436
+ shape = (
437
+ batch_size,
438
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
439
+ num_channels_latents,
440
+ height // self.vae_scale_factor_spatial,
441
+ width // self.vae_scale_factor_spatial,
442
+ )
443
+
444
+ if latents is None:
445
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
446
+ else:
447
+ latents = latents.to(device)
448
+
449
+ # scale the initial noise by the standard deviation required by the scheduler
450
+ latents = latents * self.scheduler.init_noise_sigma
451
+ return latents
452
+
453
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
454
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
455
+ latents = 1 / self.vae.config.scaling_factor * latents
456
+
457
+ frames = self.vae.decode(latents).sample
458
+ frames = (frames / 2 + 0.5).clamp(0, 1)
459
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
460
+ frames = frames.cpu().float().numpy()
461
+ return frames
462
+
463
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
464
+ def prepare_extra_step_kwargs(self, generator, eta):
465
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
466
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
467
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
468
+ # and should be between [0, 1]
469
+
470
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
471
+ extra_step_kwargs = {}
472
+ if accepts_eta:
473
+ extra_step_kwargs["eta"] = eta
474
+
475
+ # check if the scheduler accepts generator
476
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
477
+ if accepts_generator:
478
+ extra_step_kwargs["generator"] = generator
479
+ return extra_step_kwargs
480
+
481
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
482
+ def check_inputs(
483
+ self,
484
+ prompt,
485
+ height,
486
+ width,
487
+ negative_prompt,
488
+ callback_on_step_end_tensor_inputs,
489
+ prompt_embeds=None,
490
+ negative_prompt_embeds=None,
491
+ ):
492
+ if height % 8 != 0 or width % 8 != 0:
493
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
494
+
495
+ if callback_on_step_end_tensor_inputs is not None and not all(
496
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
497
+ ):
498
+ raise ValueError(
499
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
500
+ )
501
+ if prompt is not None and prompt_embeds is not None:
502
+ raise ValueError(
503
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
504
+ " only forward one of the two."
505
+ )
506
+ elif prompt is None and prompt_embeds is None:
507
+ raise ValueError(
508
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
509
+ )
510
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
511
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
512
+
513
+ if prompt is not None and negative_prompt_embeds is not None:
514
+ raise ValueError(
515
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
516
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
517
+ )
518
+
519
+ if negative_prompt is not None and negative_prompt_embeds is not None:
520
+ raise ValueError(
521
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
522
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
523
+ )
524
+
525
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
526
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
527
+ raise ValueError(
528
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
529
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
530
+ f" {negative_prompt_embeds.shape}."
531
+ )
532
+
533
+ def fuse_qkv_projections(self) -> None:
534
+ r"""Enables fused QKV projections."""
535
+ self.fusing_transformer = True
536
+ self.transformer.fuse_qkv_projections()
537
+
538
+ def unfuse_qkv_projections(self) -> None:
539
+ r"""Disable QKV projection fusion if enabled."""
540
+ if not self.fusing_transformer:
541
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
542
+ else:
543
+ self.transformer.unfuse_qkv_projections()
544
+ self.fusing_transformer = False
545
+
546
+ def _prepare_rotary_positional_embeddings(
547
+ self,
548
+ height: int,
549
+ width: int,
550
+ num_frames: int,
551
+ device: torch.device,
552
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
553
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
554
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
555
+
556
+ p = self.transformer.config.patch_size
557
+ p_t = self.transformer.config.patch_size_t
558
+
559
+ base_size_width = self.transformer.config.sample_width // p
560
+ base_size_height = self.transformer.config.sample_height // p
561
+
562
+ if p_t is None:
563
+ # CogVideoX 1.0
564
+ grid_crops_coords = get_resize_crop_region_for_grid(
565
+ (grid_height, grid_width), base_size_width, base_size_height
566
+ )
567
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
568
+ embed_dim=self.transformer.config.attention_head_dim,
569
+ crops_coords=grid_crops_coords,
570
+ grid_size=(grid_height, grid_width),
571
+ temporal_size=num_frames,
572
+ )
573
+ else:
574
+ # CogVideoX 1.5
575
+ base_num_frames = (num_frames + p_t - 1) // p_t
576
+
577
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
578
+ embed_dim=self.transformer.config.attention_head_dim,
579
+ crops_coords=None,
580
+ grid_size=(grid_height, grid_width),
581
+ temporal_size=base_num_frames,
582
+ grid_type="slice",
583
+ max_size=(base_size_height, base_size_width),
584
+ )
585
+
586
+ freqs_cos = freqs_cos.to(device=device)
587
+ freqs_sin = freqs_sin.to(device=device)
588
+ return freqs_cos, freqs_sin
589
+
590
+ @property
591
+ def guidance_scale(self):
592
+ return self._guidance_scale
593
+
594
+ @property
595
+ def num_timesteps(self):
596
+ return self._num_timesteps
597
+
598
+ @property
599
+ def attention_kwargs(self):
600
+ return self._attention_kwargs
601
+
602
+ @property
603
+ def interrupt(self):
604
+ return self._interrupt
605
+
606
+ @torch.no_grad()
607
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
608
+ def __call__(
609
+ self,
610
+ prompt: Optional[Union[str, List[str]]] = None,
611
+ negative_prompt: Optional[Union[str, List[str]]] = None,
612
+ height: int = 480,
613
+ width: int = 720,
614
+ num_frames: int = 49,
615
+ num_inference_steps: int = 50,
616
+ timesteps: Optional[List[int]] = None,
617
+ guidance_scale: float = 6,
618
+ use_dynamic_cfg: bool = False,
619
+ num_videos_per_prompt: int = 1,
620
+ eta: float = 0.0,
621
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
622
+ latents: Optional[torch.FloatTensor] = None,
623
+ prompt_embeds: Optional[torch.FloatTensor] = None,
624
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
625
+ output_type: str = "numpy",
626
+ return_dict: bool = False,
627
+ callback_on_step_end: Optional[
628
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
629
+ ] = None,
630
+ attention_kwargs: Optional[Dict[str, Any]] = None,
631
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
632
+ max_sequence_length: int = 226,
633
+ ) -> Union[CogVideoXFunPipelineOutput, Tuple]:
634
+ """
635
+ Function invoked when calling the pipeline for generation.
636
+
637
+ Args:
638
+ prompt (`str` or `List[str]`, *optional*):
639
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
640
+ instead.
641
+ negative_prompt (`str` or `List[str]`, *optional*):
642
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
643
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
644
+ less than `1`).
645
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
646
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
647
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
648
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
649
+ num_frames (`int`, defaults to `48`):
650
+ Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
651
+ contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
652
+ num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
653
+ needs to be satisfied is that of divisibility mentioned above.
654
+ num_inference_steps (`int`, *optional*, defaults to 50):
655
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
656
+ expense of slower inference.
657
+ timesteps (`List[int]`, *optional*):
658
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
659
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
660
+ passed will be used. Must be in descending order.
661
+ guidance_scale (`float`, *optional*, defaults to 7.0):
662
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
663
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
664
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
665
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
666
+ usually at the expense of lower image quality.
667
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
668
+ The number of videos to generate per prompt.
669
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
670
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
671
+ to make generation deterministic.
672
+ latents (`torch.FloatTensor`, *optional*):
673
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
674
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
675
+ tensor will ge generated by sampling using the supplied random `generator`.
676
+ prompt_embeds (`torch.FloatTensor`, *optional*):
677
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
678
+ provided, text embeddings will be generated from `prompt` input argument.
679
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
680
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
681
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
682
+ argument.
683
+ output_type (`str`, *optional*, defaults to `"pil"`):
684
+ The output format of the generate image. Choose between
685
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
686
+ return_dict (`bool`, *optional*, defaults to `True`):
687
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
688
+ of a plain tuple.
689
+ callback_on_step_end (`Callable`, *optional*):
690
+ A function that calls at the end of each denoising steps during the inference. The function is called
691
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
692
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
693
+ `callback_on_step_end_tensor_inputs`.
694
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
695
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
696
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
697
+ `._callback_tensor_inputs` attribute of your pipeline class.
698
+ max_sequence_length (`int`, defaults to `226`):
699
+ Maximum sequence length in encoded prompt. Must be consistent with
700
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
701
+
702
+ Examples:
703
+
704
+ Returns:
705
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] or `tuple`:
706
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] if `return_dict` is True, otherwise a
707
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
708
+ """
709
+
710
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
711
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
712
+
713
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
714
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
715
+ num_frames = num_frames or self.transformer.config.sample_frames
716
+
717
+ num_videos_per_prompt = 1
718
+
719
+ # 1. Check inputs. Raise error if not correct
720
+ self.check_inputs(
721
+ prompt,
722
+ height,
723
+ width,
724
+ negative_prompt,
725
+ callback_on_step_end_tensor_inputs,
726
+ prompt_embeds,
727
+ negative_prompt_embeds,
728
+ )
729
+ self._guidance_scale = guidance_scale
730
+ self._attention_kwargs = attention_kwargs
731
+ self._interrupt = False
732
+
733
+ # 2. Default call parameters
734
+ if prompt is not None and isinstance(prompt, str):
735
+ batch_size = 1
736
+ elif prompt is not None and isinstance(prompt, list):
737
+ batch_size = len(prompt)
738
+ else:
739
+ batch_size = prompt_embeds.shape[0]
740
+
741
+ device = self._execution_device
742
+
743
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
744
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
745
+ # corresponds to doing no classifier free guidance.
746
+ do_classifier_free_guidance = guidance_scale > 1.0
747
+
748
+ # 3. Encode input prompt
749
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
750
+ prompt,
751
+ negative_prompt,
752
+ do_classifier_free_guidance,
753
+ num_videos_per_prompt=num_videos_per_prompt,
754
+ prompt_embeds=prompt_embeds,
755
+ negative_prompt_embeds=negative_prompt_embeds,
756
+ max_sequence_length=max_sequence_length,
757
+ device=device,
758
+ )
759
+ if do_classifier_free_guidance:
760
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
761
+
762
+ # 4. Prepare timesteps
763
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
764
+ self._num_timesteps = len(timesteps)
765
+
766
+ # 5. Prepare latents
767
+ latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
768
+
769
+ # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
770
+ patch_size_t = self.transformer.config.patch_size_t
771
+ additional_frames = 0
772
+ if num_frames != 1 and patch_size_t is not None and latent_frames % patch_size_t != 0:
773
+ additional_frames = patch_size_t - latent_frames % patch_size_t
774
+ num_frames += additional_frames * self.vae_scale_factor_temporal
775
+
776
+ latent_channels = self.transformer.config.in_channels
777
+ latents = self.prepare_latents(
778
+ batch_size * num_videos_per_prompt,
779
+ latent_channels,
780
+ num_frames,
781
+ height,
782
+ width,
783
+ prompt_embeds.dtype,
784
+ device,
785
+ generator,
786
+ latents,
787
+ )
788
+
789
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
790
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
791
+
792
+ # 7. Create rotary embeds if required
793
+ image_rotary_emb = (
794
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
795
+ if self.transformer.config.use_rotary_positional_embeddings
796
+ else None
797
+ )
798
+
799
+ # 8. Denoising loop
800
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
801
+
802
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
803
+ # for DPM-solver++
804
+ old_pred_original_sample = None
805
+ for i, t in enumerate(timesteps):
806
+ if self.interrupt:
807
+ continue
808
+
809
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
810
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
811
+
812
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
813
+ timestep = t.expand(latent_model_input.shape[0])
814
+
815
+ # predict noise model_output
816
+ noise_pred = self.transformer(
817
+ hidden_states=latent_model_input,
818
+ encoder_hidden_states=prompt_embeds,
819
+ timestep=timestep,
820
+ image_rotary_emb=image_rotary_emb,
821
+ return_dict=False,
822
+ )[0]
823
+ noise_pred = noise_pred.float()
824
+
825
+ # perform guidance
826
+ if use_dynamic_cfg:
827
+ self._guidance_scale = 1 + guidance_scale * (
828
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
829
+ )
830
+ if do_classifier_free_guidance:
831
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
832
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
833
+
834
+ # compute the previous noisy sample x_t -> x_t-1
835
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
836
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
837
+ else:
838
+ latents, old_pred_original_sample = self.scheduler.step(
839
+ noise_pred,
840
+ old_pred_original_sample,
841
+ t,
842
+ timesteps[i - 1] if i > 0 else None,
843
+ latents,
844
+ **extra_step_kwargs,
845
+ return_dict=False,
846
+ )
847
+ latents = latents.to(prompt_embeds.dtype)
848
+
849
+ # call the callback, if provided
850
+ if callback_on_step_end is not None:
851
+ callback_kwargs = {}
852
+ for k in callback_on_step_end_tensor_inputs:
853
+ callback_kwargs[k] = locals()[k]
854
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
855
+
856
+ latents = callback_outputs.pop("latents", latents)
857
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
858
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
859
+
860
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
861
+ progress_bar.update()
862
+
863
+ if output_type == "numpy":
864
+ video = self.decode_latents(latents)
865
+ elif not output_type == "latent":
866
+ video = self.decode_latents(latents)
867
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
868
+ else:
869
+ video = latents
870
+
871
+ # Offload all models
872
+ self.maybe_free_model_hooks()
873
+
874
+ if not return_dict:
875
+ video = torch.from_numpy(video)
876
+
877
+ return CogVideoXFunPipelineOutput(videos=video)
cogvideox/pipeline/pipeline_cogvideox_fun_control.py ADDED
@@ -0,0 +1,971 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
25
+ from diffusers.image_processor import VaeImageProcessor
26
+ from diffusers.models.embeddings import (get_1d_rotary_pos_embed,
27
+ get_3d_rotary_pos_embed)
28
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
29
+ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
30
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
31
+ from diffusers.utils.torch_utils import randn_tensor
32
+ from diffusers.video_processor import VideoProcessor
33
+ from einops import rearrange
34
+
35
+ from ..models import (AutoencoderKLCogVideoX,
36
+ CogVideoXTransformer3DModel, T5EncoderModel,
37
+ T5Tokenizer)
38
+
39
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
+
41
+
42
+ EXAMPLE_DOC_STRING = """
43
+ Examples:
44
+ ```python
45
+ >>> import torch
46
+ >>> from diffusers import CogVideoXFunPipeline
47
+ >>> from diffusers.utils import export_to_video
48
+
49
+ >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
50
+ >>> pipe = CogVideoXFunPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
51
+ >>> prompt = (
52
+ ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
53
+ ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
54
+ ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
55
+ ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
56
+ ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
57
+ ... "atmosphere of this unique musical performance."
58
+ ... )
59
+ >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
60
+ >>> export_to_video(video, "output.mp4", fps=8)
61
+ ```
62
+ """
63
+
64
+
65
+ # Copied from diffusers.models.embeddings.get_3d_rotary_pos_embed
66
+ def get_3d_rotary_pos_embed(
67
+ embed_dim,
68
+ crops_coords,
69
+ grid_size,
70
+ temporal_size,
71
+ theta: int = 10000,
72
+ use_real: bool = True,
73
+ grid_type: str = "linspace",
74
+ max_size: Optional[Tuple[int, int]] = None,
75
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
76
+ """
77
+ RoPE for video tokens with 3D structure.
78
+
79
+ Args:
80
+ embed_dim: (`int`):
81
+ The embedding dimension size, corresponding to hidden_size_head.
82
+ crops_coords (`Tuple[int]`):
83
+ The top-left and bottom-right coordinates of the crop.
84
+ grid_size (`Tuple[int]`):
85
+ The grid size of the spatial positional embedding (height, width).
86
+ temporal_size (`int`):
87
+ The size of the temporal dimension.
88
+ theta (`float`):
89
+ Scaling factor for frequency computation.
90
+ grid_type (`str`):
91
+ Whether to use "linspace" or "slice" to compute grids.
92
+
93
+ Returns:
94
+ `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
95
+ """
96
+ if use_real is not True:
97
+ raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
98
+
99
+ if grid_type == "linspace":
100
+ start, stop = crops_coords
101
+ grid_size_h, grid_size_w = grid_size
102
+ grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
103
+ grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
104
+ grid_t = np.arange(temporal_size, dtype=np.float32)
105
+ grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
106
+ elif grid_type == "slice":
107
+ max_h, max_w = max_size
108
+ grid_size_h, grid_size_w = grid_size
109
+ grid_h = np.arange(max_h, dtype=np.float32)
110
+ grid_w = np.arange(max_w, dtype=np.float32)
111
+ grid_t = np.arange(temporal_size, dtype=np.float32)
112
+ else:
113
+ raise ValueError("Invalid value passed for `grid_type`.")
114
+
115
+ # Compute dimensions for each axis
116
+ dim_t = embed_dim // 4
117
+ dim_h = embed_dim // 8 * 3
118
+ dim_w = embed_dim // 8 * 3
119
+
120
+ # Temporal frequencies
121
+ freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
122
+ # Spatial frequencies for height and width
123
+ freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
124
+ freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
125
+
126
+ # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
127
+ def combine_time_height_width(freqs_t, freqs_h, freqs_w):
128
+ freqs_t = freqs_t[:, None, None, :].expand(
129
+ -1, grid_size_h, grid_size_w, -1
130
+ ) # temporal_size, grid_size_h, grid_size_w, dim_t
131
+ freqs_h = freqs_h[None, :, None, :].expand(
132
+ temporal_size, -1, grid_size_w, -1
133
+ ) # temporal_size, grid_size_h, grid_size_2, dim_h
134
+ freqs_w = freqs_w[None, None, :, :].expand(
135
+ temporal_size, grid_size_h, -1, -1
136
+ ) # temporal_size, grid_size_h, grid_size_2, dim_w
137
+
138
+ freqs = torch.cat(
139
+ [freqs_t, freqs_h, freqs_w], dim=-1
140
+ ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
141
+ freqs = freqs.view(
142
+ temporal_size * grid_size_h * grid_size_w, -1
143
+ ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
144
+ return freqs
145
+
146
+ t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
147
+ h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
148
+ w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
149
+
150
+ if grid_type == "slice":
151
+ t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size]
152
+ h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h]
153
+ w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w]
154
+
155
+ cos = combine_time_height_width(t_cos, h_cos, w_cos)
156
+ sin = combine_time_height_width(t_sin, h_sin, w_sin)
157
+ return cos, sin
158
+
159
+
160
+ # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
161
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
162
+ tw = tgt_width
163
+ th = tgt_height
164
+ h, w = src
165
+ r = h / w
166
+ if r > (th / tw):
167
+ resize_height = th
168
+ resize_width = int(round(th / h * w))
169
+ else:
170
+ resize_width = tw
171
+ resize_height = int(round(tw / w * h))
172
+
173
+ crop_top = int(round((th - resize_height) / 2.0))
174
+ crop_left = int(round((tw - resize_width) / 2.0))
175
+
176
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
177
+
178
+
179
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
180
+ def retrieve_timesteps(
181
+ scheduler,
182
+ num_inference_steps: Optional[int] = None,
183
+ device: Optional[Union[str, torch.device]] = None,
184
+ timesteps: Optional[List[int]] = None,
185
+ sigmas: Optional[List[float]] = None,
186
+ **kwargs,
187
+ ):
188
+ """
189
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
190
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
191
+
192
+ Args:
193
+ scheduler (`SchedulerMixin`):
194
+ The scheduler to get timesteps from.
195
+ num_inference_steps (`int`):
196
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
197
+ must be `None`.
198
+ device (`str` or `torch.device`, *optional*):
199
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
200
+ timesteps (`List[int]`, *optional*):
201
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
202
+ `num_inference_steps` and `sigmas` must be `None`.
203
+ sigmas (`List[float]`, *optional*):
204
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
205
+ `num_inference_steps` and `timesteps` must be `None`.
206
+
207
+ Returns:
208
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
209
+ second element is the number of inference steps.
210
+ """
211
+ if timesteps is not None and sigmas is not None:
212
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
213
+ if timesteps is not None:
214
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
215
+ if not accepts_timesteps:
216
+ raise ValueError(
217
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
218
+ f" timestep schedules. Please check whether you are using the correct scheduler."
219
+ )
220
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
221
+ timesteps = scheduler.timesteps
222
+ num_inference_steps = len(timesteps)
223
+ elif sigmas is not None:
224
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
225
+ if not accept_sigmas:
226
+ raise ValueError(
227
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
228
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
229
+ )
230
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
231
+ timesteps = scheduler.timesteps
232
+ num_inference_steps = len(timesteps)
233
+ else:
234
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
235
+ timesteps = scheduler.timesteps
236
+ return timesteps, num_inference_steps
237
+
238
+
239
+ @dataclass
240
+ class CogVideoXFunPipelineOutput(BaseOutput):
241
+ r"""
242
+ Output class for CogVideo pipelines.
243
+
244
+ Args:
245
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
246
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
247
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
248
+ `(batch_size, num_frames, channels, height, width)`.
249
+ """
250
+
251
+ videos: torch.Tensor
252
+
253
+
254
+ class CogVideoXFunControlPipeline(DiffusionPipeline):
255
+ r"""
256
+ Pipeline for text-to-video generation using CogVideoX.
257
+
258
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
259
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
260
+
261
+ Args:
262
+ vae ([`AutoencoderKL`]):
263
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
264
+ text_encoder ([`T5EncoderModel`]):
265
+ Frozen text-encoder. CogVideoX_Fun uses
266
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
267
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
268
+ tokenizer (`T5Tokenizer`):
269
+ Tokenizer of class
270
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
271
+ transformer ([`CogVideoXTransformer3DModel`]):
272
+ A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
273
+ scheduler ([`SchedulerMixin`]):
274
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
275
+ """
276
+
277
+ _optional_components = []
278
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
279
+
280
+ _callback_tensor_inputs = [
281
+ "latents",
282
+ "prompt_embeds",
283
+ "negative_prompt_embeds",
284
+ ]
285
+
286
+ def __init__(
287
+ self,
288
+ tokenizer: T5Tokenizer,
289
+ text_encoder: T5EncoderModel,
290
+ vae: AutoencoderKLCogVideoX,
291
+ transformer: CogVideoXTransformer3DModel,
292
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
293
+ ):
294
+ super().__init__()
295
+
296
+ self.register_modules(
297
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
298
+ )
299
+ self.vae_scale_factor_spatial = (
300
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
301
+ )
302
+ self.vae_scale_factor_temporal = (
303
+ self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
304
+ )
305
+
306
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
307
+
308
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
309
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
310
+ self.mask_processor = VaeImageProcessor(
311
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
312
+ )
313
+
314
+ def _get_t5_prompt_embeds(
315
+ self,
316
+ prompt: Union[str, List[str]] = None,
317
+ num_videos_per_prompt: int = 1,
318
+ max_sequence_length: int = 226,
319
+ device: Optional[torch.device] = None,
320
+ dtype: Optional[torch.dtype] = None,
321
+ ):
322
+ device = device or self._execution_device
323
+ dtype = dtype or self.text_encoder.dtype
324
+
325
+ prompt = [prompt] if isinstance(prompt, str) else prompt
326
+ batch_size = len(prompt)
327
+
328
+ text_inputs = self.tokenizer(
329
+ prompt,
330
+ padding="max_length",
331
+ max_length=max_sequence_length,
332
+ truncation=True,
333
+ add_special_tokens=True,
334
+ return_tensors="pt",
335
+ )
336
+ text_input_ids = text_inputs.input_ids
337
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
338
+
339
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
340
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
341
+ logger.warning(
342
+ "The following part of your input was truncated because `max_sequence_length` is set to "
343
+ f" {max_sequence_length} tokens: {removed_text}"
344
+ )
345
+
346
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
347
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
348
+
349
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
350
+ _, seq_len, _ = prompt_embeds.shape
351
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
352
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
353
+
354
+ return prompt_embeds
355
+
356
+ def encode_prompt(
357
+ self,
358
+ prompt: Union[str, List[str]],
359
+ negative_prompt: Optional[Union[str, List[str]]] = None,
360
+ do_classifier_free_guidance: bool = True,
361
+ num_videos_per_prompt: int = 1,
362
+ prompt_embeds: Optional[torch.Tensor] = None,
363
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
364
+ max_sequence_length: int = 226,
365
+ device: Optional[torch.device] = None,
366
+ dtype: Optional[torch.dtype] = None,
367
+ ):
368
+ r"""
369
+ Encodes the prompt into text encoder hidden states.
370
+
371
+ Args:
372
+ prompt (`str` or `List[str]`, *optional*):
373
+ prompt to be encoded
374
+ negative_prompt (`str` or `List[str]`, *optional*):
375
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
376
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
377
+ less than `1`).
378
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
379
+ Whether to use classifier free guidance or not.
380
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
381
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
382
+ prompt_embeds (`torch.Tensor`, *optional*):
383
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
384
+ provided, text embeddings will be generated from `prompt` input argument.
385
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
386
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
387
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
388
+ argument.
389
+ device: (`torch.device`, *optional*):
390
+ torch device
391
+ dtype: (`torch.dtype`, *optional*):
392
+ torch dtype
393
+ """
394
+ device = device or self._execution_device
395
+
396
+ prompt = [prompt] if isinstance(prompt, str) else prompt
397
+ if prompt is not None:
398
+ batch_size = len(prompt)
399
+ else:
400
+ batch_size = prompt_embeds.shape[0]
401
+
402
+ if prompt_embeds is None:
403
+ prompt_embeds = self._get_t5_prompt_embeds(
404
+ prompt=prompt,
405
+ num_videos_per_prompt=num_videos_per_prompt,
406
+ max_sequence_length=max_sequence_length,
407
+ device=device,
408
+ dtype=dtype,
409
+ )
410
+
411
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
412
+ negative_prompt = negative_prompt or ""
413
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
414
+
415
+ if prompt is not None and type(prompt) is not type(negative_prompt):
416
+ raise TypeError(
417
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
418
+ f" {type(prompt)}."
419
+ )
420
+ elif batch_size != len(negative_prompt):
421
+ raise ValueError(
422
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
423
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
424
+ " the batch size of `prompt`."
425
+ )
426
+
427
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
428
+ prompt=negative_prompt,
429
+ num_videos_per_prompt=num_videos_per_prompt,
430
+ max_sequence_length=max_sequence_length,
431
+ device=device,
432
+ dtype=dtype,
433
+ )
434
+
435
+ return prompt_embeds, negative_prompt_embeds
436
+
437
+ def prepare_latents(
438
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
439
+ ):
440
+ shape = (
441
+ batch_size,
442
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
443
+ num_channels_latents,
444
+ height // self.vae_scale_factor_spatial,
445
+ width // self.vae_scale_factor_spatial,
446
+ )
447
+ if isinstance(generator, list) and len(generator) != batch_size:
448
+ raise ValueError(
449
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
450
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
451
+ )
452
+
453
+ if latents is None:
454
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
455
+ else:
456
+ latents = latents.to(device)
457
+
458
+ # scale the initial noise by the standard deviation required by the scheduler
459
+ latents = latents * self.scheduler.init_noise_sigma
460
+ return latents
461
+
462
+ def prepare_control_latents(
463
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
464
+ ):
465
+ # resize the mask to latents shape as we concatenate the mask to the latents
466
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
467
+ # and half precision
468
+
469
+ if mask is not None:
470
+ mask = mask.to(device=device, dtype=self.vae.dtype)
471
+ bs = 1
472
+ new_mask = []
473
+ for i in range(0, mask.shape[0], bs):
474
+ mask_bs = mask[i : i + bs]
475
+ mask_bs = self.vae.encode(mask_bs)[0]
476
+ mask_bs = mask_bs.mode()
477
+ new_mask.append(mask_bs)
478
+ mask = torch.cat(new_mask, dim = 0)
479
+ mask = mask * self.vae.config.scaling_factor
480
+
481
+ if masked_image is not None:
482
+ masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
483
+ bs = 1
484
+ new_mask_pixel_values = []
485
+ for i in range(0, masked_image.shape[0], bs):
486
+ mask_pixel_values_bs = masked_image[i : i + bs]
487
+ mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
488
+ mask_pixel_values_bs = mask_pixel_values_bs.mode()
489
+ new_mask_pixel_values.append(mask_pixel_values_bs)
490
+ masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
491
+ masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
492
+ else:
493
+ masked_image_latents = None
494
+
495
+ return mask, masked_image_latents
496
+
497
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
498
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
499
+ latents = 1 / self.vae.config.scaling_factor * latents
500
+
501
+ frames = self.vae.decode(latents).sample
502
+ frames = (frames / 2 + 0.5).clamp(0, 1)
503
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
504
+ frames = frames.cpu().float().numpy()
505
+ return frames
506
+
507
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
508
+ def prepare_extra_step_kwargs(self, generator, eta):
509
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
510
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
511
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
512
+ # and should be between [0, 1]
513
+
514
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
515
+ extra_step_kwargs = {}
516
+ if accepts_eta:
517
+ extra_step_kwargs["eta"] = eta
518
+
519
+ # check if the scheduler accepts generator
520
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
521
+ if accepts_generator:
522
+ extra_step_kwargs["generator"] = generator
523
+ return extra_step_kwargs
524
+
525
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
526
+ def check_inputs(
527
+ self,
528
+ prompt,
529
+ height,
530
+ width,
531
+ negative_prompt,
532
+ callback_on_step_end_tensor_inputs,
533
+ prompt_embeds=None,
534
+ negative_prompt_embeds=None,
535
+ ):
536
+ if height % 8 != 0 or width % 8 != 0:
537
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
538
+
539
+ if callback_on_step_end_tensor_inputs is not None and not all(
540
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
541
+ ):
542
+ raise ValueError(
543
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
544
+ )
545
+ if prompt is not None and prompt_embeds is not None:
546
+ raise ValueError(
547
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
548
+ " only forward one of the two."
549
+ )
550
+ elif prompt is None and prompt_embeds is None:
551
+ raise ValueError(
552
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
553
+ )
554
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
555
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
556
+
557
+ if prompt is not None and negative_prompt_embeds is not None:
558
+ raise ValueError(
559
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
560
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
561
+ )
562
+
563
+ if negative_prompt is not None and negative_prompt_embeds is not None:
564
+ raise ValueError(
565
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
566
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
567
+ )
568
+
569
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
570
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
571
+ raise ValueError(
572
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
573
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
574
+ f" {negative_prompt_embeds.shape}."
575
+ )
576
+
577
+ def fuse_qkv_projections(self) -> None:
578
+ r"""Enables fused QKV projections."""
579
+ self.fusing_transformer = True
580
+ self.transformer.fuse_qkv_projections()
581
+
582
+ def unfuse_qkv_projections(self) -> None:
583
+ r"""Disable QKV projection fusion if enabled."""
584
+ if not self.fusing_transformer:
585
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
586
+ else:
587
+ self.transformer.unfuse_qkv_projections()
588
+ self.fusing_transformer = False
589
+
590
+ def _prepare_rotary_positional_embeddings(
591
+ self,
592
+ height: int,
593
+ width: int,
594
+ num_frames: int,
595
+ device: torch.device,
596
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
597
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
598
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
599
+
600
+ p = self.transformer.config.patch_size
601
+ p_t = self.transformer.config.patch_size_t
602
+
603
+ base_size_width = self.transformer.config.sample_width // p
604
+ base_size_height = self.transformer.config.sample_height // p
605
+
606
+ if p_t is None:
607
+ # CogVideoX 1.0
608
+ grid_crops_coords = get_resize_crop_region_for_grid(
609
+ (grid_height, grid_width), base_size_width, base_size_height
610
+ )
611
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
612
+ embed_dim=self.transformer.config.attention_head_dim,
613
+ crops_coords=grid_crops_coords,
614
+ grid_size=(grid_height, grid_width),
615
+ temporal_size=num_frames,
616
+ )
617
+ else:
618
+ # CogVideoX 1.5
619
+ base_num_frames = (num_frames + p_t - 1) // p_t
620
+
621
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
622
+ embed_dim=self.transformer.config.attention_head_dim,
623
+ crops_coords=None,
624
+ grid_size=(grid_height, grid_width),
625
+ temporal_size=base_num_frames,
626
+ grid_type="slice",
627
+ max_size=(base_size_height, base_size_width),
628
+ )
629
+
630
+ freqs_cos = freqs_cos.to(device=device)
631
+ freqs_sin = freqs_sin.to(device=device)
632
+ return freqs_cos, freqs_sin
633
+
634
+ @property
635
+ def guidance_scale(self):
636
+ return self._guidance_scale
637
+
638
+ @property
639
+ def num_timesteps(self):
640
+ return self._num_timesteps
641
+
642
+ @property
643
+ def interrupt(self):
644
+ return self._interrupt
645
+
646
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
647
+ def get_timesteps(self, num_inference_steps, strength, device):
648
+ # get the original timestep using init_timestep
649
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
650
+
651
+ t_start = max(num_inference_steps - init_timestep, 0)
652
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
653
+
654
+ return timesteps, num_inference_steps - t_start
655
+
656
+ @torch.no_grad()
657
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
658
+ def __call__(
659
+ self,
660
+ prompt: Optional[Union[str, List[str]]] = None,
661
+ negative_prompt: Optional[Union[str, List[str]]] = None,
662
+ height: int = 480,
663
+ width: int = 720,
664
+ video: Union[torch.FloatTensor] = None,
665
+ control_video: Union[torch.FloatTensor] = None,
666
+ num_frames: int = 49,
667
+ num_inference_steps: int = 50,
668
+ timesteps: Optional[List[int]] = None,
669
+ guidance_scale: float = 6,
670
+ use_dynamic_cfg: bool = False,
671
+ num_videos_per_prompt: int = 1,
672
+ eta: float = 0.0,
673
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
674
+ latents: Optional[torch.FloatTensor] = None,
675
+ prompt_embeds: Optional[torch.FloatTensor] = None,
676
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
677
+ output_type: str = "numpy",
678
+ return_dict: bool = False,
679
+ callback_on_step_end: Optional[
680
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
681
+ ] = None,
682
+ attention_kwargs: Optional[Dict[str, Any]] = None,
683
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
684
+ max_sequence_length: int = 226,
685
+ comfyui_progressbar: bool = False,
686
+ ) -> Union[CogVideoXFunPipelineOutput, Tuple]:
687
+ """
688
+ Function invoked when calling the pipeline for generation.
689
+
690
+ Args:
691
+ prompt (`str` or `List[str]`, *optional*):
692
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
693
+ instead.
694
+ negative_prompt (`str` or `List[str]`, *optional*):
695
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
696
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
697
+ less than `1`).
698
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
699
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
700
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
701
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
702
+ num_frames (`int`, defaults to `48`):
703
+ Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
704
+ contain 1 extra frame because CogVideoX_Fun is conditioned with (num_seconds * fps + 1) frames where
705
+ num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
706
+ needs to be satisfied is that of divisibility mentioned above.
707
+ num_inference_steps (`int`, *optional*, defaults to 50):
708
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
709
+ expense of slower inference.
710
+ timesteps (`List[int]`, *optional*):
711
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
712
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
713
+ passed will be used. Must be in descending order.
714
+ guidance_scale (`float`, *optional*, defaults to 7.0):
715
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
716
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
717
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
718
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
719
+ usually at the expense of lower image quality.
720
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
721
+ The number of videos to generate per prompt.
722
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
723
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
724
+ to make generation deterministic.
725
+ latents (`torch.FloatTensor`, *optional*):
726
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
727
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
728
+ tensor will ge generated by sampling using the supplied random `generator`.
729
+ prompt_embeds (`torch.FloatTensor`, *optional*):
730
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
731
+ provided, text embeddings will be generated from `prompt` input argument.
732
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
733
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
734
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
735
+ argument.
736
+ output_type (`str`, *optional*, defaults to `"pil"`):
737
+ The output format of the generate image. Choose between
738
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
739
+ return_dict (`bool`, *optional*, defaults to `True`):
740
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
741
+ of a plain tuple.
742
+ callback_on_step_end (`Callable`, *optional*):
743
+ A function that calls at the end of each denoising steps during the inference. The function is called
744
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
745
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
746
+ `callback_on_step_end_tensor_inputs`.
747
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
748
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
749
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
750
+ `._callback_tensor_inputs` attribute of your pipeline class.
751
+ max_sequence_length (`int`, defaults to `226`):
752
+ Maximum sequence length in encoded prompt. Must be consistent with
753
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
754
+
755
+ Examples:
756
+
757
+ Returns:
758
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] or `tuple`:
759
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] if `return_dict` is True, otherwise a
760
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
761
+ """
762
+
763
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
764
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
765
+
766
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
767
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
768
+ num_frames = num_frames or self.transformer.config.sample_frames
769
+
770
+ num_videos_per_prompt = 1
771
+
772
+ # 1. Check inputs. Raise error if not correct
773
+ self.check_inputs(
774
+ prompt,
775
+ height,
776
+ width,
777
+ negative_prompt,
778
+ callback_on_step_end_tensor_inputs,
779
+ prompt_embeds,
780
+ negative_prompt_embeds,
781
+ )
782
+ self._guidance_scale = guidance_scale
783
+ self._attention_kwargs = attention_kwargs
784
+ self._interrupt = False
785
+
786
+ # 2. Default call parameters
787
+ if prompt is not None and isinstance(prompt, str):
788
+ batch_size = 1
789
+ elif prompt is not None and isinstance(prompt, list):
790
+ batch_size = len(prompt)
791
+ else:
792
+ batch_size = prompt_embeds.shape[0]
793
+
794
+ device = self._execution_device
795
+
796
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
797
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
798
+ # corresponds to doing no classifier free guidance.
799
+ do_classifier_free_guidance = guidance_scale > 1.0
800
+
801
+ # 3. Encode input prompt
802
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
803
+ prompt,
804
+ negative_prompt,
805
+ do_classifier_free_guidance,
806
+ num_videos_per_prompt=num_videos_per_prompt,
807
+ prompt_embeds=prompt_embeds,
808
+ negative_prompt_embeds=negative_prompt_embeds,
809
+ max_sequence_length=max_sequence_length,
810
+ device=device,
811
+ )
812
+ if do_classifier_free_guidance:
813
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
814
+
815
+ # 4. Prepare timesteps
816
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
817
+ self._num_timesteps = len(timesteps)
818
+ if comfyui_progressbar:
819
+ from comfy.utils import ProgressBar
820
+ pbar = ProgressBar(num_inference_steps + 2)
821
+
822
+ if control_video is not None:
823
+ video_length = control_video.shape[2]
824
+ control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width)
825
+ control_video = control_video.to(dtype=torch.float32)
826
+ control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length)
827
+ else:
828
+ control_video = None
829
+
830
+ # Magvae needs the number of frames to be 4n + 1.
831
+ local_latent_length = (num_frames - 1) // self.vae_scale_factor_temporal + 1
832
+ # For CogVideoX 1.5, the latent frames should be clipped to make it divisible by patch_size_t
833
+ patch_size_t = self.transformer.config.patch_size_t
834
+ additional_frames = 0
835
+ if patch_size_t is not None and local_latent_length % patch_size_t != 0:
836
+ additional_frames = local_latent_length % patch_size_t
837
+ num_frames -= additional_frames * self.vae_scale_factor_temporal
838
+ if num_frames <= 0:
839
+ num_frames = 1
840
+ if video_length > num_frames:
841
+ logger.warning("The length of condition video is not right, the latent frames should be clipped to make it divisible by patch_size_t. ")
842
+ video_length = num_frames
843
+ control_video = control_video[:, :, :video_length]
844
+
845
+ # 5. Prepare latents.
846
+ latent_channels = self.vae.config.latent_channels
847
+ latents = self.prepare_latents(
848
+ batch_size * num_videos_per_prompt,
849
+ latent_channels,
850
+ num_frames,
851
+ height,
852
+ width,
853
+ prompt_embeds.dtype,
854
+ device,
855
+ generator,
856
+ latents,
857
+ )
858
+ if comfyui_progressbar:
859
+ pbar.update(1)
860
+
861
+ control_video_latents = self.prepare_control_latents(
862
+ None,
863
+ control_video,
864
+ batch_size,
865
+ height,
866
+ width,
867
+ prompt_embeds.dtype,
868
+ device,
869
+ generator,
870
+ do_classifier_free_guidance
871
+ )[1]
872
+ control_video_latents_input = (
873
+ torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents
874
+ )
875
+ control_latents = rearrange(control_video_latents_input, "b c f h w -> b f c h w")
876
+
877
+ if comfyui_progressbar:
878
+ pbar.update(1)
879
+
880
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
881
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
882
+
883
+ # 7. Create rotary embeds if required
884
+ image_rotary_emb = (
885
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
886
+ if self.transformer.config.use_rotary_positional_embeddings
887
+ else None
888
+ )
889
+
890
+ # 8. Denoising loop
891
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
892
+
893
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
894
+ # for DPM-solver++
895
+ old_pred_original_sample = None
896
+ for i, t in enumerate(timesteps):
897
+ if self.interrupt:
898
+ continue
899
+
900
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
901
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
902
+
903
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
904
+ timestep = t.expand(latent_model_input.shape[0])
905
+
906
+ # predict noise model_output
907
+ noise_pred = self.transformer(
908
+ hidden_states=latent_model_input,
909
+ encoder_hidden_states=prompt_embeds,
910
+ timestep=timestep,
911
+ image_rotary_emb=image_rotary_emb,
912
+ return_dict=False,
913
+ control_latents=control_latents,
914
+ )[0]
915
+ noise_pred = noise_pred.float()
916
+
917
+ # perform guidance
918
+ if use_dynamic_cfg:
919
+ self._guidance_scale = 1 + guidance_scale * (
920
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
921
+ )
922
+ if do_classifier_free_guidance:
923
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
924
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
925
+
926
+ # compute the previous noisy sample x_t -> x_t-1
927
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
928
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
929
+ else:
930
+ latents, old_pred_original_sample = self.scheduler.step(
931
+ noise_pred,
932
+ old_pred_original_sample,
933
+ t,
934
+ timesteps[i - 1] if i > 0 else None,
935
+ latents,
936
+ **extra_step_kwargs,
937
+ return_dict=False,
938
+ )
939
+ latents = latents.to(prompt_embeds.dtype)
940
+
941
+ # call the callback, if provided
942
+ if callback_on_step_end is not None:
943
+ callback_kwargs = {}
944
+ for k in callback_on_step_end_tensor_inputs:
945
+ callback_kwargs[k] = locals()[k]
946
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
947
+
948
+ latents = callback_outputs.pop("latents", latents)
949
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
950
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
951
+
952
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
953
+ progress_bar.update()
954
+ if comfyui_progressbar:
955
+ pbar.update(1)
956
+
957
+ if output_type == "numpy":
958
+ video = self.decode_latents(latents)
959
+ elif not output_type == "latent":
960
+ video = self.decode_latents(latents)
961
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
962
+ else:
963
+ video = latents
964
+
965
+ # Offload all models
966
+ self.maybe_free_model_hooks()
967
+
968
+ if not return_dict:
969
+ video = torch.from_numpy(video)
970
+
971
+ return CogVideoXFunPipelineOutput(videos=video)
cogvideox/pipeline/pipeline_cogvideox_fun_inpaint.py ADDED
@@ -0,0 +1,1151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
25
+ from diffusers.image_processor import VaeImageProcessor
26
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
27
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
28
+ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
29
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
30
+ from diffusers.utils.torch_utils import randn_tensor
31
+ from diffusers.video_processor import VideoProcessor
32
+ from einops import rearrange
33
+
34
+ from ..models import (AutoencoderKLCogVideoX,
35
+ CogVideoXTransformer3DModel, T5EncoderModel,
36
+ T5Tokenizer)
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+
41
+ EXAMPLE_DOC_STRING = """
42
+ Examples:
43
+ ```python
44
+ >>> import torch
45
+ >>> from diffusers import CogVideoXFunPipeline
46
+ >>> from diffusers.utils import export_to_video
47
+
48
+ >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
49
+ >>> pipe = CogVideoXFunPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
50
+ >>> prompt = (
51
+ ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
52
+ ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
53
+ ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
54
+ ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
55
+ ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
56
+ ... "atmosphere of this unique musical performance."
57
+ ... )
58
+ >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
59
+ >>> export_to_video(video, "output.mp4", fps=8)
60
+ ```
61
+ """
62
+
63
+ # Copied from diffusers.models.embeddings.get_3d_rotary_pos_embed
64
+ def get_3d_rotary_pos_embed(
65
+ embed_dim,
66
+ crops_coords,
67
+ grid_size,
68
+ temporal_size,
69
+ theta: int = 10000,
70
+ use_real: bool = True,
71
+ grid_type: str = "linspace",
72
+ max_size: Optional[Tuple[int, int]] = None,
73
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
74
+ """
75
+ RoPE for video tokens with 3D structure.
76
+
77
+ Args:
78
+ embed_dim: (`int`):
79
+ The embedding dimension size, corresponding to hidden_size_head.
80
+ crops_coords (`Tuple[int]`):
81
+ The top-left and bottom-right coordinates of the crop.
82
+ grid_size (`Tuple[int]`):
83
+ The grid size of the spatial positional embedding (height, width).
84
+ temporal_size (`int`):
85
+ The size of the temporal dimension.
86
+ theta (`float`):
87
+ Scaling factor for frequency computation.
88
+ grid_type (`str`):
89
+ Whether to use "linspace" or "slice" to compute grids.
90
+
91
+ Returns:
92
+ `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
93
+ """
94
+ if use_real is not True:
95
+ raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
96
+
97
+ if grid_type == "linspace":
98
+ start, stop = crops_coords
99
+ grid_size_h, grid_size_w = grid_size
100
+ grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
101
+ grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
102
+ grid_t = np.arange(temporal_size, dtype=np.float32)
103
+ grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
104
+ elif grid_type == "slice":
105
+ max_h, max_w = max_size
106
+ grid_size_h, grid_size_w = grid_size
107
+ grid_h = np.arange(max_h, dtype=np.float32)
108
+ grid_w = np.arange(max_w, dtype=np.float32)
109
+ grid_t = np.arange(temporal_size, dtype=np.float32)
110
+ else:
111
+ raise ValueError("Invalid value passed for `grid_type`.")
112
+
113
+ # Compute dimensions for each axis
114
+ dim_t = embed_dim // 4
115
+ dim_h = embed_dim // 8 * 3
116
+ dim_w = embed_dim // 8 * 3
117
+
118
+ # Temporal frequencies
119
+ freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
120
+ # Spatial frequencies for height and width
121
+ freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
122
+ freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
123
+
124
+ # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
125
+ def combine_time_height_width(freqs_t, freqs_h, freqs_w):
126
+ freqs_t = freqs_t[:, None, None, :].expand(
127
+ -1, grid_size_h, grid_size_w, -1
128
+ ) # temporal_size, grid_size_h, grid_size_w, dim_t
129
+ freqs_h = freqs_h[None, :, None, :].expand(
130
+ temporal_size, -1, grid_size_w, -1
131
+ ) # temporal_size, grid_size_h, grid_size_2, dim_h
132
+ freqs_w = freqs_w[None, None, :, :].expand(
133
+ temporal_size, grid_size_h, -1, -1
134
+ ) # temporal_size, grid_size_h, grid_size_2, dim_w
135
+
136
+ freqs = torch.cat(
137
+ [freqs_t, freqs_h, freqs_w], dim=-1
138
+ ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
139
+ freqs = freqs.view(
140
+ temporal_size * grid_size_h * grid_size_w, -1
141
+ ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
142
+ return freqs
143
+
144
+ t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
145
+ h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
146
+ w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
147
+
148
+ if grid_type == "slice":
149
+ t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size]
150
+ h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h]
151
+ w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w]
152
+
153
+ cos = combine_time_height_width(t_cos, h_cos, w_cos)
154
+ sin = combine_time_height_width(t_sin, h_sin, w_sin)
155
+ return cos, sin
156
+
157
+
158
+ # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
159
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
160
+ tw = tgt_width
161
+ th = tgt_height
162
+ h, w = src
163
+ r = h / w
164
+ if r > (th / tw):
165
+ resize_height = th
166
+ resize_width = int(round(th / h * w))
167
+ else:
168
+ resize_width = tw
169
+ resize_height = int(round(tw / w * h))
170
+
171
+ crop_top = int(round((th - resize_height) / 2.0))
172
+ crop_left = int(round((tw - resize_width) / 2.0))
173
+
174
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
175
+
176
+
177
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
178
+ def retrieve_timesteps(
179
+ scheduler,
180
+ num_inference_steps: Optional[int] = None,
181
+ device: Optional[Union[str, torch.device]] = None,
182
+ timesteps: Optional[List[int]] = None,
183
+ sigmas: Optional[List[float]] = None,
184
+ **kwargs,
185
+ ):
186
+ """
187
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
188
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
189
+
190
+ Args:
191
+ scheduler (`SchedulerMixin`):
192
+ The scheduler to get timesteps from.
193
+ num_inference_steps (`int`):
194
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
195
+ must be `None`.
196
+ device (`str` or `torch.device`, *optional*):
197
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
198
+ timesteps (`List[int]`, *optional*):
199
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
200
+ `num_inference_steps` and `sigmas` must be `None`.
201
+ sigmas (`List[float]`, *optional*):
202
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
203
+ `num_inference_steps` and `timesteps` must be `None`.
204
+
205
+ Returns:
206
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
207
+ second element is the number of inference steps.
208
+ """
209
+ if timesteps is not None and sigmas is not None:
210
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
211
+ if timesteps is not None:
212
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
213
+ if not accepts_timesteps:
214
+ raise ValueError(
215
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
216
+ f" timestep schedules. Please check whether you are using the correct scheduler."
217
+ )
218
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
219
+ timesteps = scheduler.timesteps
220
+ num_inference_steps = len(timesteps)
221
+ elif sigmas is not None:
222
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
223
+ if not accept_sigmas:
224
+ raise ValueError(
225
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
226
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
227
+ )
228
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
229
+ timesteps = scheduler.timesteps
230
+ num_inference_steps = len(timesteps)
231
+ else:
232
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
233
+ timesteps = scheduler.timesteps
234
+ return timesteps, num_inference_steps
235
+
236
+
237
+ def resize_mask(mask, latent, process_first_frame_only=True):
238
+ latent_size = latent.size()
239
+ batch_size, channels, num_frames, height, width = mask.shape
240
+
241
+ if process_first_frame_only:
242
+ target_size = list(latent_size[2:])
243
+ target_size[0] = 1
244
+ first_frame_resized = F.interpolate(
245
+ mask[:, :, 0:1, :, :],
246
+ size=target_size,
247
+ mode='trilinear',
248
+ align_corners=False
249
+ )
250
+
251
+ target_size = list(latent_size[2:])
252
+ target_size[0] = target_size[0] - 1
253
+ if target_size[0] != 0:
254
+ remaining_frames_resized = F.interpolate(
255
+ mask[:, :, 1:, :, :],
256
+ size=target_size,
257
+ mode='trilinear',
258
+ align_corners=False
259
+ )
260
+ resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
261
+ else:
262
+ resized_mask = first_frame_resized
263
+ else:
264
+ target_size = list(latent_size[2:])
265
+ resized_mask = F.interpolate(
266
+ mask,
267
+ size=target_size,
268
+ mode='trilinear',
269
+ align_corners=False
270
+ )
271
+ return resized_mask
272
+
273
+
274
+ def add_noise_to_reference_video(image, ratio=None):
275
+ if ratio is None:
276
+ sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device)
277
+ sigma = torch.exp(sigma).to(image.dtype)
278
+ else:
279
+ sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio
280
+
281
+ image_noise = torch.randn_like(image) * sigma[:, None, None, None, None]
282
+ image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise)
283
+ image = image + image_noise
284
+ return image
285
+
286
+
287
+ @dataclass
288
+ class CogVideoXFunPipelineOutput(BaseOutput):
289
+ r"""
290
+ Output class for CogVideo pipelines.
291
+
292
+ Args:
293
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
294
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
295
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
296
+ `(batch_size, num_frames, channels, height, width)`.
297
+ """
298
+
299
+ videos: torch.Tensor
300
+
301
+
302
+ class CogVideoXFunInpaintPipeline(DiffusionPipeline):
303
+ r"""
304
+ Pipeline for text-to-video generation using CogVideoX.
305
+
306
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
307
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
308
+
309
+ Args:
310
+ vae ([`AutoencoderKL`]):
311
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
312
+ text_encoder ([`T5EncoderModel`]):
313
+ Frozen text-encoder. CogVideoX_Fun uses
314
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
315
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
316
+ tokenizer (`T5Tokenizer`):
317
+ Tokenizer of class
318
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
319
+ transformer ([`CogVideoXTransformer3DModel`]):
320
+ A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
321
+ scheduler ([`SchedulerMixin`]):
322
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
323
+ """
324
+
325
+ _optional_components = []
326
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
327
+
328
+ _callback_tensor_inputs = [
329
+ "latents",
330
+ "prompt_embeds",
331
+ "negative_prompt_embeds",
332
+ ]
333
+
334
+ def __init__(
335
+ self,
336
+ tokenizer: T5Tokenizer,
337
+ text_encoder: T5EncoderModel,
338
+ vae: AutoencoderKLCogVideoX,
339
+ transformer: CogVideoXTransformer3DModel,
340
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
341
+ ):
342
+ super().__init__()
343
+
344
+ self.register_modules(
345
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
346
+ )
347
+ self.vae_scale_factor_spatial = (
348
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
349
+ )
350
+ self.vae_scale_factor_temporal = (
351
+ self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
352
+ )
353
+
354
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
355
+
356
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
357
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
358
+ self.mask_processor = VaeImageProcessor(
359
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
360
+ )
361
+
362
+ def _get_t5_prompt_embeds(
363
+ self,
364
+ prompt: Union[str, List[str]] = None,
365
+ num_videos_per_prompt: int = 1,
366
+ max_sequence_length: int = 226,
367
+ device: Optional[torch.device] = None,
368
+ dtype: Optional[torch.dtype] = None,
369
+ ):
370
+ device = device or self._execution_device
371
+ dtype = dtype or self.text_encoder.dtype
372
+
373
+ prompt = [prompt] if isinstance(prompt, str) else prompt
374
+ batch_size = len(prompt)
375
+
376
+ text_inputs = self.tokenizer(
377
+ prompt,
378
+ padding="max_length",
379
+ max_length=max_sequence_length,
380
+ truncation=True,
381
+ add_special_tokens=True,
382
+ return_tensors="pt",
383
+ )
384
+ text_input_ids = text_inputs.input_ids
385
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
386
+
387
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
388
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
389
+ logger.warning(
390
+ "The following part of your input was truncated because `max_sequence_length` is set to "
391
+ f" {max_sequence_length} tokens: {removed_text}"
392
+ )
393
+
394
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
395
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
396
+
397
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
398
+ _, seq_len, _ = prompt_embeds.shape
399
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
400
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
401
+
402
+ return prompt_embeds
403
+
404
+ def encode_prompt(
405
+ self,
406
+ prompt: Union[str, List[str]],
407
+ negative_prompt: Optional[Union[str, List[str]]] = None,
408
+ do_classifier_free_guidance: bool = True,
409
+ num_videos_per_prompt: int = 1,
410
+ prompt_embeds: Optional[torch.Tensor] = None,
411
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
412
+ max_sequence_length: int = 226,
413
+ device: Optional[torch.device] = None,
414
+ dtype: Optional[torch.dtype] = None,
415
+ ):
416
+ r"""
417
+ Encodes the prompt into text encoder hidden states.
418
+
419
+ Args:
420
+ prompt (`str` or `List[str]`, *optional*):
421
+ prompt to be encoded
422
+ negative_prompt (`str` or `List[str]`, *optional*):
423
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
424
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
425
+ less than `1`).
426
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
427
+ Whether to use classifier free guidance or not.
428
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
429
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
430
+ prompt_embeds (`torch.Tensor`, *optional*):
431
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
432
+ provided, text embeddings will be generated from `prompt` input argument.
433
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
434
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
435
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
436
+ argument.
437
+ device: (`torch.device`, *optional*):
438
+ torch device
439
+ dtype: (`torch.dtype`, *optional*):
440
+ torch dtype
441
+ """
442
+ device = device or self._execution_device
443
+
444
+ prompt = [prompt] if isinstance(prompt, str) else prompt
445
+ if prompt is not None:
446
+ batch_size = len(prompt)
447
+ else:
448
+ batch_size = prompt_embeds.shape[0]
449
+
450
+ if prompt_embeds is None:
451
+ prompt_embeds = self._get_t5_prompt_embeds(
452
+ prompt=prompt,
453
+ num_videos_per_prompt=num_videos_per_prompt,
454
+ max_sequence_length=max_sequence_length,
455
+ device=device,
456
+ dtype=dtype,
457
+ )
458
+
459
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
460
+ negative_prompt = negative_prompt or ""
461
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
462
+
463
+ if prompt is not None and type(prompt) is not type(negative_prompt):
464
+ raise TypeError(
465
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
466
+ f" {type(prompt)}."
467
+ )
468
+ elif batch_size != len(negative_prompt):
469
+ raise ValueError(
470
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
471
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
472
+ " the batch size of `prompt`."
473
+ )
474
+
475
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
476
+ prompt=negative_prompt,
477
+ num_videos_per_prompt=num_videos_per_prompt,
478
+ max_sequence_length=max_sequence_length,
479
+ device=device,
480
+ dtype=dtype,
481
+ )
482
+
483
+ return prompt_embeds, negative_prompt_embeds
484
+
485
+ def prepare_latents(
486
+ self,
487
+ batch_size,
488
+ num_channels_latents,
489
+ height,
490
+ width,
491
+ video_length,
492
+ dtype,
493
+ device,
494
+ generator,
495
+ latents=None,
496
+ video=None,
497
+ timestep=None,
498
+ is_strength_max=True,
499
+ return_noise=False,
500
+ return_video_latents=False,
501
+ ):
502
+ shape = (
503
+ batch_size,
504
+ (video_length - 1) // self.vae_scale_factor_temporal + 1,
505
+ num_channels_latents,
506
+ height // self.vae_scale_factor_spatial,
507
+ width // self.vae_scale_factor_spatial,
508
+ )
509
+ if isinstance(generator, list) and len(generator) != batch_size:
510
+ raise ValueError(
511
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
512
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
513
+ )
514
+
515
+ if return_video_latents or (latents is None and not is_strength_max):
516
+ video = video.to(device=device, dtype=self.vae.dtype)
517
+
518
+ bs = 1
519
+ new_video = []
520
+ for i in range(0, video.shape[0], bs):
521
+ video_bs = video[i : i + bs]
522
+ video_bs = self.vae.encode(video_bs)[0]
523
+ video_bs = video_bs.sample()
524
+ new_video.append(video_bs)
525
+ video = torch.cat(new_video, dim = 0)
526
+ video = video * self.vae.config.scaling_factor
527
+
528
+ video_latents = video.repeat(batch_size // video.shape[0], 1, 1, 1, 1)
529
+ video_latents = video_latents.to(device=device, dtype=dtype)
530
+ video_latents = rearrange(video_latents, "b c f h w -> b f c h w")
531
+
532
+ if latents is None:
533
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
534
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
535
+ latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep)
536
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
537
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
538
+ else:
539
+ noise = latents.to(device)
540
+ latents = noise * self.scheduler.init_noise_sigma
541
+
542
+ # scale the initial noise by the standard deviation required by the scheduler
543
+ outputs = (latents,)
544
+
545
+ if return_noise:
546
+ outputs += (noise,)
547
+
548
+ if return_video_latents:
549
+ outputs += (video_latents,)
550
+
551
+ return outputs
552
+
553
+ def prepare_mask_latents(
554
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength
555
+ ):
556
+ # resize the mask to latents shape as we concatenate the mask to the latents
557
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
558
+ # and half precision
559
+
560
+ if mask is not None:
561
+ mask = mask.to(device=device, dtype=self.vae.dtype)
562
+ bs = 1
563
+ new_mask = []
564
+ for i in range(0, mask.shape[0], bs):
565
+ mask_bs = mask[i : i + bs]
566
+ mask_bs = self.vae.encode(mask_bs)[0]
567
+ mask_bs = mask_bs.mode()
568
+ new_mask.append(mask_bs)
569
+ mask = torch.cat(new_mask, dim = 0)
570
+ mask = mask * self.vae.config.scaling_factor
571
+
572
+ if masked_image is not None:
573
+ if self.transformer.config.add_noise_in_inpaint_model:
574
+ masked_image = add_noise_to_reference_video(masked_image, ratio=noise_aug_strength)
575
+ masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
576
+ bs = 1
577
+ new_mask_pixel_values = []
578
+ for i in range(0, masked_image.shape[0], bs):
579
+ mask_pixel_values_bs = masked_image[i : i + bs]
580
+ mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
581
+ mask_pixel_values_bs = mask_pixel_values_bs.mode()
582
+ new_mask_pixel_values.append(mask_pixel_values_bs)
583
+ masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
584
+ masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
585
+ else:
586
+ masked_image_latents = None
587
+
588
+ return mask, masked_image_latents
589
+
590
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
591
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
592
+ latents = 1 / self.vae.config.scaling_factor * latents
593
+
594
+ frames = self.vae.decode(latents).sample
595
+ frames = (frames / 2 + 0.5).clamp(0, 1)
596
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
597
+ frames = frames.cpu().float().numpy()
598
+ return frames
599
+
600
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
601
+ def prepare_extra_step_kwargs(self, generator, eta):
602
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
603
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
604
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
605
+ # and should be between [0, 1]
606
+
607
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
608
+ extra_step_kwargs = {}
609
+ if accepts_eta:
610
+ extra_step_kwargs["eta"] = eta
611
+
612
+ # check if the scheduler accepts generator
613
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
614
+ if accepts_generator:
615
+ extra_step_kwargs["generator"] = generator
616
+ return extra_step_kwargs
617
+
618
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
619
+ def check_inputs(
620
+ self,
621
+ prompt,
622
+ height,
623
+ width,
624
+ negative_prompt,
625
+ callback_on_step_end_tensor_inputs,
626
+ prompt_embeds=None,
627
+ negative_prompt_embeds=None,
628
+ ):
629
+ if height % 8 != 0 or width % 8 != 0:
630
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
631
+
632
+ if callback_on_step_end_tensor_inputs is not None and not all(
633
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
634
+ ):
635
+ raise ValueError(
636
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
637
+ )
638
+ if prompt is not None and prompt_embeds is not None:
639
+ raise ValueError(
640
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
641
+ " only forward one of the two."
642
+ )
643
+ elif prompt is None and prompt_embeds is None:
644
+ raise ValueError(
645
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
646
+ )
647
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
648
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
649
+
650
+ if prompt is not None and negative_prompt_embeds is not None:
651
+ raise ValueError(
652
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
653
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
654
+ )
655
+
656
+ if negative_prompt is not None and negative_prompt_embeds is not None:
657
+ raise ValueError(
658
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
659
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
660
+ )
661
+
662
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
663
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
664
+ raise ValueError(
665
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
666
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
667
+ f" {negative_prompt_embeds.shape}."
668
+ )
669
+
670
+ def fuse_qkv_projections(self) -> None:
671
+ r"""Enables fused QKV projections."""
672
+ self.fusing_transformer = True
673
+ self.transformer.fuse_qkv_projections()
674
+
675
+ def unfuse_qkv_projections(self) -> None:
676
+ r"""Disable QKV projection fusion if enabled."""
677
+ if not self.fusing_transformer:
678
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
679
+ else:
680
+ self.transformer.unfuse_qkv_projections()
681
+ self.fusing_transformer = False
682
+
683
+ def _prepare_rotary_positional_embeddings(
684
+ self,
685
+ height: int,
686
+ width: int,
687
+ num_frames: int,
688
+ device: torch.device,
689
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
690
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
691
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
692
+
693
+ p = self.transformer.config.patch_size
694
+ p_t = self.transformer.config.patch_size_t
695
+
696
+ base_size_width = self.transformer.config.sample_width // p
697
+ base_size_height = self.transformer.config.sample_height // p
698
+
699
+ if p_t is None:
700
+ # CogVideoX 1.0
701
+ grid_crops_coords = get_resize_crop_region_for_grid(
702
+ (grid_height, grid_width), base_size_width, base_size_height
703
+ )
704
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
705
+ embed_dim=self.transformer.config.attention_head_dim,
706
+ crops_coords=grid_crops_coords,
707
+ grid_size=(grid_height, grid_width),
708
+ temporal_size=num_frames,
709
+ )
710
+ else:
711
+ # CogVideoX 1.5
712
+ base_num_frames = (num_frames + p_t - 1) // p_t
713
+
714
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
715
+ embed_dim=self.transformer.config.attention_head_dim,
716
+ crops_coords=None,
717
+ grid_size=(grid_height, grid_width),
718
+ temporal_size=base_num_frames,
719
+ grid_type="slice",
720
+ max_size=(base_size_height, base_size_width),
721
+ )
722
+
723
+ freqs_cos = freqs_cos.to(device=device)
724
+ freqs_sin = freqs_sin.to(device=device)
725
+ return freqs_cos, freqs_sin
726
+
727
+ @property
728
+ def guidance_scale(self):
729
+ return self._guidance_scale
730
+
731
+ @property
732
+ def num_timesteps(self):
733
+ return self._num_timesteps
734
+
735
+ @property
736
+ def attention_kwargs(self):
737
+ return self._attention_kwargs
738
+
739
+ @property
740
+ def interrupt(self):
741
+ return self._interrupt
742
+
743
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
744
+ def get_timesteps(self, num_inference_steps, strength, device):
745
+ # get the original timestep using init_timestep
746
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
747
+
748
+ t_start = max(num_inference_steps - init_timestep, 0)
749
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
750
+
751
+ return timesteps, num_inference_steps - t_start
752
+
753
+ @torch.no_grad()
754
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
755
+ def __call__(
756
+ self,
757
+ prompt: Optional[Union[str, List[str]]] = None,
758
+ negative_prompt: Optional[Union[str, List[str]]] = None,
759
+ height: int = 480,
760
+ width: int = 720,
761
+ video: Union[torch.FloatTensor] = None,
762
+ mask_video: Union[torch.FloatTensor] = None,
763
+ masked_video_latents: Union[torch.FloatTensor] = None,
764
+ num_frames: int = 49,
765
+ num_inference_steps: int = 50,
766
+ timesteps: Optional[List[int]] = None,
767
+ guidance_scale: float = 6,
768
+ use_dynamic_cfg: bool = False,
769
+ num_videos_per_prompt: int = 1,
770
+ eta: float = 0.0,
771
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
772
+ latents: Optional[torch.FloatTensor] = None,
773
+ prompt_embeds: Optional[torch.FloatTensor] = None,
774
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
775
+ output_type: str = "numpy",
776
+ return_dict: bool = False,
777
+ callback_on_step_end: Optional[
778
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
779
+ ] = None,
780
+ attention_kwargs: Optional[Dict[str, Any]] = None,
781
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
782
+ max_sequence_length: int = 226,
783
+ strength: float = 1,
784
+ noise_aug_strength: float = 0.0563,
785
+ comfyui_progressbar: bool = False,
786
+ ) -> Union[CogVideoXFunPipelineOutput, Tuple]:
787
+ """
788
+ Function invoked when calling the pipeline for generation.
789
+
790
+ Args:
791
+ prompt (`str` or `List[str]`, *optional*):
792
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
793
+ instead.
794
+ negative_prompt (`str` or `List[str]`, *optional*):
795
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
796
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
797
+ less than `1`).
798
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
799
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
800
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
801
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
802
+ num_frames (`int`, defaults to `48`):
803
+ Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
804
+ contain 1 extra frame because CogVideoX_Fun is conditioned with (num_seconds * fps + 1) frames where
805
+ num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
806
+ needs to be satisfied is that of divisibility mentioned above.
807
+ num_inference_steps (`int`, *optional*, defaults to 50):
808
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
809
+ expense of slower inference.
810
+ timesteps (`List[int]`, *optional*):
811
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
812
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
813
+ passed will be used. Must be in descending order.
814
+ guidance_scale (`float`, *optional*, defaults to 7.0):
815
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
816
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
817
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
818
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
819
+ usually at the expense of lower image quality.
820
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
821
+ The number of videos to generate per prompt.
822
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
823
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
824
+ to make generation deterministic.
825
+ latents (`torch.FloatTensor`, *optional*):
826
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
827
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
828
+ tensor will ge generated by sampling using the supplied random `generator`.
829
+ prompt_embeds (`torch.FloatTensor`, *optional*):
830
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
831
+ provided, text embeddings will be generated from `prompt` input argument.
832
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
833
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
834
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
835
+ argument.
836
+ output_type (`str`, *optional*, defaults to `"pil"`):
837
+ The output format of the generate image. Choose between
838
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
839
+ return_dict (`bool`, *optional*, defaults to `True`):
840
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
841
+ of a plain tuple.
842
+ callback_on_step_end (`Callable`, *optional*):
843
+ A function that calls at the end of each denoising steps during the inference. The function is called
844
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
845
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
846
+ `callback_on_step_end_tensor_inputs`.
847
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
848
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
849
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
850
+ `._callback_tensor_inputs` attribute of your pipeline class.
851
+ max_sequence_length (`int`, defaults to `226`):
852
+ Maximum sequence length in encoded prompt. Must be consistent with
853
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
854
+
855
+ Examples:
856
+
857
+ Returns:
858
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] or `tuple`:
859
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] if `return_dict` is True, otherwise a
860
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
861
+ """
862
+
863
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
864
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
865
+
866
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
867
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
868
+ num_frames = num_frames or self.transformer.config.sample_frames
869
+
870
+ num_videos_per_prompt = 1
871
+
872
+ # 1. Check inputs. Raise error if not correct
873
+ self.check_inputs(
874
+ prompt,
875
+ height,
876
+ width,
877
+ negative_prompt,
878
+ callback_on_step_end_tensor_inputs,
879
+ prompt_embeds,
880
+ negative_prompt_embeds,
881
+ )
882
+ self._guidance_scale = guidance_scale
883
+ self._attention_kwargs = attention_kwargs
884
+ self._interrupt = False
885
+
886
+ # 2. Default call parameters
887
+ if prompt is not None and isinstance(prompt, str):
888
+ batch_size = 1
889
+ elif prompt is not None and isinstance(prompt, list):
890
+ batch_size = len(prompt)
891
+ else:
892
+ batch_size = prompt_embeds.shape[0]
893
+
894
+ device = self._execution_device
895
+
896
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
897
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
898
+ # corresponds to doing no classifier free guidance.
899
+ do_classifier_free_guidance = guidance_scale > 1.0
900
+
901
+ # 3. Encode input prompt
902
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
903
+ prompt,
904
+ negative_prompt,
905
+ do_classifier_free_guidance,
906
+ num_videos_per_prompt=num_videos_per_prompt,
907
+ prompt_embeds=prompt_embeds,
908
+ negative_prompt_embeds=negative_prompt_embeds,
909
+ max_sequence_length=max_sequence_length,
910
+ device=device,
911
+ )
912
+ if do_classifier_free_guidance:
913
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
914
+
915
+ # 4. set timesteps
916
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
917
+ timesteps, num_inference_steps = self.get_timesteps(
918
+ num_inference_steps=num_inference_steps, strength=strength, device=device
919
+ )
920
+ self._num_timesteps = len(timesteps)
921
+ if comfyui_progressbar:
922
+ from comfy.utils import ProgressBar
923
+ pbar = ProgressBar(num_inference_steps + 2)
924
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
925
+ latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
926
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
927
+ is_strength_max = strength == 1.0
928
+
929
+ # 5. Prepare latents.
930
+ if video is not None:
931
+ video_length = video.shape[2]
932
+ init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
933
+ init_video = init_video.to(dtype=torch.float32)
934
+ init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length)
935
+ else:
936
+ init_video = None
937
+
938
+ # Magvae needs the number of frames to be 4n + 1.
939
+ local_latent_length = (num_frames - 1) // self.vae_scale_factor_temporal + 1
940
+ # For CogVideoX 1.5, the latent frames should be clipped to make it divisible by patch_size_t
941
+ patch_size_t = self.transformer.config.patch_size_t
942
+ additional_frames = 0
943
+ if patch_size_t is not None and local_latent_length % patch_size_t != 0:
944
+ additional_frames = local_latent_length % patch_size_t
945
+ num_frames -= additional_frames * self.vae_scale_factor_temporal
946
+ if num_frames <= 0:
947
+ num_frames = 1
948
+ if video_length > num_frames:
949
+ logger.warning("The length of condition video is not right, the latent frames should be clipped to make it divisible by patch_size_t. ")
950
+ video_length = num_frames
951
+ video = video[:, :, :video_length]
952
+ init_video = init_video[:, :, :video_length]
953
+ mask_video = mask_video[:, :, :video_length]
954
+
955
+ num_channels_latents = self.vae.config.latent_channels
956
+ num_channels_transformer = self.transformer.config.in_channels
957
+ return_image_latents = num_channels_transformer == num_channels_latents
958
+
959
+ latents_outputs = self.prepare_latents(
960
+ batch_size * num_videos_per_prompt,
961
+ num_channels_latents,
962
+ height,
963
+ width,
964
+ video_length,
965
+ prompt_embeds.dtype,
966
+ device,
967
+ generator,
968
+ latents,
969
+ video=init_video,
970
+ timestep=latent_timestep,
971
+ is_strength_max=is_strength_max,
972
+ return_noise=True,
973
+ return_video_latents=return_image_latents,
974
+ )
975
+ if return_image_latents:
976
+ latents, noise, image_latents = latents_outputs
977
+ else:
978
+ latents, noise = latents_outputs
979
+ if comfyui_progressbar:
980
+ pbar.update(1)
981
+
982
+ if mask_video is not None:
983
+ if (mask_video == 255).all():
984
+ mask_latents = torch.zeros_like(latents)[:, :, :1].to(latents.device, latents.dtype)
985
+ masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
986
+
987
+ mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
988
+ masked_video_latents_input = (
989
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
990
+ )
991
+ inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=2).to(latents.dtype)
992
+ else:
993
+ # Prepare mask latent variables
994
+ video_length = video.shape[2]
995
+ mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
996
+ mask_condition = mask_condition.to(dtype=torch.float32)
997
+ mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
998
+
999
+ if num_channels_transformer != num_channels_latents:
1000
+ mask_condition_tile = torch.tile(mask_condition, [1, 3, 1, 1, 1])
1001
+ if masked_video_latents is None:
1002
+ masked_video = init_video * (mask_condition_tile < 0.5) + torch.ones_like(init_video) * (mask_condition_tile > 0.5) * -1
1003
+ else:
1004
+ masked_video = masked_video_latents
1005
+
1006
+ _, masked_video_latents = self.prepare_mask_latents(
1007
+ None,
1008
+ masked_video,
1009
+ batch_size,
1010
+ height,
1011
+ width,
1012
+ prompt_embeds.dtype,
1013
+ device,
1014
+ generator,
1015
+ do_classifier_free_guidance,
1016
+ noise_aug_strength=noise_aug_strength,
1017
+ )
1018
+ mask_latents = resize_mask(1 - mask_condition, masked_video_latents)
1019
+ mask_latents = mask_latents.to(masked_video_latents.device) * self.vae.config.scaling_factor
1020
+
1021
+ mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1])
1022
+ mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
1023
+
1024
+ mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
1025
+ masked_video_latents_input = (
1026
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
1027
+ )
1028
+
1029
+ mask = rearrange(mask, "b c f h w -> b f c h w")
1030
+ mask_input = rearrange(mask_input, "b c f h w -> b f c h w")
1031
+ masked_video_latents_input = rearrange(masked_video_latents_input, "b c f h w -> b f c h w")
1032
+
1033
+ inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=2).to(latents.dtype)
1034
+ else:
1035
+ mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1])
1036
+ mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
1037
+ mask = rearrange(mask, "b c f h w -> b f c h w")
1038
+
1039
+ inpaint_latents = None
1040
+ else:
1041
+ if num_channels_transformer != num_channels_latents:
1042
+ mask = torch.zeros_like(latents).to(latents.device, latents.dtype)
1043
+ masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
1044
+
1045
+ mask_input = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
1046
+ masked_video_latents_input = (
1047
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
1048
+ )
1049
+ inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
1050
+ else:
1051
+ mask = torch.zeros_like(init_video[:, :1])
1052
+ mask = torch.tile(mask, [1, num_channels_latents, 1, 1, 1])
1053
+ mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
1054
+ mask = rearrange(mask, "b c f h w -> b f c h w")
1055
+
1056
+ inpaint_latents = None
1057
+ if comfyui_progressbar:
1058
+ pbar.update(1)
1059
+
1060
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1061
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1062
+
1063
+ # 7. Create rotary embeds if required
1064
+ image_rotary_emb = (
1065
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
1066
+ if self.transformer.config.use_rotary_positional_embeddings
1067
+ else None
1068
+ )
1069
+
1070
+ # 8. Denoising loop
1071
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1072
+
1073
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1074
+ # for DPM-solver++
1075
+ old_pred_original_sample = None
1076
+ for i, t in enumerate(timesteps):
1077
+ if self.interrupt:
1078
+ continue
1079
+
1080
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1081
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1082
+
1083
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1084
+ timestep = t.expand(latent_model_input.shape[0])
1085
+
1086
+ # predict noise model_output
1087
+ noise_pred = self.transformer(
1088
+ hidden_states=latent_model_input,
1089
+ encoder_hidden_states=prompt_embeds,
1090
+ timestep=timestep,
1091
+ image_rotary_emb=image_rotary_emb,
1092
+ return_dict=False,
1093
+ inpaint_latents=inpaint_latents,
1094
+ )[0]
1095
+ noise_pred = noise_pred.float()
1096
+
1097
+ # perform guidance
1098
+ if use_dynamic_cfg:
1099
+ self._guidance_scale = 1 + guidance_scale * (
1100
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
1101
+ )
1102
+ if do_classifier_free_guidance:
1103
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1104
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1105
+
1106
+ # compute the previous noisy sample x_t -> x_t-1
1107
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
1108
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1109
+ else:
1110
+ latents, old_pred_original_sample = self.scheduler.step(
1111
+ noise_pred,
1112
+ old_pred_original_sample,
1113
+ t,
1114
+ timesteps[i - 1] if i > 0 else None,
1115
+ latents,
1116
+ **extra_step_kwargs,
1117
+ return_dict=False,
1118
+ )
1119
+ latents = latents.to(prompt_embeds.dtype)
1120
+
1121
+ # call the callback, if provided
1122
+ if callback_on_step_end is not None:
1123
+ callback_kwargs = {}
1124
+ for k in callback_on_step_end_tensor_inputs:
1125
+ callback_kwargs[k] = locals()[k]
1126
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1127
+
1128
+ latents = callback_outputs.pop("latents", latents)
1129
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1130
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1131
+
1132
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1133
+ progress_bar.update()
1134
+ if comfyui_progressbar:
1135
+ pbar.update(1)
1136
+
1137
+ if output_type == "numpy":
1138
+ video = self.decode_latents(latents)
1139
+ elif not output_type == "latent":
1140
+ video = self.decode_latents(latents)
1141
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
1142
+ else:
1143
+ video = latents
1144
+
1145
+ # Offload all models
1146
+ self.maybe_free_model_hooks()
1147
+
1148
+ if not return_dict:
1149
+ video = torch.from_numpy(video)
1150
+
1151
+ return CogVideoXFunPipelineOutput(videos=video)
cogvideox/pipeline/pipeline_wan_fun.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from diffusers import FlowMatchEulerDiscreteScheduler
9
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
10
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
11
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
12
+ from diffusers.utils.torch_utils import randn_tensor
13
+ from diffusers.video_processor import VideoProcessor
14
+
15
+ from ..models import (AutoencoderKLWan, AutoTokenizer,
16
+ WanT5EncoderModel, WanTransformer3DModel)
17
+
18
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
19
+
20
+
21
+ EXAMPLE_DOC_STRING = """
22
+ Examples:
23
+ ```python
24
+ pass
25
+ ```
26
+ """
27
+
28
+
29
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
30
+ def retrieve_timesteps(
31
+ scheduler,
32
+ num_inference_steps: Optional[int] = None,
33
+ device: Optional[Union[str, torch.device]] = None,
34
+ timesteps: Optional[List[int]] = None,
35
+ sigmas: Optional[List[float]] = None,
36
+ **kwargs,
37
+ ):
38
+ """
39
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
40
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
41
+
42
+ Args:
43
+ scheduler (`SchedulerMixin`):
44
+ The scheduler to get timesteps from.
45
+ num_inference_steps (`int`):
46
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
47
+ must be `None`.
48
+ device (`str` or `torch.device`, *optional*):
49
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
50
+ timesteps (`List[int]`, *optional*):
51
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
52
+ `num_inference_steps` and `sigmas` must be `None`.
53
+ sigmas (`List[float]`, *optional*):
54
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
55
+ `num_inference_steps` and `timesteps` must be `None`.
56
+
57
+ Returns:
58
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
59
+ second element is the number of inference steps.
60
+ """
61
+ if timesteps is not None and sigmas is not None:
62
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
63
+ if timesteps is not None:
64
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
65
+ if not accepts_timesteps:
66
+ raise ValueError(
67
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
68
+ f" timestep schedules. Please check whether you are using the correct scheduler."
69
+ )
70
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
71
+ timesteps = scheduler.timesteps
72
+ num_inference_steps = len(timesteps)
73
+ elif sigmas is not None:
74
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
75
+ if not accept_sigmas:
76
+ raise ValueError(
77
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
78
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
79
+ )
80
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
81
+ timesteps = scheduler.timesteps
82
+ num_inference_steps = len(timesteps)
83
+ else:
84
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
85
+ timesteps = scheduler.timesteps
86
+ return timesteps, num_inference_steps
87
+
88
+
89
+ @dataclass
90
+ class WanPipelineOutput(BaseOutput):
91
+ r"""
92
+ Output class for CogVideo pipelines.
93
+
94
+ Args:
95
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
96
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
97
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
98
+ `(batch_size, num_frames, channels, height, width)`.
99
+ """
100
+
101
+ videos: torch.Tensor
102
+
103
+
104
+ class WanFunPipeline(DiffusionPipeline):
105
+ r"""
106
+ Pipeline for text-to-video generation using Wan.
107
+
108
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
109
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
110
+ """
111
+
112
+ _optional_components = []
113
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
114
+
115
+ _callback_tensor_inputs = [
116
+ "latents",
117
+ "prompt_embeds",
118
+ "negative_prompt_embeds",
119
+ ]
120
+
121
+ def __init__(
122
+ self,
123
+ tokenizer: AutoTokenizer,
124
+ text_encoder: WanT5EncoderModel,
125
+ vae: AutoencoderKLWan,
126
+ transformer: WanTransformer3DModel,
127
+ scheduler: FlowMatchEulerDiscreteScheduler,
128
+ ):
129
+ super().__init__()
130
+
131
+ self.register_modules(
132
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
133
+ )
134
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spacial_compression_ratio)
135
+
136
+ def _get_t5_prompt_embeds(
137
+ self,
138
+ prompt: Union[str, List[str]] = None,
139
+ num_videos_per_prompt: int = 1,
140
+ max_sequence_length: int = 512,
141
+ device: Optional[torch.device] = None,
142
+ dtype: Optional[torch.dtype] = None,
143
+ ):
144
+ device = device or self._execution_device
145
+ dtype = dtype or self.text_encoder.dtype
146
+
147
+ prompt = [prompt] if isinstance(prompt, str) else prompt
148
+ batch_size = len(prompt)
149
+
150
+ text_inputs = self.tokenizer(
151
+ prompt,
152
+ padding="max_length",
153
+ max_length=max_sequence_length,
154
+ truncation=True,
155
+ add_special_tokens=True,
156
+ return_tensors="pt",
157
+ )
158
+ text_input_ids = text_inputs.input_ids
159
+ prompt_attention_mask = text_inputs.attention_mask
160
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
161
+
162
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
163
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
164
+ logger.warning(
165
+ "The following part of your input was truncated because `max_sequence_length` is set to "
166
+ f" {max_sequence_length} tokens: {removed_text}"
167
+ )
168
+
169
+ seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
170
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
171
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
172
+
173
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
174
+ _, seq_len, _ = prompt_embeds.shape
175
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
176
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
177
+
178
+ return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
179
+
180
+ def encode_prompt(
181
+ self,
182
+ prompt: Union[str, List[str]],
183
+ negative_prompt: Optional[Union[str, List[str]]] = None,
184
+ do_classifier_free_guidance: bool = True,
185
+ num_videos_per_prompt: int = 1,
186
+ prompt_embeds: Optional[torch.Tensor] = None,
187
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
188
+ max_sequence_length: int = 512,
189
+ device: Optional[torch.device] = None,
190
+ dtype: Optional[torch.dtype] = None,
191
+ ):
192
+ r"""
193
+ Encodes the prompt into text encoder hidden states.
194
+
195
+ Args:
196
+ prompt (`str` or `List[str]`, *optional*):
197
+ prompt to be encoded
198
+ negative_prompt (`str` or `List[str]`, *optional*):
199
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
200
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
201
+ less than `1`).
202
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
203
+ Whether to use classifier free guidance or not.
204
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
205
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
206
+ prompt_embeds (`torch.Tensor`, *optional*):
207
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
208
+ provided, text embeddings will be generated from `prompt` input argument.
209
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
210
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
211
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
212
+ argument.
213
+ device: (`torch.device`, *optional*):
214
+ torch device
215
+ dtype: (`torch.dtype`, *optional*):
216
+ torch dtype
217
+ """
218
+ device = device or self._execution_device
219
+
220
+ prompt = [prompt] if isinstance(prompt, str) else prompt
221
+ if prompt is not None:
222
+ batch_size = len(prompt)
223
+ else:
224
+ batch_size = prompt_embeds.shape[0]
225
+
226
+ if prompt_embeds is None:
227
+ prompt_embeds = self._get_t5_prompt_embeds(
228
+ prompt=prompt,
229
+ num_videos_per_prompt=num_videos_per_prompt,
230
+ max_sequence_length=max_sequence_length,
231
+ device=device,
232
+ dtype=dtype,
233
+ )
234
+
235
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
236
+ negative_prompt = negative_prompt or ""
237
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
238
+
239
+ if prompt is not None and type(prompt) is not type(negative_prompt):
240
+ raise TypeError(
241
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
242
+ f" {type(prompt)}."
243
+ )
244
+ elif batch_size != len(negative_prompt):
245
+ raise ValueError(
246
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
247
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
248
+ " the batch size of `prompt`."
249
+ )
250
+
251
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
252
+ prompt=negative_prompt,
253
+ num_videos_per_prompt=num_videos_per_prompt,
254
+ max_sequence_length=max_sequence_length,
255
+ device=device,
256
+ dtype=dtype,
257
+ )
258
+
259
+ return prompt_embeds, negative_prompt_embeds
260
+
261
+ def prepare_latents(
262
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
263
+ ):
264
+ if isinstance(generator, list) and len(generator) != batch_size:
265
+ raise ValueError(
266
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
267
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
268
+ )
269
+
270
+ shape = (
271
+ batch_size,
272
+ num_channels_latents,
273
+ (num_frames - 1) // self.vae.temporal_compression_ratio + 1,
274
+ height // self.vae.spacial_compression_ratio,
275
+ width // self.vae.spacial_compression_ratio,
276
+ )
277
+
278
+ if latents is None:
279
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
280
+ else:
281
+ latents = latents.to(device)
282
+
283
+ # scale the initial noise by the standard deviation required by the scheduler
284
+ if hasattr(self.scheduler, "init_noise_sigma"):
285
+ latents = latents * self.scheduler.init_noise_sigma
286
+ return latents
287
+
288
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
289
+ frames = self.vae.decode(latents.to(self.vae.dtype)).sample
290
+ frames = (frames / 2 + 0.5).clamp(0, 1)
291
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
292
+ frames = frames.cpu().float().numpy()
293
+ return frames
294
+
295
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
296
+ def prepare_extra_step_kwargs(self, generator, eta):
297
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
298
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
299
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
300
+ # and should be between [0, 1]
301
+
302
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
303
+ extra_step_kwargs = {}
304
+ if accepts_eta:
305
+ extra_step_kwargs["eta"] = eta
306
+
307
+ # check if the scheduler accepts generator
308
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
309
+ if accepts_generator:
310
+ extra_step_kwargs["generator"] = generator
311
+ return extra_step_kwargs
312
+
313
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
314
+ def check_inputs(
315
+ self,
316
+ prompt,
317
+ height,
318
+ width,
319
+ negative_prompt,
320
+ callback_on_step_end_tensor_inputs,
321
+ prompt_embeds=None,
322
+ negative_prompt_embeds=None,
323
+ ):
324
+ if height % 8 != 0 or width % 8 != 0:
325
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
326
+
327
+ if callback_on_step_end_tensor_inputs is not None and not all(
328
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
329
+ ):
330
+ raise ValueError(
331
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
332
+ )
333
+ if prompt is not None and prompt_embeds is not None:
334
+ raise ValueError(
335
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
336
+ " only forward one of the two."
337
+ )
338
+ elif prompt is None and prompt_embeds is None:
339
+ raise ValueError(
340
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
341
+ )
342
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
343
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
344
+
345
+ if prompt is not None and negative_prompt_embeds is not None:
346
+ raise ValueError(
347
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
348
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
349
+ )
350
+
351
+ if negative_prompt is not None and negative_prompt_embeds is not None:
352
+ raise ValueError(
353
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
354
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
355
+ )
356
+
357
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
358
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
359
+ raise ValueError(
360
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
361
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
362
+ f" {negative_prompt_embeds.shape}."
363
+ )
364
+
365
+ @property
366
+ def guidance_scale(self):
367
+ return self._guidance_scale
368
+
369
+ @property
370
+ def num_timesteps(self):
371
+ return self._num_timesteps
372
+
373
+ @property
374
+ def attention_kwargs(self):
375
+ return self._attention_kwargs
376
+
377
+ @property
378
+ def interrupt(self):
379
+ return self._interrupt
380
+
381
+ @torch.no_grad()
382
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
383
+ def __call__(
384
+ self,
385
+ prompt: Optional[Union[str, List[str]]] = None,
386
+ negative_prompt: Optional[Union[str, List[str]]] = None,
387
+ height: int = 480,
388
+ width: int = 720,
389
+ num_frames: int = 49,
390
+ num_inference_steps: int = 50,
391
+ timesteps: Optional[List[int]] = None,
392
+ guidance_scale: float = 6,
393
+ num_videos_per_prompt: int = 1,
394
+ eta: float = 0.0,
395
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
396
+ latents: Optional[torch.FloatTensor] = None,
397
+ prompt_embeds: Optional[torch.FloatTensor] = None,
398
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
399
+ output_type: str = "numpy",
400
+ return_dict: bool = False,
401
+ callback_on_step_end: Optional[
402
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
403
+ ] = None,
404
+ attention_kwargs: Optional[Dict[str, Any]] = None,
405
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
406
+ max_sequence_length: int = 512,
407
+ comfyui_progressbar: bool = False,
408
+ ) -> Union[WanPipelineOutput, Tuple]:
409
+ """
410
+ Function invoked when calling the pipeline for generation.
411
+ Args:
412
+
413
+ Examples:
414
+
415
+ Returns:
416
+
417
+ """
418
+
419
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
420
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
421
+ num_videos_per_prompt = 1
422
+
423
+ # 1. Check inputs. Raise error if not correct
424
+ self.check_inputs(
425
+ prompt,
426
+ height,
427
+ width,
428
+ negative_prompt,
429
+ callback_on_step_end_tensor_inputs,
430
+ prompt_embeds,
431
+ negative_prompt_embeds,
432
+ )
433
+ self._guidance_scale = guidance_scale
434
+ self._attention_kwargs = attention_kwargs
435
+ self._interrupt = False
436
+
437
+ # 2. Default call parameters
438
+ if prompt is not None and isinstance(prompt, str):
439
+ batch_size = 1
440
+ elif prompt is not None and isinstance(prompt, list):
441
+ batch_size = len(prompt)
442
+ else:
443
+ batch_size = prompt_embeds.shape[0]
444
+
445
+ device = self._execution_device
446
+ weight_dtype = self.text_encoder.dtype
447
+
448
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
449
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
450
+ # corresponds to doing no classifier free guidance.
451
+ do_classifier_free_guidance = guidance_scale > 1.0
452
+
453
+ # 3. Encode input prompt
454
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
455
+ prompt,
456
+ negative_prompt,
457
+ do_classifier_free_guidance,
458
+ num_videos_per_prompt=num_videos_per_prompt,
459
+ prompt_embeds=prompt_embeds,
460
+ negative_prompt_embeds=negative_prompt_embeds,
461
+ max_sequence_length=max_sequence_length,
462
+ device=device,
463
+ )
464
+ if do_classifier_free_guidance:
465
+ prompt_embeds = negative_prompt_embeds + prompt_embeds
466
+
467
+ # 4. Prepare timesteps
468
+ if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
469
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
470
+ else:
471
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
472
+ self._num_timesteps = len(timesteps)
473
+ if comfyui_progressbar:
474
+ from comfy.utils import ProgressBar
475
+ pbar = ProgressBar(num_inference_steps + 1)
476
+
477
+ # 5. Prepare latents
478
+ latent_channels = self.transformer.config.in_channels
479
+ latents = self.prepare_latents(
480
+ batch_size * num_videos_per_prompt,
481
+ latent_channels,
482
+ num_frames,
483
+ height,
484
+ width,
485
+ weight_dtype,
486
+ device,
487
+ generator,
488
+ latents,
489
+ )
490
+ if comfyui_progressbar:
491
+ pbar.update(1)
492
+
493
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
494
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
495
+
496
+ target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spacial_compression_ratio, height // self.vae.spacial_compression_ratio)
497
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
498
+ # 7. Denoising loop
499
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
500
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
501
+ for i, t in enumerate(timesteps):
502
+ if self.interrupt:
503
+ continue
504
+
505
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
506
+ if hasattr(self.scheduler, "scale_model_input"):
507
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
508
+
509
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
510
+ timestep = t.expand(latent_model_input.shape[0])
511
+
512
+ # predict noise model_output
513
+ with torch.cuda.amp.autocast(dtype=weight_dtype):
514
+ noise_pred = self.transformer(
515
+ x=latent_model_input,
516
+ context=prompt_embeds,
517
+ t=timestep,
518
+ seq_len=seq_len,
519
+ )
520
+
521
+ # perform guidance
522
+ if do_classifier_free_guidance:
523
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
524
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
525
+
526
+ # compute the previous noisy sample x_t -> x_t-1
527
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
528
+
529
+ if callback_on_step_end is not None:
530
+ callback_kwargs = {}
531
+ for k in callback_on_step_end_tensor_inputs:
532
+ callback_kwargs[k] = locals()[k]
533
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
534
+
535
+ latents = callback_outputs.pop("latents", latents)
536
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
537
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
538
+ prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2)
539
+ negative_prompt_embeds_2 = callback_outputs.pop(
540
+ "negative_prompt_embeds_2", negative_prompt_embeds_2
541
+ )
542
+
543
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
544
+ progress_bar.update()
545
+ if comfyui_progressbar:
546
+ pbar.update(1)
547
+
548
+ if output_type == "numpy":
549
+ video = self.decode_latents(latents)
550
+ elif not output_type == "latent":
551
+ video = self.decode_latents(latents)
552
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
553
+ else:
554
+ video = latents
555
+
556
+ # Offload all models
557
+ self.maybe_free_model_hooks()
558
+
559
+ if not return_dict:
560
+ video = torch.from_numpy(video)
561
+
562
+ return WanPipelineOutput(videos=video)
cogvideox/pipeline/pipeline_wan_fun_inpaint.py ADDED
@@ -0,0 +1,729 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torchvision.transforms.functional as TF
10
+ from diffusers import FlowMatchEulerDiscreteScheduler
11
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
12
+ from diffusers.image_processor import VaeImageProcessor
13
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
14
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
15
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
16
+ from diffusers.utils import BaseOutput, logging, replace_example_docstring
17
+ from diffusers.utils.torch_utils import randn_tensor
18
+ from diffusers.video_processor import VideoProcessor
19
+ from einops import rearrange
20
+ from PIL import Image
21
+ from transformers import T5Tokenizer
22
+
23
+ from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel,
24
+ WanT5EncoderModel, WanTransformer3DModel)
25
+
26
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
27
+
28
+
29
+ EXAMPLE_DOC_STRING = """
30
+ Examples:
31
+ ```python
32
+ pass
33
+ ```
34
+ """
35
+
36
+
37
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
38
+ def retrieve_timesteps(
39
+ scheduler,
40
+ num_inference_steps: Optional[int] = None,
41
+ device: Optional[Union[str, torch.device]] = None,
42
+ timesteps: Optional[List[int]] = None,
43
+ sigmas: Optional[List[float]] = None,
44
+ **kwargs,
45
+ ):
46
+ """
47
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
48
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
49
+
50
+ Args:
51
+ scheduler (`SchedulerMixin`):
52
+ The scheduler to get timesteps from.
53
+ num_inference_steps (`int`):
54
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
55
+ must be `None`.
56
+ device (`str` or `torch.device`, *optional*):
57
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
58
+ timesteps (`List[int]`, *optional*):
59
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
60
+ `num_inference_steps` and `sigmas` must be `None`.
61
+ sigmas (`List[float]`, *optional*):
62
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
63
+ `num_inference_steps` and `timesteps` must be `None`.
64
+
65
+ Returns:
66
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
67
+ second element is the number of inference steps.
68
+ """
69
+ if timesteps is not None and sigmas is not None:
70
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
71
+ if timesteps is not None:
72
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
73
+ if not accepts_timesteps:
74
+ raise ValueError(
75
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
76
+ f" timestep schedules. Please check whether you are using the correct scheduler."
77
+ )
78
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
79
+ timesteps = scheduler.timesteps
80
+ num_inference_steps = len(timesteps)
81
+ elif sigmas is not None:
82
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
83
+ if not accept_sigmas:
84
+ raise ValueError(
85
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
86
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
87
+ )
88
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
89
+ timesteps = scheduler.timesteps
90
+ num_inference_steps = len(timesteps)
91
+ else:
92
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
93
+ timesteps = scheduler.timesteps
94
+ return timesteps, num_inference_steps
95
+
96
+
97
+ def resize_mask(mask, latent, process_first_frame_only=True):
98
+ latent_size = latent.size()
99
+ batch_size, channels, num_frames, height, width = mask.shape
100
+
101
+ if process_first_frame_only:
102
+ target_size = list(latent_size[2:])
103
+ target_size[0] = 1
104
+ first_frame_resized = F.interpolate(
105
+ mask[:, :, 0:1, :, :],
106
+ size=target_size,
107
+ mode='trilinear',
108
+ align_corners=False
109
+ )
110
+
111
+ target_size = list(latent_size[2:])
112
+ target_size[0] = target_size[0] - 1
113
+ if target_size[0] != 0:
114
+ remaining_frames_resized = F.interpolate(
115
+ mask[:, :, 1:, :, :],
116
+ size=target_size,
117
+ mode='trilinear',
118
+ align_corners=False
119
+ )
120
+ resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
121
+ else:
122
+ resized_mask = first_frame_resized
123
+ else:
124
+ target_size = list(latent_size[2:])
125
+ resized_mask = F.interpolate(
126
+ mask,
127
+ size=target_size,
128
+ mode='trilinear',
129
+ align_corners=False
130
+ )
131
+ return resized_mask
132
+
133
+
134
+ @dataclass
135
+ class WanPipelineOutput(BaseOutput):
136
+ r"""
137
+ Output class for CogVideo pipelines.
138
+
139
+ Args:
140
+ video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
141
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
142
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
143
+ `(batch_size, num_frames, channels, height, width)`.
144
+ """
145
+
146
+ videos: torch.Tensor
147
+
148
+
149
+ class WanFunInpaintPipeline(DiffusionPipeline):
150
+ r"""
151
+ Pipeline for text-to-video generation using Wan.
152
+
153
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
154
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
155
+ """
156
+
157
+ _optional_components = []
158
+ model_cpu_offload_seq = "text_encoder->clip_image_encoder->transformer->vae"
159
+
160
+ _callback_tensor_inputs = [
161
+ "latents",
162
+ "prompt_embeds",
163
+ "negative_prompt_embeds",
164
+ ]
165
+
166
+ def __init__(
167
+ self,
168
+ tokenizer: AutoTokenizer,
169
+ text_encoder: WanT5EncoderModel,
170
+ vae: AutoencoderKLWan,
171
+ transformer: WanTransformer3DModel,
172
+ clip_image_encoder: CLIPModel,
173
+ scheduler: FlowMatchEulerDiscreteScheduler,
174
+ ):
175
+ super().__init__()
176
+
177
+ self.register_modules(
178
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, clip_image_encoder=clip_image_encoder, scheduler=scheduler
179
+ )
180
+
181
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spacial_compression_ratio)
182
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spacial_compression_ratio)
183
+ self.mask_processor = VaeImageProcessor(
184
+ vae_scale_factor=self.vae.spacial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True
185
+ )
186
+
187
+ def _get_t5_prompt_embeds(
188
+ self,
189
+ prompt: Union[str, List[str]] = None,
190
+ num_videos_per_prompt: int = 1,
191
+ max_sequence_length: int = 512,
192
+ device: Optional[torch.device] = None,
193
+ dtype: Optional[torch.dtype] = None,
194
+ ):
195
+ device = device or self._execution_device
196
+ dtype = dtype or self.text_encoder.dtype
197
+
198
+ prompt = [prompt] if isinstance(prompt, str) else prompt
199
+ batch_size = len(prompt)
200
+
201
+ text_inputs = self.tokenizer(
202
+ prompt,
203
+ padding="max_length",
204
+ max_length=max_sequence_length,
205
+ truncation=True,
206
+ add_special_tokens=True,
207
+ return_tensors="pt",
208
+ )
209
+ text_input_ids = text_inputs.input_ids
210
+ prompt_attention_mask = text_inputs.attention_mask
211
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
212
+
213
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
214
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
215
+ logger.warning(
216
+ "The following part of your input was truncated because `max_sequence_length` is set to "
217
+ f" {max_sequence_length} tokens: {removed_text}"
218
+ )
219
+
220
+ seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
221
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0]
222
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
223
+
224
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
225
+ _, seq_len, _ = prompt_embeds.shape
226
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
227
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
228
+
229
+ return [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
230
+
231
+ def encode_prompt(
232
+ self,
233
+ prompt: Union[str, List[str]],
234
+ negative_prompt: Optional[Union[str, List[str]]] = None,
235
+ do_classifier_free_guidance: bool = True,
236
+ num_videos_per_prompt: int = 1,
237
+ prompt_embeds: Optional[torch.Tensor] = None,
238
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
239
+ max_sequence_length: int = 512,
240
+ device: Optional[torch.device] = None,
241
+ dtype: Optional[torch.dtype] = None,
242
+ ):
243
+ r"""
244
+ Encodes the prompt into text encoder hidden states.
245
+
246
+ Args:
247
+ prompt (`str` or `List[str]`, *optional*):
248
+ prompt to be encoded
249
+ negative_prompt (`str` or `List[str]`, *optional*):
250
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
251
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
252
+ less than `1`).
253
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
254
+ Whether to use classifier free guidance or not.
255
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
256
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
257
+ prompt_embeds (`torch.Tensor`, *optional*):
258
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
259
+ provided, text embeddings will be generated from `prompt` input argument.
260
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
261
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
262
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
263
+ argument.
264
+ device: (`torch.device`, *optional*):
265
+ torch device
266
+ dtype: (`torch.dtype`, *optional*):
267
+ torch dtype
268
+ """
269
+ device = device or self._execution_device
270
+
271
+ prompt = [prompt] if isinstance(prompt, str) else prompt
272
+ if prompt is not None:
273
+ batch_size = len(prompt)
274
+ else:
275
+ batch_size = prompt_embeds.shape[0]
276
+
277
+ if prompt_embeds is None:
278
+ prompt_embeds = self._get_t5_prompt_embeds(
279
+ prompt=prompt,
280
+ num_videos_per_prompt=num_videos_per_prompt,
281
+ max_sequence_length=max_sequence_length,
282
+ device=device,
283
+ dtype=dtype,
284
+ )
285
+
286
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
287
+ negative_prompt = negative_prompt or ""
288
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
289
+
290
+ if prompt is not None and type(prompt) is not type(negative_prompt):
291
+ raise TypeError(
292
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
293
+ f" {type(prompt)}."
294
+ )
295
+ elif batch_size != len(negative_prompt):
296
+ raise ValueError(
297
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
298
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
299
+ " the batch size of `prompt`."
300
+ )
301
+
302
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
303
+ prompt=negative_prompt,
304
+ num_videos_per_prompt=num_videos_per_prompt,
305
+ max_sequence_length=max_sequence_length,
306
+ device=device,
307
+ dtype=dtype,
308
+ )
309
+
310
+ return prompt_embeds, negative_prompt_embeds
311
+
312
+ def prepare_latents(
313
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
314
+ ):
315
+ if isinstance(generator, list) and len(generator) != batch_size:
316
+ raise ValueError(
317
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
318
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
319
+ )
320
+
321
+ shape = (
322
+ batch_size,
323
+ num_channels_latents,
324
+ (num_frames - 1) // self.vae.temporal_compression_ratio + 1,
325
+ height // self.vae.spacial_compression_ratio,
326
+ width // self.vae.spacial_compression_ratio,
327
+ )
328
+
329
+ if latents is None:
330
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
331
+ else:
332
+ latents = latents.to(device)
333
+
334
+ # scale the initial noise by the standard deviation required by the scheduler
335
+ if hasattr(self.scheduler, "init_noise_sigma"):
336
+ latents = latents * self.scheduler.init_noise_sigma
337
+ return latents
338
+
339
+ def prepare_mask_latents(
340
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength
341
+ ):
342
+ # resize the mask to latents shape as we concatenate the mask to the latents
343
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
344
+ # and half precision
345
+
346
+ if mask is not None:
347
+ mask = mask.to(device=device, dtype=self.vae.dtype)
348
+ bs = 1
349
+ new_mask = []
350
+ for i in range(0, mask.shape[0], bs):
351
+ mask_bs = mask[i : i + bs]
352
+ mask_bs = self.vae.encode(mask_bs)[0]
353
+ mask_bs = mask_bs.mode()
354
+ new_mask.append(mask_bs)
355
+ mask = torch.cat(new_mask, dim = 0)
356
+ # mask = mask * self.vae.config.scaling_factor
357
+
358
+ if masked_image is not None:
359
+ masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
360
+ bs = 1
361
+ new_mask_pixel_values = []
362
+ for i in range(0, masked_image.shape[0], bs):
363
+ mask_pixel_values_bs = masked_image[i : i + bs]
364
+ mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
365
+ mask_pixel_values_bs = mask_pixel_values_bs.mode()
366
+ new_mask_pixel_values.append(mask_pixel_values_bs)
367
+ masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
368
+ # masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
369
+ else:
370
+ masked_image_latents = None
371
+
372
+ return mask, masked_image_latents
373
+
374
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
375
+ frames = self.vae.decode(latents.to(self.vae.dtype)).sample
376
+ frames = (frames / 2 + 0.5).clamp(0, 1)
377
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
378
+ frames = frames.cpu().float().numpy()
379
+ return frames
380
+
381
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
382
+ def prepare_extra_step_kwargs(self, generator, eta):
383
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
384
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
385
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
386
+ # and should be between [0, 1]
387
+
388
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
389
+ extra_step_kwargs = {}
390
+ if accepts_eta:
391
+ extra_step_kwargs["eta"] = eta
392
+
393
+ # check if the scheduler accepts generator
394
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
395
+ if accepts_generator:
396
+ extra_step_kwargs["generator"] = generator
397
+ return extra_step_kwargs
398
+
399
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
400
+ def check_inputs(
401
+ self,
402
+ prompt,
403
+ height,
404
+ width,
405
+ negative_prompt,
406
+ callback_on_step_end_tensor_inputs,
407
+ prompt_embeds=None,
408
+ negative_prompt_embeds=None,
409
+ ):
410
+ if height % 8 != 0 or width % 8 != 0:
411
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
412
+
413
+ if callback_on_step_end_tensor_inputs is not None and not all(
414
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
415
+ ):
416
+ raise ValueError(
417
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
418
+ )
419
+ if prompt is not None and prompt_embeds is not None:
420
+ raise ValueError(
421
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
422
+ " only forward one of the two."
423
+ )
424
+ elif prompt is None and prompt_embeds is None:
425
+ raise ValueError(
426
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
427
+ )
428
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
429
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
430
+
431
+ if prompt is not None and negative_prompt_embeds is not None:
432
+ raise ValueError(
433
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
434
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
435
+ )
436
+
437
+ if negative_prompt is not None and negative_prompt_embeds is not None:
438
+ raise ValueError(
439
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
440
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
441
+ )
442
+
443
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
444
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
445
+ raise ValueError(
446
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
447
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
448
+ f" {negative_prompt_embeds.shape}."
449
+ )
450
+
451
+ @property
452
+ def guidance_scale(self):
453
+ return self._guidance_scale
454
+
455
+ @property
456
+ def num_timesteps(self):
457
+ return self._num_timesteps
458
+
459
+ @property
460
+ def attention_kwargs(self):
461
+ return self._attention_kwargs
462
+
463
+ @property
464
+ def interrupt(self):
465
+ return self._interrupt
466
+
467
+ @torch.no_grad()
468
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
469
+ def __call__(
470
+ self,
471
+ prompt: Optional[Union[str, List[str]]] = None,
472
+ negative_prompt: Optional[Union[str, List[str]]] = None,
473
+ height: int = 480,
474
+ width: int = 720,
475
+ video: Union[torch.FloatTensor] = None,
476
+ mask_video: Union[torch.FloatTensor] = None,
477
+ num_frames: int = 49,
478
+ num_inference_steps: int = 50,
479
+ timesteps: Optional[List[int]] = None,
480
+ guidance_scale: float = 6,
481
+ num_videos_per_prompt: int = 1,
482
+ eta: float = 0.0,
483
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
484
+ latents: Optional[torch.FloatTensor] = None,
485
+ prompt_embeds: Optional[torch.FloatTensor] = None,
486
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
487
+ output_type: str = "numpy",
488
+ return_dict: bool = False,
489
+ callback_on_step_end: Optional[
490
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
491
+ ] = None,
492
+ attention_kwargs: Optional[Dict[str, Any]] = None,
493
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
494
+ clip_image: Image = None,
495
+ max_sequence_length: int = 512,
496
+ comfyui_progressbar: bool = False,
497
+ ) -> Union[WanPipelineOutput, Tuple]:
498
+ """
499
+ Function invoked when calling the pipeline for generation.
500
+ Args:
501
+
502
+ Examples:
503
+
504
+ Returns:
505
+
506
+ """
507
+
508
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
509
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
510
+ num_videos_per_prompt = 1
511
+
512
+ # 1. Check inputs. Raise error if not correct
513
+ self.check_inputs(
514
+ prompt,
515
+ height,
516
+ width,
517
+ negative_prompt,
518
+ callback_on_step_end_tensor_inputs,
519
+ prompt_embeds,
520
+ negative_prompt_embeds,
521
+ )
522
+ self._guidance_scale = guidance_scale
523
+ self._attention_kwargs = attention_kwargs
524
+ self._interrupt = False
525
+
526
+ # 2. Default call parameters
527
+ if prompt is not None and isinstance(prompt, str):
528
+ batch_size = 1
529
+ elif prompt is not None and isinstance(prompt, list):
530
+ batch_size = len(prompt)
531
+ else:
532
+ batch_size = prompt_embeds.shape[0]
533
+
534
+ device = self._execution_device
535
+ weight_dtype = self.text_encoder.dtype
536
+
537
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
538
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
539
+ # corresponds to doing no classifier free guidance.
540
+ do_classifier_free_guidance = guidance_scale > 1.0
541
+
542
+ # 3. Encode input prompt
543
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
544
+ prompt,
545
+ negative_prompt,
546
+ do_classifier_free_guidance,
547
+ num_videos_per_prompt=num_videos_per_prompt,
548
+ prompt_embeds=prompt_embeds,
549
+ negative_prompt_embeds=negative_prompt_embeds,
550
+ max_sequence_length=max_sequence_length,
551
+ device=device,
552
+ )
553
+ if do_classifier_free_guidance:
554
+ prompt_embeds = negative_prompt_embeds + prompt_embeds
555
+
556
+ # 4. Prepare timesteps
557
+ if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
558
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1)
559
+ else:
560
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
561
+ self._num_timesteps = len(timesteps)
562
+ if comfyui_progressbar:
563
+ from comfy.utils import ProgressBar
564
+ pbar = ProgressBar(num_inference_steps + 2)
565
+
566
+ # 5. Prepare latents.
567
+ if video is not None:
568
+ video_length = video.shape[2]
569
+ init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
570
+ init_video = init_video.to(dtype=torch.float32)
571
+ init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length)
572
+ else:
573
+ init_video = None
574
+
575
+ latent_channels = self.vae.config.latent_channels
576
+ latents = self.prepare_latents(
577
+ batch_size * num_videos_per_prompt,
578
+ latent_channels,
579
+ num_frames,
580
+ height,
581
+ width,
582
+ weight_dtype,
583
+ device,
584
+ generator,
585
+ latents,
586
+ )
587
+ if comfyui_progressbar:
588
+ pbar.update(1)
589
+
590
+ # Prepare mask latent variables
591
+ if init_video is not None:
592
+ if (mask_video == 255).all():
593
+ mask_latents = torch.tile(
594
+ torch.zeros_like(latents)[:, :1].to(device, weight_dtype), [1, 4, 1, 1, 1]
595
+ )
596
+ masked_video_latents = torch.zeros_like(latents).to(device, weight_dtype)
597
+
598
+ mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
599
+ masked_video_latents_input = (
600
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
601
+ )
602
+ y = torch.cat([mask_input, masked_video_latents_input], dim=1).to(device, weight_dtype)
603
+ else:
604
+ bs, _, video_length, height, width = video.size()
605
+ mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
606
+ mask_condition = mask_condition.to(dtype=torch.float32)
607
+ mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
608
+
609
+ masked_video = init_video * (torch.tile(mask_condition, [1, 3, 1, 1, 1]) < 0.5)
610
+ _, masked_video_latents = self.prepare_mask_latents(
611
+ None,
612
+ masked_video,
613
+ batch_size,
614
+ height,
615
+ width,
616
+ weight_dtype,
617
+ device,
618
+ generator,
619
+ do_classifier_free_guidance,
620
+ noise_aug_strength=None,
621
+ )
622
+
623
+ mask_condition = torch.concat(
624
+ [
625
+ torch.repeat_interleave(mask_condition[:, :, 0:1], repeats=4, dim=2),
626
+ mask_condition[:, :, 1:]
627
+ ], dim=2
628
+ )
629
+ mask_condition = mask_condition.view(bs, mask_condition.shape[2] // 4, 4, height, width)
630
+ mask_condition = mask_condition.transpose(1, 2)
631
+ mask_latents = resize_mask(1 - mask_condition, masked_video_latents, True).to(device, weight_dtype)
632
+
633
+ mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
634
+ masked_video_latents_input = (
635
+ torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
636
+ )
637
+
638
+ y = torch.cat([mask_input, masked_video_latents_input], dim=1).to(device, weight_dtype)
639
+
640
+ # Prepare clip latent variables
641
+ if clip_image is not None:
642
+ clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype)
643
+ clip_context = self.clip_image_encoder([clip_image[:, None, :, :]])
644
+ clip_context = (
645
+ torch.cat([clip_context] * 2) if do_classifier_free_guidance else clip_context
646
+ )
647
+ else:
648
+ clip_image = Image.new("RGB", (512, 512), color=(0, 0, 0))
649
+ clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype)
650
+ clip_context = self.clip_image_encoder([clip_image[:, None, :, :]])
651
+ clip_context = (
652
+ torch.cat([clip_context] * 2) if do_classifier_free_guidance else clip_context
653
+ )
654
+ clip_context = torch.zeros_like(clip_context)
655
+ if comfyui_progressbar:
656
+ pbar.update(1)
657
+
658
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
659
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
660
+
661
+ target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spacial_compression_ratio, height // self.vae.spacial_compression_ratio)
662
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1])
663
+ # 7. Denoising loop
664
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
665
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
666
+ for i, t in enumerate(timesteps):
667
+ if self.interrupt:
668
+ continue
669
+
670
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
671
+ if hasattr(self.scheduler, "scale_model_input"):
672
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
673
+
674
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
675
+ timestep = t.expand(latent_model_input.shape[0])
676
+
677
+ # predict noise model_output
678
+ with torch.cuda.amp.autocast(dtype=weight_dtype):
679
+ noise_pred = self.transformer(
680
+ x=latent_model_input,
681
+ context=prompt_embeds,
682
+ t=timestep,
683
+ seq_len=seq_len,
684
+ y=y,
685
+ clip_fea=clip_context,
686
+ )
687
+
688
+ # perform guidance
689
+ if do_classifier_free_guidance:
690
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
691
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
692
+
693
+ # compute the previous noisy sample x_t -> x_t-1
694
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
695
+
696
+ if callback_on_step_end is not None:
697
+ callback_kwargs = {}
698
+ for k in callback_on_step_end_tensor_inputs:
699
+ callback_kwargs[k] = locals()[k]
700
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
701
+
702
+ latents = callback_outputs.pop("latents", latents)
703
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
704
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
705
+ prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2)
706
+ negative_prompt_embeds_2 = callback_outputs.pop(
707
+ "negative_prompt_embeds_2", negative_prompt_embeds_2
708
+ )
709
+
710
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
711
+ progress_bar.update()
712
+ if comfyui_progressbar:
713
+ pbar.update(1)
714
+
715
+ if output_type == "numpy":
716
+ video = self.decode_latents(latents)
717
+ elif not output_type == "latent":
718
+ video = self.decode_latents(latents)
719
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
720
+ else:
721
+ video = latents
722
+
723
+ # Offload all models
724
+ self.maybe_free_model_hooks()
725
+
726
+ if not return_dict:
727
+ video = torch.from_numpy(video)
728
+
729
+ return WanPipelineOutput(videos=video)
cogvideox/ui/cogvideox_fun_ui.py ADDED
@@ -0,0 +1,722 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py
2
+ """
3
+ import os
4
+ import random
5
+
6
+ import cv2
7
+ import gradio as gr
8
+ import numpy as np
9
+ import torch
10
+ from PIL import Image
11
+ from safetensors import safe_open
12
+
13
+ from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
14
+ from ..models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel, T5Tokenizer, T5EncoderModel
15
+ from ..pipeline import (CogVideoXFunPipeline, CogVideoXFunControlPipeline,
16
+ CogVideoXFunInpaintPipeline)
17
+ from ..utils.fp8_optimization import convert_weight_dtype_wrapper
18
+ from ..utils.lora_utils import merge_lora, unmerge_lora
19
+ from ..utils.utils import (get_image_to_video_latent,
20
+ get_video_to_video_latent, save_videos_grid)
21
+ from .controller import (Fun_Controller, Fun_Controller_EAS, all_cheduler_dict,
22
+ css, ddpm_scheduler_dict, flow_scheduler_dict,
23
+ gradio_version, gradio_version_is_above_4)
24
+ from .ui import (create_cfg_and_seedbox,
25
+ create_fake_finetune_models_checkpoints,
26
+ create_fake_height_width, create_fake_model_checkpoints,
27
+ create_fake_model_type, create_finetune_models_checkpoints,
28
+ create_generation_method,
29
+ create_generation_methods_and_video_length,
30
+ create_height_width, create_model_checkpoints,
31
+ create_model_type, create_prompts, create_samplers,
32
+ create_ui_outputs)
33
+
34
+
35
+ class CogVideoXFunController(Fun_Controller):
36
+ def update_diffusion_transformer(self, diffusion_transformer_dropdown):
37
+ print("Update diffusion transformer")
38
+ self.diffusion_transformer_dropdown = diffusion_transformer_dropdown
39
+ if diffusion_transformer_dropdown == "none":
40
+ return gr.update()
41
+ self.vae = AutoencoderKLCogVideoX.from_pretrained(
42
+ diffusion_transformer_dropdown,
43
+ subfolder="vae",
44
+ ).to(self.weight_dtype)
45
+
46
+ # Get Transformer
47
+ self.transformer = CogVideoXTransformer3DModel.from_pretrained(
48
+ diffusion_transformer_dropdown,
49
+ subfolder="transformer",
50
+ low_cpu_mem_usage=True,
51
+ ).to(self.weight_dtype)
52
+
53
+ # Get tokenizer and text_encoder
54
+ tokenizer = T5Tokenizer.from_pretrained(
55
+ diffusion_transformer_dropdown, subfolder="tokenizer"
56
+ )
57
+ text_encoder = T5EncoderModel.from_pretrained(
58
+ diffusion_transformer_dropdown, subfolder="text_encoder", torch_dtype=self.weight_dtype
59
+ )
60
+
61
+ # Get pipeline
62
+ if self.model_type == "Inpaint":
63
+ if self.transformer.config.in_channels != self.vae.config.latent_channels:
64
+ self.pipeline = CogVideoXFunInpaintPipeline.from_pretrained(
65
+ tokenizer=tokenizer,
66
+ text_encoder=text_encoder,
67
+ vae=self.vae,
68
+ transformer=self.transformer,
69
+ scheduler=self.scheduler_dict[list(self.scheduler_dict.keys())[0]].from_pretrained(diffusion_transformer_dropdown, subfolder="scheduler"),
70
+ )
71
+ else:
72
+ self.pipeline = CogVideoXFunPipeline.from_pretrained(
73
+ tokenizer=tokenizer,
74
+ text_encoder=text_encoder,
75
+ vae=self.vae,
76
+ transformer=self.transformer,
77
+ scheduler=self.scheduler_dict[list(self.scheduler_dict.keys())[0]].from_pretrained(diffusion_transformer_dropdown, subfolder="scheduler"),
78
+ )
79
+ else:
80
+ self.pipeline = CogVideoXFunControlPipeline.from_pretrained(
81
+ diffusion_transformer_dropdown,
82
+ vae=self.vae,
83
+ transformer=self.transformer,
84
+ scheduler=self.scheduler_dict[list(self.scheduler_dict.keys())[0]].from_pretrained(diffusion_transformer_dropdown, subfolder="scheduler"),
85
+ torch_dtype=self.weight_dtype
86
+ )
87
+
88
+ if self.GPU_memory_mode == "sequential_cpu_offload":
89
+ self.pipeline.enable_sequential_cpu_offload()
90
+ elif self.GPU_memory_mode == "model_cpu_offload_and_qfloat8":
91
+ convert_weight_dtype_wrapper(self.pipeline.transformer, self.weight_dtype)
92
+ self.pipeline.enable_model_cpu_offload()
93
+ else:
94
+ self.pipeline.enable_model_cpu_offload()
95
+ print("Update diffusion transformer done")
96
+ return gr.update()
97
+
98
+ def generate(
99
+ self,
100
+ diffusion_transformer_dropdown,
101
+ base_model_dropdown,
102
+ lora_model_dropdown,
103
+ lora_alpha_slider,
104
+ prompt_textbox,
105
+ negative_prompt_textbox,
106
+ sampler_dropdown,
107
+ sample_step_slider,
108
+ resize_method,
109
+ width_slider,
110
+ height_slider,
111
+ base_resolution,
112
+ generation_method,
113
+ length_slider,
114
+ overlap_video_length,
115
+ partial_video_length,
116
+ cfg_scale_slider,
117
+ start_image,
118
+ end_image,
119
+ validation_video,
120
+ validation_video_mask,
121
+ control_video,
122
+ denoise_strength,
123
+ seed_textbox,
124
+ is_api = False,
125
+ ):
126
+ self.clear_cache()
127
+
128
+ self.input_check(
129
+ resize_method, generation_method, start_image, end_image, validation_video,control_video, is_api
130
+ )
131
+ is_image = True if generation_method == "Image Generation" else False
132
+
133
+ if self.base_model_path != base_model_dropdown:
134
+ self.update_base_model(base_model_dropdown)
135
+
136
+ if self.lora_model_path != lora_model_dropdown:
137
+ self.update_lora_model(lora_model_dropdown)
138
+
139
+ self.pipeline.scheduler = self.scheduler_dict[sampler_dropdown].from_config(self.pipeline.scheduler.config)
140
+
141
+ if resize_method == "Resize according to Reference":
142
+ height_slider, width_slider = self.get_height_width_from_reference(
143
+ base_resolution, start_image, validation_video, control_video,
144
+ )
145
+ if self.lora_model_path != "none":
146
+ # lora part
147
+ self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
148
+
149
+ if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
150
+ else: seed_textbox = np.random.randint(0, 1e10)
151
+ generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
152
+
153
+ try:
154
+ if self.model_type == "Inpaint":
155
+ if self.transformer.config.in_channels != self.vae.config.latent_channels:
156
+ if generation_method == "Long Video Generation":
157
+ if validation_video is not None:
158
+ raise gr.Error(f"Video to Video is not Support Long Video Generation now.")
159
+ init_frames = 0
160
+ last_frames = init_frames + partial_video_length
161
+ while init_frames < length_slider:
162
+ if last_frames >= length_slider:
163
+ _partial_video_length = length_slider - init_frames
164
+ _partial_video_length = int((_partial_video_length - 1) // self.vae.config.temporal_compression_ratio * self.vae.config.temporal_compression_ratio) + 1
165
+
166
+ if _partial_video_length <= 0:
167
+ break
168
+ else:
169
+ _partial_video_length = partial_video_length
170
+
171
+ if last_frames >= length_slider:
172
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=_partial_video_length, sample_size=(height_slider, width_slider))
173
+ else:
174
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, None, video_length=_partial_video_length, sample_size=(height_slider, width_slider))
175
+
176
+ with torch.no_grad():
177
+ sample = self.pipeline(
178
+ prompt_textbox,
179
+ negative_prompt = negative_prompt_textbox,
180
+ num_inference_steps = sample_step_slider,
181
+ guidance_scale = cfg_scale_slider,
182
+ width = width_slider,
183
+ height = height_slider,
184
+ num_frames = _partial_video_length,
185
+ generator = generator,
186
+
187
+ video = input_video,
188
+ mask_video = input_video_mask,
189
+ strength = 1,
190
+ ).videos
191
+
192
+ if init_frames != 0:
193
+ mix_ratio = torch.from_numpy(
194
+ np.array([float(_index) / float(overlap_video_length) for _index in range(overlap_video_length)], np.float32)
195
+ ).unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
196
+
197
+ new_sample[:, :, -overlap_video_length:] = new_sample[:, :, -overlap_video_length:] * (1 - mix_ratio) + \
198
+ sample[:, :, :overlap_video_length] * mix_ratio
199
+ new_sample = torch.cat([new_sample, sample[:, :, overlap_video_length:]], dim = 2)
200
+
201
+ sample = new_sample
202
+ else:
203
+ new_sample = sample
204
+
205
+ if last_frames >= length_slider:
206
+ break
207
+
208
+ start_image = [
209
+ Image.fromarray(
210
+ (sample[0, :, _index].transpose(0, 1).transpose(1, 2) * 255).numpy().astype(np.uint8)
211
+ ) for _index in range(-overlap_video_length, 0)
212
+ ]
213
+
214
+ init_frames = init_frames + _partial_video_length - overlap_video_length
215
+ last_frames = init_frames + _partial_video_length
216
+ else:
217
+ if validation_video is not None:
218
+ input_video, input_video_mask, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), validation_video_mask=validation_video_mask, fps=8)
219
+ strength = denoise_strength
220
+ else:
221
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
222
+ strength = 1
223
+
224
+ sample = self.pipeline(
225
+ prompt_textbox,
226
+ negative_prompt = negative_prompt_textbox,
227
+ num_inference_steps = sample_step_slider,
228
+ guidance_scale = cfg_scale_slider,
229
+ width = width_slider,
230
+ height = height_slider,
231
+ num_frames = length_slider if not is_image else 1,
232
+ generator = generator,
233
+
234
+ video = input_video,
235
+ mask_video = input_video_mask,
236
+ strength = strength,
237
+ ).videos
238
+ else:
239
+ sample = self.pipeline(
240
+ prompt_textbox,
241
+ negative_prompt = negative_prompt_textbox,
242
+ num_inference_steps = sample_step_slider,
243
+ guidance_scale = cfg_scale_slider,
244
+ width = width_slider,
245
+ height = height_slider,
246
+ num_frames = length_slider if not is_image else 1,
247
+ generator = generator
248
+ ).videos
249
+ else:
250
+ input_video, input_video_mask, clip_image = get_video_to_video_latent(control_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), fps=8)
251
+
252
+ sample = self.pipeline(
253
+ prompt_textbox,
254
+ negative_prompt = negative_prompt_textbox,
255
+ num_inference_steps = sample_step_slider,
256
+ guidance_scale = cfg_scale_slider,
257
+ width = width_slider,
258
+ height = height_slider,
259
+ num_frames = length_slider if not is_image else 1,
260
+ generator = generator,
261
+
262
+ control_video = input_video,
263
+ ).videos
264
+ except Exception as e:
265
+ self.clear_cache()
266
+ if self.lora_model_path != "none":
267
+ self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
268
+ if is_api:
269
+ return "", f"Error. error information is {str(e)}"
270
+ else:
271
+ return gr.update(), gr.update(), f"Error. error information is {str(e)}"
272
+
273
+ self.clear_cache()
274
+ # lora part
275
+ if self.lora_model_path != "none":
276
+ self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
277
+
278
+ save_sample_path = self.save_outputs(
279
+ is_image, length_slider, sample, fps=8
280
+ )
281
+
282
+ if is_image or length_slider == 1:
283
+ if is_api:
284
+ return save_sample_path, "Success"
285
+ else:
286
+ if gradio_version_is_above_4:
287
+ return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success"
288
+ else:
289
+ return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
290
+ else:
291
+ if is_api:
292
+ return save_sample_path, "Success"
293
+ else:
294
+ if gradio_version_is_above_4:
295
+ return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success"
296
+ else:
297
+ return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
298
+
299
+
300
+ class CogVideoXFunController_Modelscope(CogVideoXFunController):
301
+ def __init__(self, model_name, model_type, savedir_sample, GPU_memory_mode, scheduler_dict, weight_dtype):
302
+ # Basic dir
303
+ self.basedir = os.getcwd()
304
+ self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model")
305
+ self.lora_model_path = "none"
306
+ self.base_model_path = "none"
307
+ self.savedir_sample = savedir_sample
308
+ self.scheduler_dict = scheduler_dict
309
+ self.refresh_personalized_model()
310
+ os.makedirs(self.savedir_sample, exist_ok=True)
311
+
312
+ # model path
313
+ self.model_type = model_type
314
+ self.weight_dtype = weight_dtype
315
+
316
+ self.vae = AutoencoderKLCogVideoX.from_pretrained(
317
+ model_name,
318
+ subfolder="vae",
319
+ ).to(self.weight_dtype)
320
+
321
+ # Get Transformer
322
+ self.transformer = CogVideoXTransformer3DModel.from_pretrained(
323
+ model_name,
324
+ subfolder="transformer",
325
+ low_cpu_mem_usage=True,
326
+ ).to(self.weight_dtype)
327
+
328
+ # Get tokenizer and text_encoder
329
+ tokenizer = T5Tokenizer.from_pretrained(
330
+ model_name, subfolder="tokenizer"
331
+ )
332
+ text_encoder = T5EncoderModel.from_pretrained(
333
+ model_name, subfolder="text_encoder", torch_dtype=self.weight_dtype
334
+ )
335
+
336
+ # Get pipeline
337
+ if model_type == "Inpaint":
338
+ if self.transformer.config.in_channels != self.vae.config.latent_channels:
339
+ self.pipeline = CogVideoXFunInpaintPipeline(
340
+ tokenizer=tokenizer,
341
+ text_encoder=text_encoder,
342
+ vae=self.vae,
343
+ transformer=self.transformer,
344
+ scheduler=self.scheduler_dict["Euler"].from_pretrained(model_name, subfolder="scheduler"),
345
+ )
346
+ else:
347
+ self.pipeline = CogVideoXFunPipeline(
348
+ tokenizer=tokenizer,
349
+ text_encoder=text_encoder,
350
+ vae=self.vae,
351
+ transformer=self.transformer,
352
+ scheduler=self.scheduler_dict["Euler"].from_pretrained(model_name, subfolder="scheduler"),
353
+ )
354
+ else:
355
+ self.pipeline = CogVideoXFunControlPipeline(
356
+ tokenizer=tokenizer,
357
+ text_encoder=text_encoder,
358
+ vae=self.vae,
359
+ transformer=self.transformer,
360
+ scheduler=self.scheduler_dict["Euler"].from_pretrained(model_name, subfolder="scheduler"),
361
+ )
362
+
363
+ if GPU_memory_mode == "sequential_cpu_offload":
364
+ self.pipeline.enable_sequential_cpu_offload()
365
+ elif GPU_memory_mode == "model_cpu_offload_and_qfloat8":
366
+ convert_weight_dtype_wrapper(self.pipeline.transformer, self.weight_dtype)
367
+ self.pipeline.enable_model_cpu_offload()
368
+ else:
369
+ self.pipeline.enable_model_cpu_offload()
370
+ print("Update diffusion transformer done")
371
+
372
+ CogVideoXFunController_EAS = Fun_Controller_EAS
373
+
374
+ def ui(GPU_memory_mode, scheduler_dict, weight_dtype):
375
+ controller = CogVideoXFunController(GPU_memory_mode, scheduler_dict, weight_dtype)
376
+
377
+ with gr.Blocks(css=css) as demo:
378
+ gr.Markdown(
379
+ """
380
+ # CogVideoX-Fun:
381
+
382
+ A CogVideoX with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 49), as well as image generated videos.
383
+
384
+ [Github](https://github.com/aigc-apps/CogVideoX-Fun/)
385
+ """
386
+ )
387
+ with gr.Column(variant="panel"):
388
+ model_type = create_model_type(visible=True)
389
+ diffusion_transformer_dropdown, diffusion_transformer_refresh_button = \
390
+ create_model_checkpoints(controller, visible=True)
391
+ base_model_dropdown, lora_model_dropdown, lora_alpha_slider, personalized_refresh_button = \
392
+ create_finetune_models_checkpoints(controller, visible=True)
393
+
394
+ with gr.Column(variant="panel"):
395
+ prompt_textbox, negative_prompt_textbox = create_prompts()
396
+
397
+ with gr.Row():
398
+ with gr.Column():
399
+ sampler_dropdown, sample_step_slider = create_samplers(controller)
400
+
401
+ resize_method, width_slider, height_slider, base_resolution = create_height_width(
402
+ default_height = 672, default_width = 384, maximum_height = 1344,
403
+ maximum_width = 1344,
404
+ )
405
+ gr.Markdown(
406
+ """
407
+ V1.0 and V1.1 support up to 49 frames of video generation, while V1.5 supports up to 85 frames.
408
+ (V1.0和V1.1支持最大49帧视频生成,V1.5支持最大85帧视频生成。)
409
+ """
410
+ )
411
+ generation_method, length_slider, overlap_video_length, partial_video_length = \
412
+ create_generation_methods_and_video_length(
413
+ ["Video Generation", "Image Generation", "Long Video Generation"],
414
+ default_video_length=49,
415
+ maximum_video_length=85,
416
+ )
417
+ image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video = create_generation_method(
418
+ ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)", "Video Control (视频控制)"], prompt_textbox
419
+ )
420
+ cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4)
421
+
422
+ generate_button = gr.Button(value="Generate (生成)", variant='primary')
423
+
424
+ result_image, result_video, infer_progress = create_ui_outputs()
425
+
426
+ model_type.change(
427
+ fn=controller.update_model_type,
428
+ inputs=[model_type],
429
+ outputs=[]
430
+ )
431
+
432
+ def upload_generation_method(generation_method):
433
+ if generation_method == "Video Generation":
434
+ return [gr.update(visible=True, maximum=85, value=49, interactive=True), gr.update(visible=False), gr.update(visible=False)]
435
+ elif generation_method == "Image Generation":
436
+ return [gr.update(minimum=1, maximum=1, value=1, interactive=False), gr.update(visible=False), gr.update(visible=False)]
437
+ else:
438
+ return [gr.update(visible=True, maximum=1344), gr.update(visible=True), gr.update(visible=True)]
439
+ generation_method.change(
440
+ upload_generation_method, generation_method, [length_slider, overlap_video_length, partial_video_length]
441
+ )
442
+
443
+ def upload_source_method(source_method):
444
+ if source_method == "Text to Video (文本到视频)":
445
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
446
+ elif source_method == "Image to Video (图片到视频)":
447
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
448
+ elif source_method == "Video to Video (视频到视频)":
449
+ return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)]
450
+ else:
451
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()]
452
+ source_method.change(
453
+ upload_source_method, source_method, [
454
+ image_to_video_col, video_to_video_col, control_video_col, start_image, end_image,
455
+ validation_video, validation_video_mask, control_video
456
+ ]
457
+ )
458
+
459
+ def upload_resize_method(resize_method):
460
+ if resize_method == "Generate by":
461
+ return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
462
+ else:
463
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
464
+ resize_method.change(
465
+ upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
466
+ )
467
+
468
+ generate_button.click(
469
+ fn=controller.generate,
470
+ inputs=[
471
+ diffusion_transformer_dropdown,
472
+ base_model_dropdown,
473
+ lora_model_dropdown,
474
+ lora_alpha_slider,
475
+ prompt_textbox,
476
+ negative_prompt_textbox,
477
+ sampler_dropdown,
478
+ sample_step_slider,
479
+ resize_method,
480
+ width_slider,
481
+ height_slider,
482
+ base_resolution,
483
+ generation_method,
484
+ length_slider,
485
+ overlap_video_length,
486
+ partial_video_length,
487
+ cfg_scale_slider,
488
+ start_image,
489
+ end_image,
490
+ validation_video,
491
+ validation_video_mask,
492
+ control_video,
493
+ denoise_strength,
494
+ seed_textbox,
495
+ ],
496
+ outputs=[result_image, result_video, infer_progress]
497
+ )
498
+ return demo, controller
499
+
500
+ def ui_modelscope(model_name, model_type, savedir_sample, GPU_memory_mode, scheduler_dict, weight_dtype):
501
+ controller = CogVideoXFunController_Modelscope(model_name, model_type, savedir_sample, GPU_memory_mode, scheduler_dict, weight_dtype)
502
+
503
+ with gr.Blocks(css=css) as demo:
504
+ gr.Markdown(
505
+ """
506
+ # CogVideoX-Fun
507
+
508
+ A CogVideoX with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 49), as well as image generated videos.
509
+
510
+ [Github](https://github.com/aigc-apps/CogVideoX-Fun/)
511
+ """
512
+ )
513
+ with gr.Column(variant="panel"):
514
+ model_type = create_fake_model_type(visible=True)
515
+ diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True)
516
+ base_model_dropdown, lora_model_dropdown, lora_alpha_slider = create_fake_finetune_models_checkpoints(visible=True)
517
+
518
+ with gr.Column(variant="panel"):
519
+ prompt_textbox, negative_prompt_textbox = create_prompts()
520
+
521
+ with gr.Row():
522
+ with gr.Column():
523
+ sampler_dropdown, sample_step_slider = create_samplers(controller)
524
+
525
+ resize_method, width_slider, height_slider, base_resolution = create_height_width(
526
+ default_height = 672, default_width = 384, maximum_height = 1344,
527
+ maximum_width = 1344,
528
+ )
529
+ gr.Markdown(
530
+ """
531
+ V1.0 and V1.1 support up to 49 frames of video generation, while V1.5 supports up to 85 frames.
532
+ (V1.0和V1.1支持最大49帧视频生成,V1.5支持最大85帧视频生成。)
533
+ """
534
+ )
535
+ generation_method, length_slider, overlap_video_length, partial_video_length = \
536
+ create_generation_methods_and_video_length(
537
+ ["Video Generation", "Image Generation"],
538
+ default_video_length=49,
539
+ maximum_video_length=85,
540
+ )
541
+ image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video = create_generation_method(
542
+ ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)", "Video Control (视频控制)"], prompt_textbox
543
+ )
544
+ cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4)
545
+
546
+ generate_button = gr.Button(value="Generate (生成)", variant='primary')
547
+
548
+ result_image, result_video, infer_progress = create_ui_outputs()
549
+
550
+ def upload_generation_method(generation_method):
551
+ if generation_method == "Video Generation":
552
+ return gr.update(visible=True, minimum=8, maximum=85, value=49, interactive=True)
553
+ elif generation_method == "Image Generation":
554
+ return gr.update(minimum=1, maximum=1, value=1, interactive=False)
555
+ generation_method.change(
556
+ upload_generation_method, generation_method, [length_slider]
557
+ )
558
+
559
+ def upload_source_method(source_method):
560
+ if source_method == "Text to Video (文本到视频)":
561
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
562
+ elif source_method == "Image to Video (图片到视频)":
563
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
564
+ elif source_method == "Video to Video (视频到视频)":
565
+ return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)]
566
+ else:
567
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()]
568
+ source_method.change(
569
+ upload_source_method, source_method, [
570
+ image_to_video_col, video_to_video_col, control_video_col, start_image, end_image,
571
+ validation_video, validation_video_mask, control_video
572
+ ]
573
+ )
574
+
575
+ def upload_resize_method(resize_method):
576
+ if resize_method == "Generate by":
577
+ return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
578
+ else:
579
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
580
+ resize_method.change(
581
+ upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
582
+ )
583
+
584
+ generate_button.click(
585
+ fn=controller.generate,
586
+ inputs=[
587
+ diffusion_transformer_dropdown,
588
+ base_model_dropdown,
589
+ lora_model_dropdown,
590
+ lora_alpha_slider,
591
+ prompt_textbox,
592
+ negative_prompt_textbox,
593
+ sampler_dropdown,
594
+ sample_step_slider,
595
+ resize_method,
596
+ width_slider,
597
+ height_slider,
598
+ base_resolution,
599
+ generation_method,
600
+ length_slider,
601
+ overlap_video_length,
602
+ partial_video_length,
603
+ cfg_scale_slider,
604
+ start_image,
605
+ end_image,
606
+ validation_video,
607
+ validation_video_mask,
608
+ control_video,
609
+ denoise_strength,
610
+ seed_textbox,
611
+ ],
612
+ outputs=[result_image, result_video, infer_progress]
613
+ )
614
+ return demo, controller
615
+
616
+ def ui_eas(model_name, scheduler_dict, savedir_sample):
617
+ controller = CogVideoXFunController_EAS(model_name, scheduler_dict, savedir_sample)
618
+
619
+ with gr.Blocks(css=css) as demo:
620
+ gr.Markdown(
621
+ """
622
+ # CogVideoX-Fun
623
+
624
+ A CogVideoX with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 49), as well as image generated videos.
625
+
626
+ [Github](https://github.com/aigc-apps/CogVideoX-Fun/)
627
+ """
628
+ )
629
+ with gr.Column(variant="panel"):
630
+ diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True)
631
+ base_model_dropdown, lora_model_dropdown, lora_alpha_slider = create_fake_finetune_models_checkpoints(visible=True)
632
+
633
+ with gr.Column(variant="panel"):
634
+ prompt_textbox, negative_prompt_textbox = create_prompts()
635
+
636
+ with gr.Row():
637
+ with gr.Column():
638
+ sampler_dropdown, sample_step_slider = create_samplers(controller, maximum_step=50)
639
+
640
+ resize_method, width_slider, height_slider, base_resolution = create_fake_height_width(
641
+ default_height = 672, default_width = 384, maximum_height = 1344,
642
+ maximum_width = 1344,
643
+ )
644
+ gr.Markdown(
645
+ """
646
+ V1.0 and V1.1 support up to 49 frames of video generation, while V1.5 supports up to 85 frames.
647
+ (V1.0和V1.1支持最大49帧视频生成,V1.5支持最大85帧视频生成。)
648
+ """
649
+ )
650
+ generation_method, length_slider, overlap_video_length, partial_video_length = \
651
+ create_generation_methods_and_video_length(
652
+ ["Video Generation", "Image Generation"],
653
+ default_video_length=49,
654
+ maximum_video_length=85,
655
+ )
656
+ image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video = create_generation_method(
657
+ ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)"], prompt_textbox
658
+ )
659
+
660
+ cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4)
661
+
662
+ generate_button = gr.Button(value="Generate (生成)", variant='primary')
663
+
664
+ result_image, result_video, infer_progress = create_ui_outputs()
665
+
666
+ def upload_generation_method(generation_method):
667
+ if generation_method == "Video Generation":
668
+ return gr.update(visible=True, minimum=5, maximum=85, value=49, interactive=True)
669
+ elif generation_method == "Image Generation":
670
+ return gr.update(minimum=1, maximum=1, value=1, interactive=False)
671
+ generation_method.change(
672
+ upload_generation_method, generation_method, [length_slider]
673
+ )
674
+
675
+ def upload_source_method(source_method):
676
+ if source_method == "Text to Video (文本到视频)":
677
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
678
+ elif source_method == "Image to Video (图片到视频)":
679
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None)]
680
+ else:
681
+ return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(), gr.update()]
682
+ source_method.change(
683
+ upload_source_method, source_method, [image_to_video_col, video_to_video_col, start_image, end_image, validation_video, validation_video_mask]
684
+ )
685
+
686
+ def upload_resize_method(resize_method):
687
+ if resize_method == "Generate by":
688
+ return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
689
+ else:
690
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
691
+ resize_method.change(
692
+ upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
693
+ )
694
+
695
+ generate_button.click(
696
+ fn=controller.generate,
697
+ inputs=[
698
+ diffusion_transformer_dropdown,
699
+ base_model_dropdown,
700
+ lora_model_dropdown,
701
+ lora_alpha_slider,
702
+ prompt_textbox,
703
+ negative_prompt_textbox,
704
+ sampler_dropdown,
705
+ sample_step_slider,
706
+ resize_method,
707
+ width_slider,
708
+ height_slider,
709
+ base_resolution,
710
+ generation_method,
711
+ length_slider,
712
+ cfg_scale_slider,
713
+ start_image,
714
+ end_image,
715
+ validation_video,
716
+ validation_video_mask,
717
+ denoise_strength,
718
+ seed_textbox,
719
+ ],
720
+ outputs=[result_image, result_video, infer_progress]
721
+ )
722
+ return demo, controller
cogvideox/ui/controller.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py
2
+ """
3
+ import base64
4
+ import gc
5
+ import json
6
+ import os
7
+ import random
8
+ from datetime import datetime
9
+ from glob import glob
10
+ from omegaconf import OmegaConf
11
+
12
+ import cv2
13
+ import gradio as gr
14
+ import numpy as np
15
+ import pkg_resources
16
+ import requests
17
+ import torch
18
+ from diffusers import (CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler,
19
+ DDIMScheduler, DPMSolverMultistepScheduler,
20
+ EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
21
+ PNDMScheduler)
22
+ from PIL import Image
23
+ from safetensors import safe_open
24
+
25
+ from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
26
+ from ..utils.utils import save_videos_grid
27
+
28
+ gradio_version = pkg_resources.get_distribution("gradio").version
29
+ gradio_version_is_above_4 = True if int(gradio_version.split('.')[0]) >= 4 else False
30
+
31
+ css = """
32
+ .toolbutton {
33
+ margin-buttom: 0em 0em 0em 0em;
34
+ max-width: 2.5em;
35
+ min-width: 2.5em !important;
36
+ height: 2.5em;
37
+ }
38
+ """
39
+
40
+ ddpm_scheduler_dict = {
41
+ "Euler": EulerDiscreteScheduler,
42
+ "Euler A": EulerAncestralDiscreteScheduler,
43
+ "DPM++": DPMSolverMultistepScheduler,
44
+ "PNDM": PNDMScheduler,
45
+ "DDIM": DDIMScheduler,
46
+ "DDIM_Origin": DDIMScheduler,
47
+ "DDIM_Cog": CogVideoXDDIMScheduler,
48
+ }
49
+ flow_scheduler_dict = {
50
+ "Flow": FlowMatchEulerDiscreteScheduler,
51
+ }
52
+ all_cheduler_dict = {**ddpm_scheduler_dict, **flow_scheduler_dict}
53
+
54
+ class Fun_Controller:
55
+ def __init__(self, GPU_memory_mode, scheduler_dict, weight_dtype, config_path=None):
56
+ # config dirs
57
+ self.basedir = os.getcwd()
58
+ self.config_dir = os.path.join(self.basedir, "config")
59
+ self.diffusion_transformer_dir = os.path.join(self.basedir, "models", "Diffusion_Transformer")
60
+ self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module")
61
+ self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model")
62
+ self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
63
+ self.savedir_sample = os.path.join(self.savedir, "sample")
64
+ self.model_type = "Inpaint"
65
+ os.makedirs(self.savedir, exist_ok=True)
66
+
67
+ self.diffusion_transformer_list = []
68
+ self.motion_module_list = []
69
+ self.personalized_model_list = []
70
+
71
+ self.refresh_diffusion_transformer()
72
+ self.refresh_motion_module()
73
+ self.refresh_personalized_model()
74
+
75
+ # config models
76
+ self.tokenizer = None
77
+ self.text_encoder = None
78
+ self.vae = None
79
+ self.transformer = None
80
+ self.pipeline = None
81
+ self.motion_module_path = "none"
82
+ self.base_model_path = "none"
83
+ self.lora_model_path = "none"
84
+ self.GPU_memory_mode = GPU_memory_mode
85
+
86
+ self.weight_dtype = weight_dtype
87
+ self.scheduler_dict = scheduler_dict
88
+ if config_path is not None:
89
+ self.config = OmegaConf.load(config_path)
90
+
91
+ def refresh_diffusion_transformer(self):
92
+ self.diffusion_transformer_list = sorted(glob(os.path.join(self.diffusion_transformer_dir, "*/")))
93
+
94
+ def refresh_motion_module(self):
95
+ motion_module_list = sorted(glob(os.path.join(self.motion_module_dir, "*.safetensors")))
96
+ self.motion_module_list = [os.path.basename(p) for p in motion_module_list]
97
+
98
+ def refresh_personalized_model(self):
99
+ personalized_model_list = sorted(glob(os.path.join(self.personalized_model_dir, "*.safetensors")))
100
+ self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
101
+
102
+ def update_model_type(self, model_type):
103
+ self.model_type = model_type
104
+
105
+ def update_diffusion_transformer(self, diffusion_transformer_dropdown):
106
+ pass
107
+
108
+ def update_base_model(self, base_model_dropdown):
109
+ self.base_model_path = base_model_dropdown
110
+ print("Update base model")
111
+ if base_model_dropdown == "none":
112
+ return gr.update()
113
+ if self.transformer is None:
114
+ gr.Info(f"Please select a pretrained model path.")
115
+ return gr.update(value=None)
116
+ else:
117
+ base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown)
118
+ base_model_state_dict = {}
119
+ with safe_open(base_model_dropdown, framework="pt", device="cpu") as f:
120
+ for key in f.keys():
121
+ base_model_state_dict[key] = f.get_tensor(key)
122
+ self.transformer.load_state_dict(base_model_state_dict, strict=False)
123
+ print("Update base done")
124
+ return gr.update()
125
+
126
+ def update_lora_model(self, lora_model_dropdown):
127
+ print("Update lora model")
128
+ if lora_model_dropdown == "none":
129
+ self.lora_model_path = "none"
130
+ return gr.update()
131
+ lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown)
132
+ self.lora_model_path = lora_model_dropdown
133
+ return gr.update()
134
+
135
+ def clear_cache(self,):
136
+ gc.collect()
137
+ torch.cuda.empty_cache()
138
+ torch.cuda.ipc_collect()
139
+
140
+ def input_check(self,
141
+ resize_method,
142
+ generation_method,
143
+ start_image,
144
+ end_image,
145
+ validation_video,
146
+ control_video,
147
+ is_api = False,
148
+ ):
149
+ if self.transformer is None:
150
+ raise gr.Error(f"Please select a pretrained model path.")
151
+
152
+ if control_video is not None and self.model_type == "Inpaint":
153
+ if is_api:
154
+ return "", f"If specifying the control video, please set the model_type == \"Control\". "
155
+ else:
156
+ raise gr.Error(f"If specifying the control video, please set the model_type == \"Control\". ")
157
+
158
+ if control_video is None and self.model_type == "Control":
159
+ if is_api:
160
+ return "", f"If set the model_type == \"Control\", please specifying the control video. "
161
+ else:
162
+ raise gr.Error(f"If set the model_type == \"Control\", please specifying the control video. ")
163
+
164
+ if resize_method == "Resize according to Reference":
165
+ if start_image is None and validation_video is None and control_video is None:
166
+ if is_api:
167
+ return "", f"Please upload an image when using \"Resize according to Reference\"."
168
+ else:
169
+ raise gr.Error(f"Please upload an image when using \"Resize according to Reference\".")
170
+
171
+ if self.transformer.config.in_channels == self.vae.config.latent_channels and start_image is not None:
172
+ if is_api:
173
+ return "", f"Please select an image to video pretrained model while using image to video."
174
+ else:
175
+ raise gr.Error(f"Please select an image to video pretrained model while using image to video.")
176
+
177
+ if self.transformer.config.in_channels == self.vae.config.latent_channels and generation_method == "Long Video Generation":
178
+ if is_api:
179
+ return "", f"Please select an image to video pretrained model while using long video generation."
180
+ else:
181
+ raise gr.Error(f"Please select an image to video pretrained model while using long video generation.")
182
+
183
+ if start_image is None and end_image is not None:
184
+ if is_api:
185
+ return "", f"If specifying the ending image of the video, please specify a starting image of the video."
186
+ else:
187
+ raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.")
188
+
189
+ def get_height_width_from_reference(
190
+ self,
191
+ base_resolution,
192
+ start_image,
193
+ validation_video,
194
+ control_video,
195
+ ):
196
+ aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
197
+ if self.model_type == "Inpaint":
198
+ if validation_video is not None:
199
+ original_width, original_height = Image.fromarray(cv2.VideoCapture(validation_video).read()[1]).size
200
+ else:
201
+ original_width, original_height = start_image[0].size if type(start_image) is list else Image.open(start_image).size
202
+ else:
203
+ original_width, original_height = Image.fromarray(cv2.VideoCapture(control_video).read()[1]).size
204
+ closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
205
+ height_slider, width_slider = [int(x / 16) * 16 for x in closest_size]
206
+ return height_slider, width_slider
207
+
208
+ def save_outputs(self, is_image, length_slider, sample, fps):
209
+ if not os.path.exists(self.savedir_sample):
210
+ os.makedirs(self.savedir_sample, exist_ok=True)
211
+ index = len([path for path in os.listdir(self.savedir_sample)]) + 1
212
+ prefix = str(index).zfill(3)
213
+
214
+ if is_image or length_slider == 1:
215
+ save_sample_path = os.path.join(self.savedir_sample, prefix + f".png")
216
+
217
+ image = sample[0, :, 0]
218
+ image = image.transpose(0, 1).transpose(1, 2)
219
+ image = (image * 255).numpy().astype(np.uint8)
220
+ image = Image.fromarray(image)
221
+ image.save(save_sample_path)
222
+
223
+ else:
224
+ save_sample_path = os.path.join(self.savedir_sample, prefix + f".mp4")
225
+ save_videos_grid(sample, save_sample_path, fps=fps)
226
+ return save_sample_path
227
+
228
+ def generate(
229
+ self,
230
+ diffusion_transformer_dropdown,
231
+ base_model_dropdown,
232
+ lora_model_dropdown,
233
+ lora_alpha_slider,
234
+ prompt_textbox,
235
+ negative_prompt_textbox,
236
+ sampler_dropdown,
237
+ sample_step_slider,
238
+ resize_method,
239
+ width_slider,
240
+ height_slider,
241
+ base_resolution,
242
+ generation_method,
243
+ length_slider,
244
+ overlap_video_length,
245
+ partial_video_length,
246
+ cfg_scale_slider,
247
+ start_image,
248
+ end_image,
249
+ validation_video,
250
+ validation_video_mask,
251
+ control_video,
252
+ denoise_strength,
253
+ seed_textbox,
254
+ is_api = False,
255
+ ):
256
+ pass
257
+
258
+ def post_eas(
259
+ diffusion_transformer_dropdown,
260
+ base_model_dropdown, lora_model_dropdown, lora_alpha_slider,
261
+ prompt_textbox, negative_prompt_textbox,
262
+ sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider,
263
+ base_resolution, generation_method, length_slider, cfg_scale_slider,
264
+ start_image, end_image, validation_video, validation_video_mask, denoise_strength, seed_textbox,
265
+ ):
266
+ if start_image is not None:
267
+ with open(start_image, 'rb') as file:
268
+ file_content = file.read()
269
+ start_image_encoded_content = base64.b64encode(file_content)
270
+ start_image = start_image_encoded_content.decode('utf-8')
271
+
272
+ if end_image is not None:
273
+ with open(end_image, 'rb') as file:
274
+ file_content = file.read()
275
+ end_image_encoded_content = base64.b64encode(file_content)
276
+ end_image = end_image_encoded_content.decode('utf-8')
277
+
278
+ if validation_video is not None:
279
+ with open(validation_video, 'rb') as file:
280
+ file_content = file.read()
281
+ validation_video_encoded_content = base64.b64encode(file_content)
282
+ validation_video = validation_video_encoded_content.decode('utf-8')
283
+
284
+ if validation_video_mask is not None:
285
+ with open(validation_video_mask, 'rb') as file:
286
+ file_content = file.read()
287
+ validation_video_mask_encoded_content = base64.b64encode(file_content)
288
+ validation_video_mask = validation_video_mask_encoded_content.decode('utf-8')
289
+
290
+ datas = {
291
+ "base_model_path": base_model_dropdown,
292
+ "lora_model_path": lora_model_dropdown,
293
+ "lora_alpha_slider": lora_alpha_slider,
294
+ "prompt_textbox": prompt_textbox,
295
+ "negative_prompt_textbox": negative_prompt_textbox,
296
+ "sampler_dropdown": sampler_dropdown,
297
+ "sample_step_slider": sample_step_slider,
298
+ "resize_method": resize_method,
299
+ "width_slider": width_slider,
300
+ "height_slider": height_slider,
301
+ "base_resolution": base_resolution,
302
+ "generation_method": generation_method,
303
+ "length_slider": length_slider,
304
+ "cfg_scale_slider": cfg_scale_slider,
305
+ "start_image": start_image,
306
+ "end_image": end_image,
307
+ "validation_video": validation_video,
308
+ "validation_video_mask": validation_video_mask,
309
+ "denoise_strength": denoise_strength,
310
+ "seed_textbox": seed_textbox,
311
+ }
312
+
313
+ session = requests.session()
314
+ session.headers.update({"Authorization": os.environ.get("EAS_TOKEN")})
315
+
316
+ response = session.post(url=f'{os.environ.get("EAS_URL")}/cogvideox_fun/infer_forward', json=datas, timeout=300)
317
+
318
+ outputs = response.json()
319
+ return outputs
320
+
321
+
322
+ class Fun_Controller_EAS:
323
+ def __init__(self, model_name, scheduler_dict, savedir_sample):
324
+ self.savedir_sample = savedir_sample
325
+ self.scheduler_dict = scheduler_dict
326
+ os.makedirs(self.savedir_sample, exist_ok=True)
327
+
328
+ def generate(
329
+ self,
330
+ diffusion_transformer_dropdown,
331
+ base_model_dropdown,
332
+ lora_model_dropdown,
333
+ lora_alpha_slider,
334
+ prompt_textbox,
335
+ negative_prompt_textbox,
336
+ sampler_dropdown,
337
+ sample_step_slider,
338
+ resize_method,
339
+ width_slider,
340
+ height_slider,
341
+ base_resolution,
342
+ generation_method,
343
+ length_slider,
344
+ cfg_scale_slider,
345
+ start_image,
346
+ end_image,
347
+ validation_video,
348
+ validation_video_mask,
349
+ denoise_strength,
350
+ seed_textbox
351
+ ):
352
+ is_image = True if generation_method == "Image Generation" else False
353
+
354
+ outputs = post_eas(
355
+ diffusion_transformer_dropdown,
356
+ base_model_dropdown, lora_model_dropdown, lora_alpha_slider,
357
+ prompt_textbox, negative_prompt_textbox,
358
+ sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider,
359
+ base_resolution, generation_method, length_slider, cfg_scale_slider,
360
+ start_image, end_image, validation_video, validation_video_mask, denoise_strength,
361
+ seed_textbox
362
+ )
363
+ try:
364
+ base64_encoding = outputs["base64_encoding"]
365
+ except:
366
+ return gr.Image(visible=False, value=None), gr.Video(None, visible=True), outputs["message"]
367
+
368
+ decoded_data = base64.b64decode(base64_encoding)
369
+
370
+ if not os.path.exists(self.savedir_sample):
371
+ os.makedirs(self.savedir_sample, exist_ok=True)
372
+ index = len([path for path in os.listdir(self.savedir_sample)]) + 1
373
+ prefix = str(index).zfill(3)
374
+
375
+ if is_image or length_slider == 1:
376
+ save_sample_path = os.path.join(self.savedir_sample, prefix + f".png")
377
+ with open(save_sample_path, "wb") as file:
378
+ file.write(decoded_data)
379
+ if gradio_version_is_above_4:
380
+ return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success"
381
+ else:
382
+ return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
383
+ else:
384
+ save_sample_path = os.path.join(self.savedir_sample, prefix + f".mp4")
385
+ with open(save_sample_path, "wb") as file:
386
+ file.write(decoded_data)
387
+ if gradio_version_is_above_4:
388
+ return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success"
389
+ else:
390
+ return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
cogvideox/ui/ui.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import random
3
+
4
+ def create_model_type(visible):
5
+ gr.Markdown(
6
+ """
7
+ ### Model Type (模型的种类,正常模型还是控制模型).
8
+ """,
9
+ visible=visible,
10
+ )
11
+ with gr.Row():
12
+ model_type = gr.Dropdown(
13
+ label="The model type of the model (模型的种类,正常模型还是控制模型)",
14
+ choices=["Inpaint", "Control"],
15
+ value="Inpaint",
16
+ visible=visible,
17
+ interactive=True,
18
+ )
19
+ return model_type
20
+
21
+ def create_fake_model_type(visible):
22
+ gr.Markdown(
23
+ """
24
+ ### Model Type (模型的种类,正常模型还是控制模型).
25
+ """,
26
+ visible=visible,
27
+ )
28
+ with gr.Row():
29
+ model_type = gr.Dropdown(
30
+ label="The model type of the model (模型的种类,正常模型还是控制模型)",
31
+ choices=["Inpaint", "Control"],
32
+ value="Inpaint",
33
+ interactive=False,
34
+ visible=visible,
35
+ )
36
+ return model_type
37
+
38
+ def create_model_checkpoints(controller, visible):
39
+ gr.Markdown(
40
+ """
41
+ ### Model checkpoints (模型路径).
42
+ """
43
+ )
44
+ with gr.Row(visible=visible):
45
+ diffusion_transformer_dropdown = gr.Dropdown(
46
+ label="Pretrained Model Path (预训练模型路径)",
47
+ choices=controller.diffusion_transformer_list,
48
+ value="none",
49
+ interactive=True,
50
+ )
51
+ diffusion_transformer_dropdown.change(
52
+ fn=controller.update_diffusion_transformer,
53
+ inputs=[diffusion_transformer_dropdown],
54
+ outputs=[diffusion_transformer_dropdown]
55
+ )
56
+
57
+ diffusion_transformer_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
58
+ def refresh_diffusion_transformer():
59
+ controller.refresh_diffusion_transformer()
60
+ return gr.update(choices=controller.diffusion_transformer_list)
61
+ diffusion_transformer_refresh_button.click(fn=refresh_diffusion_transformer, inputs=[], outputs=[diffusion_transformer_dropdown])
62
+
63
+ return diffusion_transformer_dropdown, diffusion_transformer_refresh_button
64
+
65
+ def create_fake_model_checkpoints(model_name, visible):
66
+ gr.Markdown(
67
+ """
68
+ ### Model checkpoints (模型路径).
69
+ """
70
+ )
71
+ with gr.Row(visible=visible):
72
+ diffusion_transformer_dropdown = gr.Dropdown(
73
+ label="Pretrained Model Path (预训练模型路径)",
74
+ choices=[model_name],
75
+ value=model_name,
76
+ interactive=False,
77
+ )
78
+ return diffusion_transformer_dropdown
79
+
80
+ def create_finetune_models_checkpoints(controller, visible):
81
+ with gr.Row(visible=visible):
82
+ base_model_dropdown = gr.Dropdown(
83
+ label="Select base Dreambooth model (选择基模型[非必需])",
84
+ choices=controller.personalized_model_list,
85
+ value="none",
86
+ interactive=True,
87
+ )
88
+
89
+ lora_model_dropdown = gr.Dropdown(
90
+ label="Select LoRA model (选择LoRA模型[非必需])",
91
+ choices=["none"] + controller.personalized_model_list,
92
+ value="none",
93
+ interactive=True,
94
+ )
95
+
96
+ lora_alpha_slider = gr.Slider(label="LoRA alpha (LoRA权重)", value=0.55, minimum=0, maximum=2, interactive=True)
97
+
98
+ personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
99
+ def update_personalized_model():
100
+ controller.refresh_personalized_model()
101
+ return [
102
+ gr.update(choices=controller.personalized_model_list),
103
+ gr.update(choices=["none"] + controller.personalized_model_list)
104
+ ]
105
+ personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown])
106
+
107
+ return base_model_dropdown, lora_model_dropdown, lora_alpha_slider, personalized_refresh_button
108
+
109
+ def create_fake_finetune_models_checkpoints(visible):
110
+ with gr.Row():
111
+ base_model_dropdown = gr.Dropdown(
112
+ label="Select base Dreambooth model (选择基模型[非必需])",
113
+ choices=["none"],
114
+ value="none",
115
+ interactive=False,
116
+ visible=False
117
+ )
118
+ with gr.Column(visible=False):
119
+ gr.Markdown(
120
+ """
121
+ ### Minimalism is an example portrait of Lora, triggered by specific prompt words. More details can be found on [Wiki](https://github.com/aigc-apps/CogVideoX-Fun/wiki/Training-Lora).
122
+ """
123
+ )
124
+ with gr.Row():
125
+ lora_model_dropdown = gr.Dropdown(
126
+ label="Select LoRA model",
127
+ choices=["none"],
128
+ value="none",
129
+ interactive=True,
130
+ )
131
+
132
+ lora_alpha_slider = gr.Slider(label="LoRA alpha (LoRA权重)", value=0.55, minimum=0, maximum=2, interactive=True)
133
+
134
+ return base_model_dropdown, lora_model_dropdown, lora_alpha_slider
135
+
136
+ def create_prompts(
137
+ prompt="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.",
138
+ negative_prompt="The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. "
139
+ ):
140
+ gr.Markdown(
141
+ """
142
+ ### Configs for Generation (生成参数配置).
143
+ """
144
+ )
145
+
146
+ prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value=prompt)
147
+ negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value=negative_prompt)
148
+ return prompt_textbox, negative_prompt_textbox
149
+
150
+ def create_samplers(controller, maximum_step=100):
151
+ with gr.Row():
152
+ sampler_dropdown = gr.Dropdown(label="Sampling method (采样器种类)", choices=list(controller.scheduler_dict.keys()), value=list(controller.scheduler_dict.keys())[0])
153
+ sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=maximum_step, minimum=10, maximum=maximum_step, step=1)
154
+
155
+ return sampler_dropdown, sample_step_slider
156
+
157
+ def create_height_width(default_height, default_width, maximum_height, maximum_width):
158
+ resize_method = gr.Radio(
159
+ ["Generate by", "Resize according to Reference"],
160
+ value="Generate by",
161
+ show_label=False,
162
+ )
163
+ width_slider = gr.Slider(label="Width (视频宽度)", value=default_width, minimum=128, maximum=maximum_width, step=16)
164
+ height_slider = gr.Slider(label="Height (视频高度)", value=default_height, minimum=128, maximum=maximum_height, step=16)
165
+ base_resolution = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 768, 960], visible=False)
166
+
167
+ return resize_method, width_slider, height_slider, base_resolution
168
+
169
+ def create_fake_height_width(default_height, default_width, maximum_height, maximum_width):
170
+ resize_method = gr.Radio(
171
+ ["Generate by", "Resize according to Reference"],
172
+ value="Generate by",
173
+ show_label=False,
174
+ )
175
+ width_slider = gr.Slider(label="Width (视频宽度)", value=default_width, minimum=128, maximum=maximum_width, step=16, interactive=False)
176
+ height_slider = gr.Slider(label="Height (视频高度)", value=default_height, minimum=128, maximum=maximum_height, step=16, interactive=False)
177
+ base_resolution = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 768, 960], interactive=False, visible=False)
178
+
179
+ return resize_method, width_slider, height_slider, base_resolution
180
+
181
+ def create_generation_methods_and_video_length(
182
+ generation_method_options,
183
+ default_video_length,
184
+ maximum_video_length
185
+ ):
186
+ with gr.Group():
187
+ generation_method = gr.Radio(
188
+ generation_method_options,
189
+ value="Video Generation",
190
+ show_label=False,
191
+ )
192
+ with gr.Row():
193
+ length_slider = gr.Slider(label="Animation length (视频帧数)", value=default_video_length, minimum=1, maximum=maximum_video_length, step=4)
194
+ overlap_video_length = gr.Slider(label="Overlap length (视频续写的重叠帧数)", value=4, minimum=1, maximum=4, step=1, visible=False)
195
+ partial_video_length = gr.Slider(label="Partial video generation length (每个部分的视频生成帧数)", value=25, minimum=5, maximum=maximum_video_length, step=4, visible=False)
196
+
197
+ return generation_method, length_slider, overlap_video_length, partial_video_length
198
+
199
+ def create_generation_method(source_method_options, prompt_textbox, support_end_image=True):
200
+ source_method = gr.Radio(
201
+ source_method_options,
202
+ value="Text to Video (文本到视频)",
203
+ show_label=False,
204
+ )
205
+ with gr.Column(visible = False) as image_to_video_col:
206
+ start_image = gr.Image(
207
+ label="The image at the beginning of the video (图片到视频的开始图片)", show_label=True,
208
+ elem_id="i2v_start", sources="upload", type="filepath",
209
+ )
210
+
211
+ template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
212
+ def select_template(evt: gr.SelectData):
213
+ text = {
214
+ "asset/1.png": "A brown dog is shaking its head and sitting on a light colored sofa in a comfortable room. Behind the dog, there is a framed painting on the shelf surrounded by pink flowers. The soft and warm lighting in the room creates a comfortable atmosphere.",
215
+ "asset/2.png": "A sailboat navigates through moderately rough seas, with waves and ocean spray visible. The sailboat features a white hull and sails, accompanied by an orange sail catching the wind. The sky above shows dramatic, cloudy formations with a sunset or sunrise backdrop, casting warm colors across the scene. The water reflects the golden light, enhancing the visual contrast between the dark ocean and the bright horizon. The camera captures the scene with a dynamic and immersive angle, showcasing the movement of the boat and the energy of the ocean.",
216
+ "asset/3.png": "A stunningly beautiful woman with flowing long hair stands gracefully, her elegant dress rippling and billowing in the gentle wind. Petals falling off. Her serene expression and the natural movement of her attire create an enchanting and captivating scene, full of ethereal charm.",
217
+ "asset/4.png": "An astronaut, clad in a full space suit with a helmet, plays an electric guitar while floating in a cosmic environment filled with glowing particles and rocky textures. The scene is illuminated by a warm light source, creating dramatic shadows and contrasts. The background features a complex geometry, similar to a space station or an alien landscape, indicating a futuristic or otherworldly setting.",
218
+ "asset/5.png": "Fireworks light up the evening sky over a sprawling cityscape with gothic-style buildings featuring pointed towers and clock faces. The city is lit by both artificial lights from the buildings and the colorful bursts of the fireworks. The scene is viewed from an elevated angle, showcasing a vibrant urban environment set against a backdrop of a dramatic, partially cloudy sky at dusk.",
219
+ }[template_gallery_path[evt.index]]
220
+ return template_gallery_path[evt.index], text
221
+
222
+ template_gallery = gr.Gallery(
223
+ template_gallery_path,
224
+ columns=5, rows=1,
225
+ height=140,
226
+ allow_preview=False,
227
+ container=False,
228
+ label="Template Examples",
229
+ )
230
+ template_gallery.select(select_template, None, [start_image, prompt_textbox])
231
+
232
+ with gr.Accordion("The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", open=False, visible=support_end_image):
233
+ end_image = gr.Image(label="The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", show_label=False, elem_id="i2v_end", sources="upload", type="filepath")
234
+
235
+ with gr.Column(visible = False) as video_to_video_col:
236
+ with gr.Row():
237
+ validation_video = gr.Video(
238
+ label="The video to convert (视频转视频的参考视频)", show_label=True,
239
+ elem_id="v2v", sources="upload",
240
+ )
241
+ with gr.Accordion("The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])", open=False):
242
+ gr.Markdown(
243
+ """
244
+ - Please set a larger denoise_strength when using validation_video_mask, such as 1.00 instead of 0.70
245
+ (请设置更大的denoise_strength,当使用validation_video_mask的时候,比如1而不是0.70)
246
+ """
247
+ )
248
+ validation_video_mask = gr.Image(
249
+ label="The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])",
250
+ show_label=False, elem_id="v2v_mask", sources="upload", type="filepath"
251
+ )
252
+ denoise_strength = gr.Slider(label="Denoise strength (重绘系数)", value=0.70, minimum=0.10, maximum=1.00, step=0.01)
253
+
254
+ with gr.Column(visible = False) as control_video_col:
255
+ gr.Markdown(
256
+ """
257
+ Demo pose control video can be downloaded here [URL](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/pose.mp4).
258
+ """
259
+ )
260
+ control_video = gr.Video(
261
+ label="The control video (用于提供控制信号的video)", show_label=True,
262
+ elem_id="v2v_control", sources="upload",
263
+ )
264
+ return image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video
265
+
266
+ def create_cfg_and_seedbox(gradio_version_is_above_4):
267
+ cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)", value=6.0, minimum=0, maximum=20)
268
+
269
+ with gr.Row():
270
+ seed_textbox = gr.Textbox(label="Seed (随机种子)", value=43)
271
+ seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
272
+ seed_button.click(
273
+ fn=lambda: gr.Textbox(value=random.randint(1, 1e8)) if gradio_version_is_above_4 else gr.Textbox.update(value=random.randint(1, 1e8)),
274
+ inputs=[],
275
+ outputs=[seed_textbox]
276
+ )
277
+ return cfg_scale_slider, seed_textbox, seed_button
278
+
279
+ def create_ui_outputs():
280
+ with gr.Column():
281
+ result_image = gr.Image(label="Generated Image (生成图片)", interactive=False, visible=False)
282
+ result_video = gr.Video(label="Generated Animation (生成���频)", interactive=False)
283
+ infer_progress = gr.Textbox(
284
+ label="Generation Info (生成信息)",
285
+ value="No task currently",
286
+ interactive=False
287
+ )
288
+ return result_image, result_video, infer_progress
cogvideox/ui/wan_fun_ui.py ADDED
@@ -0,0 +1,742 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py
2
+ """
3
+ import os
4
+ import random
5
+
6
+ import cv2
7
+ import gradio as gr
8
+ import numpy as np
9
+ import torch
10
+ from omegaconf import OmegaConf
11
+ from PIL import Image
12
+ from safetensors import safe_open
13
+
14
+ from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
15
+ from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel,
16
+ WanT5EncoderModel, WanTransformer3DModel)
17
+ from ..pipeline import WanFunInpaintPipeline, WanFunPipeline
18
+ from ..utils.fp8_optimization import (convert_model_weight_to_float8,
19
+ convert_weight_dtype_wrapper,
20
+ replace_parameters_by_name)
21
+ from ..utils.lora_utils import merge_lora, unmerge_lora
22
+ from ..utils.utils import (filter_kwargs, get_image_to_video_latent,
23
+ get_video_to_video_latent, save_videos_grid)
24
+ from .controller import (Fun_Controller, Fun_Controller_EAS, all_cheduler_dict,
25
+ css, ddpm_scheduler_dict, flow_scheduler_dict,
26
+ gradio_version, gradio_version_is_above_4)
27
+ from .ui import (create_cfg_and_seedbox,
28
+ create_fake_finetune_models_checkpoints,
29
+ create_fake_height_width, create_fake_model_checkpoints,
30
+ create_fake_model_type, create_finetune_models_checkpoints,
31
+ create_generation_method,
32
+ create_generation_methods_and_video_length,
33
+ create_height_width, create_model_checkpoints,
34
+ create_model_type, create_prompts, create_samplers,
35
+ create_ui_outputs)
36
+
37
+
38
+ class Wan_Fun_Controller(Fun_Controller):
39
+ def update_diffusion_transformer(self, diffusion_transformer_dropdown):
40
+ print("Update diffusion transformer")
41
+ self.diffusion_transformer_dropdown = diffusion_transformer_dropdown
42
+ if diffusion_transformer_dropdown == "none":
43
+ return gr.update()
44
+ self.vae = AutoencoderKLWan.from_pretrained(
45
+ os.path.join(diffusion_transformer_dropdown, self.config['vae_kwargs'].get('vae_subpath', 'vae')),
46
+ additional_kwargs=OmegaConf.to_container(self.config['vae_kwargs']),
47
+ ).to(self.weight_dtype)
48
+
49
+ # Get Transformer
50
+ self.transformer = WanTransformer3DModel.from_pretrained(
51
+ os.path.join(diffusion_transformer_dropdown, self.config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')),
52
+ transformer_additional_kwargs=OmegaConf.to_container(self.config['transformer_additional_kwargs']),
53
+ low_cpu_mem_usage=True,
54
+ torch_dtype=self.weight_dtype,
55
+ )
56
+
57
+ # Get Tokenizer
58
+ self.tokenizer = AutoTokenizer.from_pretrained(
59
+ os.path.join(diffusion_transformer_dropdown, self.config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')),
60
+ )
61
+
62
+ # Get Text encoder
63
+ self.text_encoder = WanT5EncoderModel.from_pretrained(
64
+ os.path.join(diffusion_transformer_dropdown, self.config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
65
+ additional_kwargs=OmegaConf.to_container(self.config['text_encoder_kwargs']),
66
+ ).to(self.weight_dtype)
67
+ self.text_encoder = self.text_encoder.eval()
68
+
69
+ if self.transformer.config.in_channels != self.vae.config.latent_channels:
70
+ # Get Clip Image Encoder
71
+ self.clip_image_encoder = CLIPModel.from_pretrained(
72
+ os.path.join(diffusion_transformer_dropdown, self.config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')),
73
+ ).to(self.weight_dtype)
74
+ self.clip_image_encoder = self.clip_image_encoder.eval()
75
+ else:
76
+ self.clip_image_encoder = None
77
+
78
+ Choosen_Scheduler = self.scheduler_dict[list(self.scheduler_dict.keys())[0]]
79
+ self.scheduler = Choosen_Scheduler(
80
+ **filter_kwargs(Choosen_Scheduler, OmegaConf.to_container(self.config['scheduler_kwargs']))
81
+ )
82
+
83
+ # Get pipeline
84
+ if self.model_type == "Inpaint":
85
+ if self.transformer.config.in_channels != self.vae.config.latent_channels:
86
+ self.pipeline = WanFunInpaintPipeline(
87
+ vae=self.vae,
88
+ tokenizer=self.tokenizer,
89
+ text_encoder=self.text_encoder,
90
+ transformer=self.transformer,
91
+ scheduler=self.scheduler,
92
+ clip_image_encoder=self.clip_image_encoder,
93
+ )
94
+ else:
95
+ self.pipeline = WanFunPipeline(
96
+ vae=self.vae,
97
+ tokenizer=self.tokenizer,
98
+ text_encoder=self.text_encoder,
99
+ transformer=self.transformer,
100
+ scheduler=self.scheduler,
101
+ )
102
+ else:
103
+ raise ValueError("Not support now")
104
+
105
+ if self.GPU_memory_mode == "sequential_cpu_offload":
106
+ replace_parameters_by_name(self.transformer, ["modulation",], device="cuda")
107
+ self.transformer.freqs = self.transformer.freqs.to(device="cuda")
108
+ self.pipeline.enable_sequential_cpu_offload()
109
+ elif self.GPU_memory_mode == "model_cpu_offload_and_qfloat8":
110
+ convert_model_weight_to_float8(self.transformer, exclude_module_name=["modulation",])
111
+ convert_weight_dtype_wrapper(self.transformer, self.weight_dtype)
112
+ self.pipeline.enable_model_cpu_offload()
113
+ else:
114
+ self.pipeline.enable_model_cpu_offload()
115
+ print("Update diffusion transformer done")
116
+ return gr.update()
117
+
118
+ def generate(
119
+ self,
120
+ diffusion_transformer_dropdown,
121
+ base_model_dropdown,
122
+ lora_model_dropdown,
123
+ lora_alpha_slider,
124
+ prompt_textbox,
125
+ negative_prompt_textbox,
126
+ sampler_dropdown,
127
+ sample_step_slider,
128
+ resize_method,
129
+ width_slider,
130
+ height_slider,
131
+ base_resolution,
132
+ generation_method,
133
+ length_slider,
134
+ overlap_video_length,
135
+ partial_video_length,
136
+ cfg_scale_slider,
137
+ start_image,
138
+ end_image,
139
+ validation_video,
140
+ validation_video_mask,
141
+ control_video,
142
+ denoise_strength,
143
+ seed_textbox,
144
+ is_api = False,
145
+ ):
146
+ self.clear_cache()
147
+
148
+ self.input_check(
149
+ resize_method, generation_method, start_image, end_image, validation_video,control_video, is_api
150
+ )
151
+ is_image = True if generation_method == "Image Generation" else False
152
+
153
+ if self.base_model_path != base_model_dropdown:
154
+ self.update_base_model(base_model_dropdown)
155
+
156
+ if self.lora_model_path != lora_model_dropdown:
157
+ self.update_lora_model(lora_model_dropdown)
158
+
159
+ self.pipeline.scheduler = self.scheduler_dict[sampler_dropdown].from_config(self.pipeline.scheduler.config)
160
+
161
+ if resize_method == "Resize according to Reference":
162
+ height_slider, width_slider = self.get_height_width_from_reference(
163
+ base_resolution, start_image, validation_video, control_video,
164
+ )
165
+ if self.lora_model_path != "none":
166
+ # lora part
167
+ self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
168
+
169
+ if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
170
+ else: seed_textbox = np.random.randint(0, 1e10)
171
+ generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
172
+
173
+ try:
174
+ if self.model_type == "Inpaint":
175
+ if self.transformer.config.in_channels != self.vae.config.latent_channels:
176
+ if generation_method == "Long Video Generation":
177
+ if validation_video is not None:
178
+ raise gr.Error(f"Video to Video is not Support Long Video Generation now.")
179
+ init_frames = 0
180
+ last_frames = init_frames + partial_video_length
181
+ while init_frames < length_slider:
182
+ if last_frames >= length_slider:
183
+ _partial_video_length = length_slider - init_frames
184
+ _partial_video_length = int((_partial_video_length - 1) // self.vae.config.temporal_compression_ratio * self.vae.config.temporal_compression_ratio) + 1
185
+
186
+ if _partial_video_length <= 0:
187
+ break
188
+ else:
189
+ _partial_video_length = partial_video_length
190
+
191
+ if last_frames >= length_slider:
192
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=_partial_video_length, sample_size=(height_slider, width_slider))
193
+ else:
194
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, None, video_length=_partial_video_length, sample_size=(height_slider, width_slider))
195
+
196
+ with torch.no_grad():
197
+ sample = self.pipeline(
198
+ prompt_textbox,
199
+ negative_prompt = negative_prompt_textbox,
200
+ num_inference_steps = sample_step_slider,
201
+ guidance_scale = cfg_scale_slider,
202
+ width = width_slider,
203
+ height = height_slider,
204
+ num_frames = _partial_video_length,
205
+ generator = generator,
206
+
207
+ video = input_video,
208
+ mask_video = input_video_mask,
209
+ clip_image = clip_image
210
+ ).videos
211
+
212
+ if init_frames != 0:
213
+ mix_ratio = torch.from_numpy(
214
+ np.array([float(_index) / float(overlap_video_length) for _index in range(overlap_video_length)], np.float32)
215
+ ).unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
216
+
217
+ new_sample[:, :, -overlap_video_length:] = new_sample[:, :, -overlap_video_length:] * (1 - mix_ratio) + \
218
+ sample[:, :, :overlap_video_length] * mix_ratio
219
+ new_sample = torch.cat([new_sample, sample[:, :, overlap_video_length:]], dim = 2)
220
+
221
+ sample = new_sample
222
+ else:
223
+ new_sample = sample
224
+
225
+ if last_frames >= length_slider:
226
+ break
227
+
228
+ start_image = [
229
+ Image.fromarray(
230
+ (sample[0, :, _index].transpose(0, 1).transpose(1, 2) * 255).numpy().astype(np.uint8)
231
+ ) for _index in range(-overlap_video_length, 0)
232
+ ]
233
+
234
+ init_frames = init_frames + _partial_video_length - overlap_video_length
235
+ last_frames = init_frames + _partial_video_length
236
+ else:
237
+ if validation_video is not None:
238
+ input_video, input_video_mask, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), validation_video_mask=validation_video_mask, fps=16)
239
+ strength = denoise_strength
240
+ else:
241
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
242
+ strength = 1
243
+
244
+ sample = self.pipeline(
245
+ prompt_textbox,
246
+ negative_prompt = negative_prompt_textbox,
247
+ num_inference_steps = sample_step_slider,
248
+ guidance_scale = cfg_scale_slider,
249
+ width = width_slider,
250
+ height = height_slider,
251
+ num_frames = length_slider if not is_image else 1,
252
+ generator = generator,
253
+
254
+ video = input_video,
255
+ mask_video = input_video_mask,
256
+ clip_image = clip_image
257
+ ).videos
258
+ else:
259
+ sample = self.pipeline(
260
+ prompt_textbox,
261
+ negative_prompt = negative_prompt_textbox,
262
+ num_inference_steps = sample_step_slider,
263
+ guidance_scale = cfg_scale_slider,
264
+ width = width_slider,
265
+ height = height_slider,
266
+ num_frames = length_slider if not is_image else 1,
267
+ generator = generator
268
+ ).videos
269
+ else:
270
+ input_video, input_video_mask, clip_image = get_video_to_video_latent(control_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), fps=16)
271
+
272
+ sample = self.pipeline(
273
+ prompt_textbox,
274
+ negative_prompt = negative_prompt_textbox,
275
+ num_inference_steps = sample_step_slider,
276
+ guidance_scale = cfg_scale_slider,
277
+ width = width_slider,
278
+ height = height_slider,
279
+ num_frames = length_slider if not is_image else 1,
280
+ generator = generator,
281
+
282
+ control_video = input_video,
283
+ ).videos
284
+ except Exception as e:
285
+ self.clear_cache()
286
+ if self.lora_model_path != "none":
287
+ self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
288
+ if is_api:
289
+ return "", f"Error. error information is {str(e)}"
290
+ else:
291
+ return gr.update(), gr.update(), f"Error. error information is {str(e)}"
292
+
293
+ self.clear_cache()
294
+ # lora part
295
+ if self.lora_model_path != "none":
296
+ self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
297
+
298
+ save_sample_path = self.save_outputs(
299
+ is_image, length_slider, sample, fps=16
300
+ )
301
+
302
+ if is_image or length_slider == 1:
303
+ if is_api:
304
+ return save_sample_path, "Success"
305
+ else:
306
+ if gradio_version_is_above_4:
307
+ return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success"
308
+ else:
309
+ return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
310
+ else:
311
+ if is_api:
312
+ return save_sample_path, "Success"
313
+ else:
314
+ if gradio_version_is_above_4:
315
+ return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success"
316
+ else:
317
+ return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
318
+
319
+
320
+ class Wan_Fun_Controller_Modelscope(Wan_Fun_Controller):
321
+ def __init__(self, model_name, model_type, savedir_sample, GPU_memory_mode, scheduler_dict, weight_dtype, config_path):
322
+ # Basic dir
323
+ self.basedir = os.getcwd()
324
+ self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model")
325
+ self.lora_model_path = "none"
326
+ self.base_model_path = "none"
327
+ self.savedir_sample = savedir_sample
328
+ self.scheduler_dict = scheduler_dict
329
+ self.config = OmegaConf.load(config_path)
330
+ self.refresh_personalized_model()
331
+ os.makedirs(self.savedir_sample, exist_ok=True)
332
+
333
+ # model path
334
+ self.model_type = model_type
335
+ self.weight_dtype = weight_dtype
336
+
337
+ self.vae = AutoencoderKLWan.from_pretrained(
338
+ os.path.join(model_name, self.config['vae_kwargs'].get('vae_subpath', 'vae')),
339
+ additional_kwargs=OmegaConf.to_container(self.config['vae_kwargs']),
340
+ ).to(self.weight_dtype)
341
+
342
+ # Get Transformer
343
+ self.transformer = WanTransformer3DModel.from_pretrained(
344
+ os.path.join(model_name, self.config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')),
345
+ transformer_additional_kwargs=OmegaConf.to_container(self.config['transformer_additional_kwargs']),
346
+ low_cpu_mem_usage=True,
347
+ torch_dtype=self.weight_dtype,
348
+ )
349
+
350
+ # Get Tokenizer
351
+ self.tokenizer = AutoTokenizer.from_pretrained(
352
+ os.path.join(model_name, self.config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')),
353
+ )
354
+
355
+ # Get Text encoder
356
+ self.text_encoder = WanT5EncoderModel.from_pretrained(
357
+ os.path.join(model_name, self.config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
358
+ additional_kwargs=OmegaConf.to_container(self.config['text_encoder_kwargs']),
359
+ ).to(self.weight_dtype)
360
+ self.text_encoder = self.text_encoder.eval()
361
+
362
+ if self.transformer.config.in_channels != self.vae.config.latent_channels:
363
+ # Get Clip Image Encoder
364
+ self.clip_image_encoder = CLIPModel.from_pretrained(
365
+ os.path.join(model_name, self.config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')),
366
+ ).to(self.weight_dtype)
367
+ self.clip_image_encoder = self.clip_image_encoder.eval()
368
+ else:
369
+ self.clip_image_encoder = None
370
+
371
+ Choosen_Scheduler = self.scheduler_dict[list(self.scheduler_dict.keys())[0]]
372
+ self.scheduler = Choosen_Scheduler(
373
+ **filter_kwargs(Choosen_Scheduler, OmegaConf.to_container(self.config['scheduler_kwargs']))
374
+ )
375
+
376
+ # Get pipeline
377
+ if self.model_type == "Inpaint":
378
+ if self.transformer.config.in_channels != self.vae.config.latent_channels:
379
+ self.pipeline = WanFunInpaintPipeline(
380
+ vae=self.vae,
381
+ tokenizer=self.tokenizer,
382
+ text_encoder=self.text_encoder,
383
+ transformer=self.transformer,
384
+ scheduler=self.scheduler,
385
+ clip_image_encoder=self.clip_image_encoder,
386
+ )
387
+ else:
388
+ self.pipeline = WanFunPipeline(
389
+ vae=self.vae,
390
+ tokenizer=self.tokenizer,
391
+ text_encoder=self.text_encoder,
392
+ transformer=self.transformer,
393
+ scheduler=self.scheduler,
394
+ )
395
+ else:
396
+ raise ValueError("Not support now")
397
+
398
+ if GPU_memory_mode == "sequential_cpu_offload":
399
+ replace_parameters_by_name(self.transformer, ["modulation",], device="cuda")
400
+ self.transformer.freqs = self.transformer.freqs.to(device="cuda")
401
+ self.pipeline.enable_sequential_cpu_offload()
402
+ elif GPU_memory_mode == "model_cpu_offload_and_qfloat8":
403
+ convert_model_weight_to_float8(self.transformer, exclude_module_name=["modulation",])
404
+ convert_weight_dtype_wrapper(self.transformer, self.weight_dtype)
405
+ self.pipeline.enable_model_cpu_offload()
406
+ else:
407
+ self.pipeline.enable_model_cpu_offload()
408
+ print("Update diffusion transformer done")
409
+
410
+ Wan_Fun_Controller_EAS = Fun_Controller_EAS
411
+
412
+ def ui(GPU_memory_mode, scheduler_dict, weight_dtype, config_path):
413
+ controller = Wan_Fun_Controller(GPU_memory_mode, scheduler_dict, weight_dtype, config_path)
414
+
415
+ with gr.Blocks(css=css) as demo:
416
+ gr.Markdown(
417
+ """
418
+ # Wan-Fun:
419
+
420
+ A Wan with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 81), as well as image generated videos.
421
+
422
+ [Github](https://github.com/aigc-apps/CogVideoX-Fun/)
423
+ """
424
+ )
425
+ with gr.Column(variant="panel"):
426
+ model_type = create_model_type(visible=True)
427
+ diffusion_transformer_dropdown, diffusion_transformer_refresh_button = \
428
+ create_model_checkpoints(controller, visible=True)
429
+ base_model_dropdown, lora_model_dropdown, lora_alpha_slider, personalized_refresh_button = \
430
+ create_finetune_models_checkpoints(controller, visible=True)
431
+
432
+ with gr.Column(variant="panel"):
433
+ prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走")
434
+
435
+ with gr.Row():
436
+ with gr.Column():
437
+ sampler_dropdown, sample_step_slider = create_samplers(controller)
438
+
439
+ resize_method, width_slider, height_slider, base_resolution = create_height_width(
440
+ default_height = 480, default_width = 832, maximum_height = 1344,
441
+ maximum_width = 1344,
442
+ )
443
+ generation_method, length_slider, overlap_video_length, partial_video_length = \
444
+ create_generation_methods_and_video_length(
445
+ ["Video Generation", "Image Generation"],
446
+ default_video_length=81,
447
+ maximum_video_length=81,
448
+ )
449
+ image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video = create_generation_method(
450
+ ["Text to Video (文本到视频)", "Image to Video (图片到视频)"], prompt_textbox
451
+ )
452
+ cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4)
453
+
454
+ generate_button = gr.Button(value="Generate (生成)", variant='primary')
455
+
456
+ result_image, result_video, infer_progress = create_ui_outputs()
457
+
458
+ model_type.change(
459
+ fn=controller.update_model_type,
460
+ inputs=[model_type],
461
+ outputs=[]
462
+ )
463
+
464
+ def upload_generation_method(generation_method):
465
+ if generation_method == "Video Generation":
466
+ return [gr.update(visible=True, maximum=81, value=81, interactive=True), gr.update(visible=False), gr.update(visible=False)]
467
+ elif generation_method == "Image Generation":
468
+ return [gr.update(minimum=1, maximum=1, value=1, interactive=False), gr.update(visible=False), gr.update(visible=False)]
469
+ else:
470
+ return [gr.update(visible=True, maximum=1344), gr.update(visible=True), gr.update(visible=True)]
471
+ generation_method.change(
472
+ upload_generation_method, generation_method, [length_slider, overlap_video_length, partial_video_length]
473
+ )
474
+
475
+ def upload_source_method(source_method):
476
+ if source_method == "Text to Video (文本到视频)":
477
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
478
+ elif source_method == "Image to Video (图片到视频)":
479
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
480
+ elif source_method == "Video to Video (视频到视频)":
481
+ return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)]
482
+ else:
483
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()]
484
+ source_method.change(
485
+ upload_source_method, source_method, [
486
+ image_to_video_col, video_to_video_col, control_video_col, start_image, end_image,
487
+ validation_video, validation_video_mask, control_video
488
+ ]
489
+ )
490
+
491
+ def upload_resize_method(resize_method):
492
+ if resize_method == "Generate by":
493
+ return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
494
+ else:
495
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
496
+ resize_method.change(
497
+ upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
498
+ )
499
+
500
+ generate_button.click(
501
+ fn=controller.generate,
502
+ inputs=[
503
+ diffusion_transformer_dropdown,
504
+ base_model_dropdown,
505
+ lora_model_dropdown,
506
+ lora_alpha_slider,
507
+ prompt_textbox,
508
+ negative_prompt_textbox,
509
+ sampler_dropdown,
510
+ sample_step_slider,
511
+ resize_method,
512
+ width_slider,
513
+ height_slider,
514
+ base_resolution,
515
+ generation_method,
516
+ length_slider,
517
+ overlap_video_length,
518
+ partial_video_length,
519
+ cfg_scale_slider,
520
+ start_image,
521
+ end_image,
522
+ validation_video,
523
+ validation_video_mask,
524
+ control_video,
525
+ denoise_strength,
526
+ seed_textbox,
527
+ ],
528
+ outputs=[result_image, result_video, infer_progress]
529
+ )
530
+ return demo, controller
531
+
532
+ def ui_modelscope(model_name, model_type, savedir_sample, GPU_memory_mode, scheduler_dict, weight_dtype, config_path):
533
+ controller = Wan_Fun_Controller_Modelscope(model_name, model_type, savedir_sample, GPU_memory_mode, scheduler_dict, weight_dtype, config_path)
534
+
535
+ with gr.Blocks(css=css) as demo:
536
+ gr.Markdown(
537
+ """
538
+ # Wan-Fun:
539
+
540
+ A Wan with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 81), as well as image generated videos.
541
+
542
+ [Github](https://github.com/aigc-apps/CogVideoX-Fun/)
543
+ """
544
+ )
545
+ with gr.Column(variant="panel"):
546
+ model_type = create_fake_model_type(visible=True)
547
+ diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True)
548
+ base_model_dropdown, lora_model_dropdown, lora_alpha_slider = create_fake_finetune_models_checkpoints(visible=True)
549
+
550
+ with gr.Column(variant="panel"):
551
+ prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走")
552
+
553
+ with gr.Row():
554
+ with gr.Column():
555
+ sampler_dropdown, sample_step_slider = create_samplers(controller)
556
+
557
+ resize_method, width_slider, height_slider, base_resolution = create_height_width(
558
+ default_height = 480, default_width = 832, maximum_height = 1344,
559
+ maximum_width = 1344,
560
+ )
561
+ generation_method, length_slider, overlap_video_length, partial_video_length = \
562
+ create_generation_methods_and_video_length(
563
+ ["Video Generation", "Image Generation"],
564
+ default_video_length=81,
565
+ maximum_video_length=81,
566
+ )
567
+ image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video = create_generation_method(
568
+ ["Text to Video (文本到视频)", "Image to Video (图片到视频)"], prompt_textbox
569
+ )
570
+ cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4)
571
+
572
+ generate_button = gr.Button(value="Generate (生成)", variant='primary')
573
+
574
+ result_image, result_video, infer_progress = create_ui_outputs()
575
+
576
+ def upload_generation_method(generation_method):
577
+ if generation_method == "Video Generation":
578
+ return gr.update(visible=True, minimum=1, maximum=81, value=81, interactive=True)
579
+ elif generation_method == "Image Generation":
580
+ return gr.update(minimum=1, maximum=1, value=1, interactive=False)
581
+ generation_method.change(
582
+ upload_generation_method, generation_method, [length_slider]
583
+ )
584
+
585
+ def upload_source_method(source_method):
586
+ if source_method == "Text to Video (文本到视频)":
587
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
588
+ elif source_method == "Image to Video (图片到视频)":
589
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
590
+ elif source_method == "Video to Video (视频到视频)":
591
+ return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)]
592
+ else:
593
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()]
594
+ source_method.change(
595
+ upload_source_method, source_method, [
596
+ image_to_video_col, video_to_video_col, control_video_col, start_image, end_image,
597
+ validation_video, validation_video_mask, control_video
598
+ ]
599
+ )
600
+
601
+ def upload_resize_method(resize_method):
602
+ if resize_method == "Generate by":
603
+ return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
604
+ else:
605
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
606
+ resize_method.change(
607
+ upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
608
+ )
609
+
610
+ generate_button.click(
611
+ fn=controller.generate,
612
+ inputs=[
613
+ diffusion_transformer_dropdown,
614
+ base_model_dropdown,
615
+ lora_model_dropdown,
616
+ lora_alpha_slider,
617
+ prompt_textbox,
618
+ negative_prompt_textbox,
619
+ sampler_dropdown,
620
+ sample_step_slider,
621
+ resize_method,
622
+ width_slider,
623
+ height_slider,
624
+ base_resolution,
625
+ generation_method,
626
+ length_slider,
627
+ overlap_video_length,
628
+ partial_video_length,
629
+ cfg_scale_slider,
630
+ start_image,
631
+ end_image,
632
+ validation_video,
633
+ validation_video_mask,
634
+ control_video,
635
+ denoise_strength,
636
+ seed_textbox,
637
+ ],
638
+ outputs=[result_image, result_video, infer_progress]
639
+ )
640
+ return demo, controller
641
+
642
+ def ui_eas(model_name, scheduler_dict, savedir_sample, config_path):
643
+ controller = Wan_Fun_Controller_EAS(model_name, scheduler_dict, savedir_sample)
644
+
645
+ with gr.Blocks(css=css) as demo:
646
+ gr.Markdown(
647
+ """
648
+ # Wan-Fun:
649
+
650
+ A Wan with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 81), as well as image generated videos.
651
+
652
+ [Github](https://github.com/aigc-apps/CogVideoX-Fun/)
653
+ """
654
+ )
655
+ with gr.Column(variant="panel"):
656
+ diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True)
657
+ base_model_dropdown, lora_model_dropdown, lora_alpha_slider = create_fake_finetune_models_checkpoints(visible=True)
658
+
659
+ with gr.Column(variant="panel"):
660
+ prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走")
661
+
662
+ with gr.Row():
663
+ with gr.Column():
664
+ sampler_dropdown, sample_step_slider = create_samplers(controller, maximum_step=40)
665
+
666
+ resize_method, width_slider, height_slider, base_resolution = create_fake_height_width(
667
+ default_height = 480, default_width = 832, maximum_height = 1344,
668
+ maximum_width = 1344,
669
+ )
670
+ generation_method, length_slider, overlap_video_length, partial_video_length = \
671
+ create_generation_methods_and_video_length(
672
+ ["Video Generation", "Image Generation"],
673
+ default_video_length=29,
674
+ maximum_video_length=29,
675
+ )
676
+ image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video = create_generation_method(
677
+ ["Text to Video (文本到视频)", "Image to Video (图片到视频)"], prompt_textbox
678
+ )
679
+
680
+ cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4)
681
+
682
+ generate_button = gr.Button(value="Generate (生成)", variant='primary')
683
+
684
+ result_image, result_video, infer_progress = create_ui_outputs()
685
+
686
+ def upload_generation_method(generation_method):
687
+ if generation_method == "Video Generation":
688
+ return gr.update(visible=True, minimum=5, maximum=29, value=29, interactive=True)
689
+ elif generation_method == "Image Generation":
690
+ return gr.update(minimum=1, maximum=1, value=1, interactive=False)
691
+ generation_method.change(
692
+ upload_generation_method, generation_method, [length_slider]
693
+ )
694
+
695
+ def upload_source_method(source_method):
696
+ if source_method == "Text to Video (文本到视频)":
697
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
698
+ elif source_method == "Image to Video (图片到视频)":
699
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None)]
700
+ else:
701
+ return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(), gr.update()]
702
+ source_method.change(
703
+ upload_source_method, source_method, [image_to_video_col, video_to_video_col, start_image, end_image, validation_video, validation_video_mask]
704
+ )
705
+
706
+ def upload_resize_method(resize_method):
707
+ if resize_method == "Generate by":
708
+ return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
709
+ else:
710
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
711
+ resize_method.change(
712
+ upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
713
+ )
714
+
715
+ generate_button.click(
716
+ fn=controller.generate,
717
+ inputs=[
718
+ diffusion_transformer_dropdown,
719
+ base_model_dropdown,
720
+ lora_model_dropdown,
721
+ lora_alpha_slider,
722
+ prompt_textbox,
723
+ negative_prompt_textbox,
724
+ sampler_dropdown,
725
+ sample_step_slider,
726
+ resize_method,
727
+ width_slider,
728
+ height_slider,
729
+ base_resolution,
730
+ generation_method,
731
+ length_slider,
732
+ cfg_scale_slider,
733
+ start_image,
734
+ end_image,
735
+ validation_video,
736
+ validation_video_mask,
737
+ denoise_strength,
738
+ seed_textbox,
739
+ ],
740
+ outputs=[result_image, result_video, infer_progress]
741
+ )
742
+ return demo, controller
cogvideox/ui/wan_ui.py ADDED
@@ -0,0 +1,730 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py
2
+ """
3
+ import os
4
+ import random
5
+
6
+ import cv2
7
+ import gradio as gr
8
+ import numpy as np
9
+ import torch
10
+ from omegaconf import OmegaConf
11
+ from PIL import Image
12
+ from safetensors import safe_open
13
+
14
+ from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
15
+ from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel,
16
+ WanT5EncoderModel, WanTransformer3DModel)
17
+ from ..pipeline import WanI2VPipeline, WanPipeline
18
+ from ..utils.fp8_optimization import (convert_model_weight_to_float8,
19
+ convert_weight_dtype_wrapper,
20
+ replace_parameters_by_name)
21
+ from ..utils.lora_utils import merge_lora, unmerge_lora
22
+ from ..utils.utils import (filter_kwargs, get_image_to_video_latent,
23
+ get_video_to_video_latent, save_videos_grid)
24
+ from .controller import (Fun_Controller, Fun_Controller_EAS, all_cheduler_dict,
25
+ css, ddpm_scheduler_dict, flow_scheduler_dict,
26
+ gradio_version, gradio_version_is_above_4)
27
+ from .ui import (create_cfg_and_seedbox,
28
+ create_fake_finetune_models_checkpoints,
29
+ create_fake_height_width, create_fake_model_checkpoints,
30
+ create_fake_model_type, create_finetune_models_checkpoints,
31
+ create_generation_method,
32
+ create_generation_methods_and_video_length,
33
+ create_height_width, create_model_checkpoints,
34
+ create_model_type, create_prompts, create_samplers,
35
+ create_ui_outputs)
36
+
37
+
38
+ class Wan_Controller(Fun_Controller):
39
+ def update_diffusion_transformer(self, diffusion_transformer_dropdown):
40
+ print("Update diffusion transformer")
41
+ self.diffusion_transformer_dropdown = diffusion_transformer_dropdown
42
+ if diffusion_transformer_dropdown == "none":
43
+ return gr.update()
44
+ self.vae = AutoencoderKLWan.from_pretrained(
45
+ os.path.join(diffusion_transformer_dropdown, self.config['vae_kwargs'].get('vae_subpath', 'vae')),
46
+ additional_kwargs=OmegaConf.to_container(self.config['vae_kwargs']),
47
+ ).to(self.weight_dtype)
48
+
49
+ # Get Transformer
50
+ self.transformer = WanTransformer3DModel.from_pretrained(
51
+ os.path.join(diffusion_transformer_dropdown, self.config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')),
52
+ transformer_additional_kwargs=OmegaConf.to_container(self.config['transformer_additional_kwargs']),
53
+ low_cpu_mem_usage=True,
54
+ torch_dtype=self.weight_dtype,
55
+ )
56
+
57
+ # Get Tokenizer
58
+ self.tokenizer = AutoTokenizer.from_pretrained(
59
+ os.path.join(diffusion_transformer_dropdown, self.config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')),
60
+ )
61
+
62
+ # Get Text encoder
63
+ self.text_encoder = WanT5EncoderModel.from_pretrained(
64
+ os.path.join(diffusion_transformer_dropdown, self.config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
65
+ additional_kwargs=OmegaConf.to_container(self.config['text_encoder_kwargs']),
66
+ ).to(self.weight_dtype)
67
+ self.text_encoder = self.text_encoder.eval()
68
+
69
+ if self.transformer.config.in_channels != self.vae.config.latent_channels:
70
+ # Get Clip Image Encoder
71
+ self.clip_image_encoder = CLIPModel.from_pretrained(
72
+ os.path.join(diffusion_transformer_dropdown, self.config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')),
73
+ ).to(self.weight_dtype)
74
+ self.clip_image_encoder = self.clip_image_encoder.eval()
75
+ else:
76
+ self.clip_image_encoder = None
77
+
78
+ Choosen_Scheduler = self.scheduler_dict[list(self.scheduler_dict.keys())[0]]
79
+ self.scheduler = Choosen_Scheduler(
80
+ **filter_kwargs(Choosen_Scheduler, OmegaConf.to_container(self.config['scheduler_kwargs']))
81
+ )
82
+
83
+ # Get pipeline
84
+ if self.model_type == "Inpaint":
85
+ if self.transformer.config.in_channels != self.vae.config.latent_channels:
86
+ self.pipeline = WanI2VPipeline(
87
+ vae=self.vae,
88
+ tokenizer=self.tokenizer,
89
+ text_encoder=self.text_encoder,
90
+ transformer=self.transformer,
91
+ scheduler=self.scheduler,
92
+ clip_image_encoder=self.clip_image_encoder,
93
+ )
94
+ else:
95
+ self.pipeline = WanPipeline(
96
+ vae=self.vae,
97
+ tokenizer=self.tokenizer,
98
+ text_encoder=self.text_encoder,
99
+ transformer=self.transformer,
100
+ scheduler=self.scheduler,
101
+ )
102
+ else:
103
+ raise ValueError("Not support now")
104
+
105
+ if self.GPU_memory_mode == "sequential_cpu_offload":
106
+ replace_parameters_by_name(self.transformer, ["modulation",], device="cuda")
107
+ self.transformer.freqs = self.transformer.freqs.to(device="cuda")
108
+ self.pipeline.enable_sequential_cpu_offload()
109
+ elif self.GPU_memory_mode == "model_cpu_offload_and_qfloat8":
110
+ convert_model_weight_to_float8(self.transformer, exclude_module_name=["modulation",])
111
+ convert_weight_dtype_wrapper(self.transformer, self.weight_dtype)
112
+ self.pipeline.enable_model_cpu_offload()
113
+ else:
114
+ self.pipeline.enable_model_cpu_offload()
115
+ print("Update diffusion transformer done")
116
+ return gr.update()
117
+
118
+ def generate(
119
+ self,
120
+ diffusion_transformer_dropdown,
121
+ base_model_dropdown,
122
+ lora_model_dropdown,
123
+ lora_alpha_slider,
124
+ prompt_textbox,
125
+ negative_prompt_textbox,
126
+ sampler_dropdown,
127
+ sample_step_slider,
128
+ resize_method,
129
+ width_slider,
130
+ height_slider,
131
+ base_resolution,
132
+ generation_method,
133
+ length_slider,
134
+ overlap_video_length,
135
+ partial_video_length,
136
+ cfg_scale_slider,
137
+ start_image,
138
+ end_image,
139
+ validation_video,
140
+ validation_video_mask,
141
+ control_video,
142
+ denoise_strength,
143
+ seed_textbox,
144
+ is_api = False,
145
+ ):
146
+ self.clear_cache()
147
+
148
+ self.input_check(
149
+ resize_method, generation_method, start_image, end_image, validation_video,control_video, is_api
150
+ )
151
+ is_image = True if generation_method == "Image Generation" else False
152
+
153
+ if self.base_model_path != base_model_dropdown:
154
+ self.update_base_model(base_model_dropdown)
155
+
156
+ if self.lora_model_path != lora_model_dropdown:
157
+ self.update_lora_model(lora_model_dropdown)
158
+
159
+ self.pipeline.scheduler = self.scheduler_dict[sampler_dropdown].from_config(self.pipeline.scheduler.config)
160
+
161
+ if resize_method == "Resize according to Reference":
162
+ height_slider, width_slider = self.get_height_width_from_reference(
163
+ base_resolution, start_image, validation_video, control_video,
164
+ )
165
+ if self.lora_model_path != "none":
166
+ # lora part
167
+ self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
168
+
169
+ if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
170
+ else: seed_textbox = np.random.randint(0, 1e10)
171
+ generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
172
+
173
+ try:
174
+ if self.model_type == "Inpaint":
175
+ if self.transformer.config.in_channels != self.vae.config.latent_channels:
176
+ if generation_method == "Long Video Generation":
177
+ if validation_video is not None:
178
+ raise gr.Error(f"Video to Video is not Support Long Video Generation now.")
179
+ init_frames = 0
180
+ last_frames = init_frames + partial_video_length
181
+ while init_frames < length_slider:
182
+ if last_frames >= length_slider:
183
+ _partial_video_length = length_slider - init_frames
184
+ _partial_video_length = int((_partial_video_length - 1) // self.vae.config.temporal_compression_ratio * self.vae.config.temporal_compression_ratio) + 1
185
+
186
+ if _partial_video_length <= 0:
187
+ break
188
+ else:
189
+ _partial_video_length = partial_video_length
190
+
191
+ if last_frames >= length_slider:
192
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=_partial_video_length, sample_size=(height_slider, width_slider))
193
+ else:
194
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, None, video_length=_partial_video_length, sample_size=(height_slider, width_slider))
195
+
196
+ with torch.no_grad():
197
+ sample = self.pipeline(
198
+ prompt_textbox,
199
+ negative_prompt = negative_prompt_textbox,
200
+ num_inference_steps = sample_step_slider,
201
+ guidance_scale = cfg_scale_slider,
202
+ width = width_slider,
203
+ height = height_slider,
204
+ num_frames = _partial_video_length,
205
+ generator = generator,
206
+
207
+ video = input_video,
208
+ mask_video = input_video_mask,
209
+ clip_image = clip_image
210
+ ).videos
211
+
212
+ if init_frames != 0:
213
+ mix_ratio = torch.from_numpy(
214
+ np.array([float(_index) / float(overlap_video_length) for _index in range(overlap_video_length)], np.float32)
215
+ ).unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
216
+
217
+ new_sample[:, :, -overlap_video_length:] = new_sample[:, :, -overlap_video_length:] * (1 - mix_ratio) + \
218
+ sample[:, :, :overlap_video_length] * mix_ratio
219
+ new_sample = torch.cat([new_sample, sample[:, :, overlap_video_length:]], dim = 2)
220
+
221
+ sample = new_sample
222
+ else:
223
+ new_sample = sample
224
+
225
+ if last_frames >= length_slider:
226
+ break
227
+
228
+ start_image = [
229
+ Image.fromarray(
230
+ (sample[0, :, _index].transpose(0, 1).transpose(1, 2) * 255).numpy().astype(np.uint8)
231
+ ) for _index in range(-overlap_video_length, 0)
232
+ ]
233
+
234
+ init_frames = init_frames + _partial_video_length - overlap_video_length
235
+ last_frames = init_frames + _partial_video_length
236
+ else:
237
+ if validation_video is not None:
238
+ input_video, input_video_mask, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), validation_video_mask=validation_video_mask, fps=16)
239
+ strength = denoise_strength
240
+ else:
241
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
242
+ strength = 1
243
+
244
+ sample = self.pipeline(
245
+ prompt_textbox,
246
+ negative_prompt = negative_prompt_textbox,
247
+ num_inference_steps = sample_step_slider,
248
+ guidance_scale = cfg_scale_slider,
249
+ width = width_slider,
250
+ height = height_slider,
251
+ num_frames = length_slider if not is_image else 1,
252
+ generator = generator,
253
+
254
+ video = input_video,
255
+ mask_video = input_video_mask,
256
+ clip_image = clip_image
257
+ ).videos
258
+ else:
259
+ sample = self.pipeline(
260
+ prompt_textbox,
261
+ negative_prompt = negative_prompt_textbox,
262
+ num_inference_steps = sample_step_slider,
263
+ guidance_scale = cfg_scale_slider,
264
+ width = width_slider,
265
+ height = height_slider,
266
+ num_frames = length_slider if not is_image else 1,
267
+ generator = generator
268
+ ).videos
269
+ else:
270
+ input_video, input_video_mask, clip_image = get_video_to_video_latent(control_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), fps=16)
271
+
272
+ sample = self.pipeline(
273
+ prompt_textbox,
274
+ negative_prompt = negative_prompt_textbox,
275
+ num_inference_steps = sample_step_slider,
276
+ guidance_scale = cfg_scale_slider,
277
+ width = width_slider,
278
+ height = height_slider,
279
+ num_frames = length_slider if not is_image else 1,
280
+ generator = generator,
281
+
282
+ control_video = input_video,
283
+ ).videos
284
+ except Exception as e:
285
+ self.clear_cache()
286
+ if self.lora_model_path != "none":
287
+ self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
288
+ if is_api:
289
+ return "", f"Error. error information is {str(e)}"
290
+ else:
291
+ return gr.update(), gr.update(), f"Error. error information is {str(e)}"
292
+
293
+ self.clear_cache()
294
+ # lora part
295
+ if self.lora_model_path != "none":
296
+ self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
297
+
298
+ save_sample_path = self.save_outputs(
299
+ is_image, length_slider, sample, fps=16
300
+ )
301
+
302
+ if is_image or length_slider == 1:
303
+ if is_api:
304
+ return save_sample_path, "Success"
305
+ else:
306
+ if gradio_version_is_above_4:
307
+ return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success"
308
+ else:
309
+ return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
310
+ else:
311
+ if is_api:
312
+ return save_sample_path, "Success"
313
+ else:
314
+ if gradio_version_is_above_4:
315
+ return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success"
316
+ else:
317
+ return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
318
+
319
+
320
+ class Wan_Controller_Modelscope(Wan_Controller):
321
+ def __init__(self, model_name, model_type, savedir_sample, GPU_memory_mode, scheduler_dict, weight_dtype, config_path):
322
+ # Basic dir
323
+ self.basedir = os.getcwd()
324
+ self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model")
325
+ self.lora_model_path = "none"
326
+ self.base_model_path = "none"
327
+ self.savedir_sample = savedir_sample
328
+ self.scheduler_dict = scheduler_dict
329
+ self.config = OmegaConf.load(config_path)
330
+ self.refresh_personalized_model()
331
+ os.makedirs(self.savedir_sample, exist_ok=True)
332
+
333
+ # model path
334
+ self.model_type = model_type
335
+ self.weight_dtype = weight_dtype
336
+
337
+ self.vae = AutoencoderKLWan.from_pretrained(
338
+ os.path.join(model_name, self.config['vae_kwargs'].get('vae_subpath', 'vae')),
339
+ additional_kwargs=OmegaConf.to_container(self.config['vae_kwargs']),
340
+ ).to(self.weight_dtype)
341
+
342
+ # Get Transformer
343
+ self.transformer = WanTransformer3DModel.from_pretrained(
344
+ os.path.join(model_name, self.config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')),
345
+ transformer_additional_kwargs=OmegaConf.to_container(self.config['transformer_additional_kwargs']),
346
+ low_cpu_mem_usage=True,
347
+ torch_dtype=self.weight_dtype,
348
+ )
349
+
350
+ # Get Tokenizer
351
+ self.tokenizer = AutoTokenizer.from_pretrained(
352
+ os.path.join(model_name, self.config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')),
353
+ )
354
+
355
+ # Get Text encoder
356
+ self.text_encoder = WanT5EncoderModel.from_pretrained(
357
+ os.path.join(model_name, self.config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
358
+ additional_kwargs=OmegaConf.to_container(self.config['text_encoder_kwargs']),
359
+ ).to(self.weight_dtype)
360
+ self.text_encoder = self.text_encoder.eval()
361
+
362
+ if self.transformer.config.in_channels != self.vae.config.latent_channels:
363
+ # Get Clip Image Encoder
364
+ self.clip_image_encoder = CLIPModel.from_pretrained(
365
+ os.path.join(model_name, self.config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')),
366
+ ).to(self.weight_dtype)
367
+ self.clip_image_encoder = self.clip_image_encoder.eval()
368
+ else:
369
+ self.clip_image_encoder = None
370
+
371
+ Choosen_Scheduler = self.scheduler_dict[list(self.scheduler_dict.keys())[0]]
372
+ self.scheduler = Choosen_Scheduler(
373
+ **filter_kwargs(Choosen_Scheduler, OmegaConf.to_container(self.config['scheduler_kwargs']))
374
+ )
375
+
376
+ # Get pipeline
377
+ if self.model_type == "Inpaint":
378
+ if self.transformer.config.in_channels != self.vae.config.latent_channels:
379
+ self.pipeline = WanPipeline(
380
+ vae=self.vae,
381
+ tokenizer=self.tokenizer,
382
+ text_encoder=self.text_encoder,
383
+ transformer=self.transformer,
384
+ scheduler=self.scheduler,
385
+ clip_image_encoder=self.clip_image_encoder,
386
+ )
387
+ else:
388
+ self.pipeline = WanPipeline(
389
+ vae=self.vae,
390
+ tokenizer=self.tokenizer,
391
+ text_encoder=self.text_encoder,
392
+ transformer=self.transformer,
393
+ scheduler=self.scheduler,
394
+ )
395
+ else:
396
+ raise ValueError("Not support now")
397
+
398
+ if GPU_memory_mode == "sequential_cpu_offload":
399
+ replace_parameters_by_name(self.transformer, ["modulation",], device="cuda")
400
+ self.transformer.freqs = self.transformer.freqs.to(device="cuda")
401
+ self.pipeline.enable_sequential_cpu_offload()
402
+ elif GPU_memory_mode == "model_cpu_offload_and_qfloat8":
403
+ convert_model_weight_to_float8(self.transformer, exclude_module_name=["modulation",])
404
+ convert_weight_dtype_wrapper(self.transformer, self.weight_dtype)
405
+ self.pipeline.enable_model_cpu_offload()
406
+ else:
407
+ self.pipeline.enable_model_cpu_offload()
408
+ print("Update diffusion transformer done")
409
+
410
+ Wan_Controller_EAS = Fun_Controller_EAS
411
+
412
+ def ui(GPU_memory_mode, scheduler_dict, weight_dtype, config_path):
413
+ controller = Wan_Controller(GPU_memory_mode, scheduler_dict, weight_dtype, config_path)
414
+
415
+ with gr.Blocks(css=css) as demo:
416
+ gr.Markdown(
417
+ """
418
+ # Wan:
419
+ """
420
+ )
421
+ with gr.Column(variant="panel"):
422
+ model_type = create_model_type(visible=True)
423
+ diffusion_transformer_dropdown, diffusion_transformer_refresh_button = \
424
+ create_model_checkpoints(controller, visible=True)
425
+ base_model_dropdown, lora_model_dropdown, lora_alpha_slider, personalized_refresh_button = \
426
+ create_finetune_models_checkpoints(controller, visible=True)
427
+
428
+ with gr.Column(variant="panel"):
429
+ prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走")
430
+
431
+ with gr.Row():
432
+ with gr.Column():
433
+ sampler_dropdown, sample_step_slider = create_samplers(controller)
434
+
435
+ resize_method, width_slider, height_slider, base_resolution = create_height_width(
436
+ default_height = 480, default_width = 832, maximum_height = 1344,
437
+ maximum_width = 1344,
438
+ )
439
+ generation_method, length_slider, overlap_video_length, partial_video_length = \
440
+ create_generation_methods_and_video_length(
441
+ ["Video Generation", "Image Generation"],
442
+ default_video_length=81,
443
+ maximum_video_length=81,
444
+ )
445
+ image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video = create_generation_method(
446
+ ["Text to Video (文本到视频)", "Image to Video (图片到视频)"], prompt_textbox, support_end_image=False
447
+ )
448
+ cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4)
449
+
450
+ generate_button = gr.Button(value="Generate (生成)", variant='primary')
451
+
452
+ result_image, result_video, infer_progress = create_ui_outputs()
453
+
454
+ model_type.change(
455
+ fn=controller.update_model_type,
456
+ inputs=[model_type],
457
+ outputs=[]
458
+ )
459
+
460
+ def upload_generation_method(generation_method):
461
+ if generation_method == "Video Generation":
462
+ return [gr.update(visible=True, maximum=81, value=81, interactive=True), gr.update(visible=False), gr.update(visible=False)]
463
+ elif generation_method == "Image Generation":
464
+ return [gr.update(minimum=1, maximum=1, value=1, interactive=False), gr.update(visible=False), gr.update(visible=False)]
465
+ else:
466
+ return [gr.update(visible=True, maximum=1344), gr.update(visible=True), gr.update(visible=True)]
467
+ generation_method.change(
468
+ upload_generation_method, generation_method, [length_slider, overlap_video_length, partial_video_length]
469
+ )
470
+
471
+ def upload_source_method(source_method):
472
+ if source_method == "Text to Video (文本到视频)":
473
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
474
+ elif source_method == "Image to Video (图片到视频)":
475
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
476
+ elif source_method == "Video to Video (视频到视频)":
477
+ return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)]
478
+ else:
479
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()]
480
+ source_method.change(
481
+ upload_source_method, source_method, [
482
+ image_to_video_col, video_to_video_col, control_video_col, start_image, end_image,
483
+ validation_video, validation_video_mask, control_video
484
+ ]
485
+ )
486
+
487
+ def upload_resize_method(resize_method):
488
+ if resize_method == "Generate by":
489
+ return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
490
+ else:
491
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
492
+ resize_method.change(
493
+ upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
494
+ )
495
+
496
+ generate_button.click(
497
+ fn=controller.generate,
498
+ inputs=[
499
+ diffusion_transformer_dropdown,
500
+ base_model_dropdown,
501
+ lora_model_dropdown,
502
+ lora_alpha_slider,
503
+ prompt_textbox,
504
+ negative_prompt_textbox,
505
+ sampler_dropdown,
506
+ sample_step_slider,
507
+ resize_method,
508
+ width_slider,
509
+ height_slider,
510
+ base_resolution,
511
+ generation_method,
512
+ length_slider,
513
+ overlap_video_length,
514
+ partial_video_length,
515
+ cfg_scale_slider,
516
+ start_image,
517
+ end_image,
518
+ validation_video,
519
+ validation_video_mask,
520
+ control_video,
521
+ denoise_strength,
522
+ seed_textbox,
523
+ ],
524
+ outputs=[result_image, result_video, infer_progress]
525
+ )
526
+ return demo, controller
527
+
528
+ def ui_modelscope(model_name, model_type, savedir_sample, GPU_memory_mode, scheduler_dict, weight_dtype, config_path):
529
+ controller = Wan_Controller_Modelscope(model_name, model_type, savedir_sample, GPU_memory_mode, scheduler_dict, weight_dtype, config_path)
530
+
531
+ with gr.Blocks(css=css) as demo:
532
+ gr.Markdown(
533
+ """
534
+ # Wan:
535
+ """
536
+ )
537
+ with gr.Column(variant="panel"):
538
+ model_type = create_fake_model_type(visible=True)
539
+ diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True)
540
+ base_model_dropdown, lora_model_dropdown, lora_alpha_slider = create_fake_finetune_models_checkpoints(visible=True)
541
+
542
+ with gr.Column(variant="panel"):
543
+ prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走")
544
+
545
+ with gr.Row():
546
+ with gr.Column():
547
+ sampler_dropdown, sample_step_slider = create_samplers(controller)
548
+
549
+ resize_method, width_slider, height_slider, base_resolution = create_height_width(
550
+ default_height = 480, default_width = 832, maximum_height = 1344,
551
+ maximum_width = 1344,
552
+ )
553
+ generation_method, length_slider, overlap_video_length, partial_video_length = \
554
+ create_generation_methods_and_video_length(
555
+ ["Video Generation", "Image Generation"],
556
+ default_video_length=81,
557
+ maximum_video_length=81,
558
+ )
559
+ image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video = create_generation_method(
560
+ ["Text to Video (文本到视频)", "Image to Video (图片到视频)"], prompt_textbox
561
+ )
562
+ cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4)
563
+
564
+ generate_button = gr.Button(value="Generate (生成)", variant='primary')
565
+
566
+ result_image, result_video, infer_progress = create_ui_outputs()
567
+
568
+ def upload_generation_method(generation_method):
569
+ if generation_method == "Video Generation":
570
+ return gr.update(visible=True, minimum=1, maximum=81, value=81, interactive=True)
571
+ elif generation_method == "Image Generation":
572
+ return gr.update(minimum=1, maximum=1, value=1, interactive=False)
573
+ generation_method.change(
574
+ upload_generation_method, generation_method, [length_slider]
575
+ )
576
+
577
+ def upload_source_method(source_method):
578
+ if source_method == "Text to Video (文本到视频)":
579
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
580
+ elif source_method == "Image to Video (图片到视频)":
581
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
582
+ elif source_method == "Video to Video (视频到视频)":
583
+ return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)]
584
+ else:
585
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()]
586
+ source_method.change(
587
+ upload_source_method, source_method, [
588
+ image_to_video_col, video_to_video_col, control_video_col, start_image, end_image,
589
+ validation_video, validation_video_mask, control_video
590
+ ]
591
+ )
592
+
593
+ def upload_resize_method(resize_method):
594
+ if resize_method == "Generate by":
595
+ return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
596
+ else:
597
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
598
+ resize_method.change(
599
+ upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
600
+ )
601
+
602
+ generate_button.click(
603
+ fn=controller.generate,
604
+ inputs=[
605
+ diffusion_transformer_dropdown,
606
+ base_model_dropdown,
607
+ lora_model_dropdown,
608
+ lora_alpha_slider,
609
+ prompt_textbox,
610
+ negative_prompt_textbox,
611
+ sampler_dropdown,
612
+ sample_step_slider,
613
+ resize_method,
614
+ width_slider,
615
+ height_slider,
616
+ base_resolution,
617
+ generation_method,
618
+ length_slider,
619
+ overlap_video_length,
620
+ partial_video_length,
621
+ cfg_scale_slider,
622
+ start_image,
623
+ end_image,
624
+ validation_video,
625
+ validation_video_mask,
626
+ control_video,
627
+ denoise_strength,
628
+ seed_textbox,
629
+ ],
630
+ outputs=[result_image, result_video, infer_progress]
631
+ )
632
+ return demo, controller
633
+
634
+ def ui_eas(model_name, scheduler_dict, savedir_sample, config_path):
635
+ controller = Wan_Controller_EAS(model_name, scheduler_dict, savedir_sample)
636
+
637
+ with gr.Blocks(css=css) as demo:
638
+ gr.Markdown(
639
+ """
640
+ # Wan:
641
+ """
642
+ )
643
+ with gr.Column(variant="panel"):
644
+ diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True)
645
+ base_model_dropdown, lora_model_dropdown, lora_alpha_slider = create_fake_finetune_models_checkpoints(visible=True)
646
+
647
+ with gr.Column(variant="panel"):
648
+ prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走")
649
+
650
+ with gr.Row():
651
+ with gr.Column():
652
+ sampler_dropdown, sample_step_slider = create_samplers(controller, maximum_step=50)
653
+
654
+ resize_method, width_slider, height_slider, base_resolution = create_fake_height_width(
655
+ default_height = 480, default_width = 832, maximum_height = 1344,
656
+ maximum_width = 1344,
657
+ )
658
+ generation_method, length_slider, overlap_video_length, partial_video_length = \
659
+ create_generation_methods_and_video_length(
660
+ ["Video Generation", "Image Generation"],
661
+ default_video_length=81,
662
+ maximum_video_length=81,
663
+ )
664
+ image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video = create_generation_method(
665
+ ["Text to Video (文本到视频)", "Image to Video (图片到视频)"], prompt_textbox
666
+ )
667
+
668
+ cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4)
669
+
670
+ generate_button = gr.Button(value="Generate (生成)", variant='primary')
671
+
672
+ result_image, result_video, infer_progress = create_ui_outputs()
673
+
674
+ def upload_generation_method(generation_method):
675
+ if generation_method == "Video Generation":
676
+ return gr.update(visible=True, minimum=5, maximum=85, value=49, interactive=True)
677
+ elif generation_method == "Image Generation":
678
+ return gr.update(minimum=1, maximum=1, value=1, interactive=False)
679
+ generation_method.change(
680
+ upload_generation_method, generation_method, [length_slider]
681
+ )
682
+
683
+ def upload_source_method(source_method):
684
+ if source_method == "Text to Video (文本到视频)":
685
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)]
686
+ elif source_method == "Image to Video (图片到视频)":
687
+ return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None)]
688
+ else:
689
+ return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(), gr.update()]
690
+ source_method.change(
691
+ upload_source_method, source_method, [image_to_video_col, video_to_video_col, start_image, end_image, validation_video, validation_video_mask]
692
+ )
693
+
694
+ def upload_resize_method(resize_method):
695
+ if resize_method == "Generate by":
696
+ return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
697
+ else:
698
+ return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
699
+ resize_method.change(
700
+ upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
701
+ )
702
+
703
+ generate_button.click(
704
+ fn=controller.generate,
705
+ inputs=[
706
+ diffusion_transformer_dropdown,
707
+ base_model_dropdown,
708
+ lora_model_dropdown,
709
+ lora_alpha_slider,
710
+ prompt_textbox,
711
+ negative_prompt_textbox,
712
+ sampler_dropdown,
713
+ sample_step_slider,
714
+ resize_method,
715
+ width_slider,
716
+ height_slider,
717
+ base_resolution,
718
+ generation_method,
719
+ length_slider,
720
+ cfg_scale_slider,
721
+ start_image,
722
+ end_image,
723
+ validation_video,
724
+ validation_video_mask,
725
+ denoise_strength,
726
+ seed_textbox,
727
+ ],
728
+ outputs=[result_image, result_video, infer_progress]
729
+ )
730
+ return demo, controller
cogvideox/utils/__init__.py ADDED
File without changes
cogvideox/utils/discrete_sampler.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified from https://github.com/THUDM/CogVideo/blob/3710a612d8760f5cdb1741befeebb65b9e0f2fe0/sat/sgm/modules/diffusionmodules/sigma_sampling.py
2
+ """
3
+ import torch
4
+
5
+ class DiscreteSampling:
6
+ def __init__(self, num_idx, uniform_sampling=False):
7
+ self.num_idx = num_idx
8
+ self.uniform_sampling = uniform_sampling
9
+ self.is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
10
+
11
+ if self.is_distributed and self.uniform_sampling:
12
+ world_size = torch.distributed.get_world_size()
13
+ self.rank = torch.distributed.get_rank()
14
+
15
+ i = 1
16
+ while True:
17
+ if world_size % i != 0 or num_idx % (world_size // i) != 0:
18
+ i += 1
19
+ else:
20
+ self.group_num = world_size // i
21
+ break
22
+ assert self.group_num > 0
23
+ assert world_size % self.group_num == 0
24
+ # the number of rank in one group
25
+ self.group_width = world_size // self.group_num
26
+ self.sigma_interval = self.num_idx // self.group_num
27
+ print('rank=%d world_size=%d group_num=%d group_width=%d sigma_interval=%s' % (
28
+ self.rank, world_size, self.group_num,
29
+ self.group_width, self.sigma_interval))
30
+
31
+ def __call__(self, n_samples, generator=None, device=None):
32
+ if self.is_distributed and self.uniform_sampling:
33
+ group_index = self.rank // self.group_width
34
+ idx = torch.randint(
35
+ group_index * self.sigma_interval,
36
+ (group_index + 1) * self.sigma_interval,
37
+ (n_samples,),
38
+ generator=generator, device=device,
39
+ )
40
+ print('proc[%d] idx=%s' % (self.rank, idx))
41
+ else:
42
+ idx = torch.randint(
43
+ 0, self.num_idx, (n_samples,),
44
+ generator=generator, device=device,
45
+ )
46
+ return idx
cogvideox/utils/fp8_optimization.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modified from https://github.com/kijai/ComfyUI-MochiWrapper
2
+ """
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ def autocast_model_forward(cls, origin_dtype, *inputs, **kwargs):
7
+ weight_dtype = cls.weight.dtype
8
+ cls.to(origin_dtype)
9
+
10
+ # Convert all inputs to the original dtype
11
+ inputs = [input.to(origin_dtype) for input in inputs]
12
+ out = cls.original_forward(*inputs, **kwargs)
13
+
14
+ cls.to(weight_dtype)
15
+ return out
16
+
17
+ def replace_parameters_by_name(module, name_keywords, device):
18
+ from torch import nn
19
+ for name, param in list(module.named_parameters(recurse=False)):
20
+ if any(keyword in name for keyword in name_keywords):
21
+ if isinstance(param, nn.Parameter):
22
+ tensor = param.data
23
+ delattr(module, name)
24
+ setattr(module, name, tensor.to(device=device))
25
+ for child_name, child_module in module.named_children():
26
+ replace_parameters_by_name(child_module, name_keywords, device)
27
+
28
+ def convert_model_weight_to_float8(model, exclude_module_name=['embed_tokens']):
29
+ for name, module in model.named_modules():
30
+ flag = False
31
+ for _exclude_module_name in exclude_module_name:
32
+ if _exclude_module_name in name:
33
+ flag = True
34
+ if flag:
35
+ continue
36
+ for param_name, param in module.named_parameters():
37
+ flag = False
38
+ for _exclude_module_name in exclude_module_name:
39
+ if _exclude_module_name in param_name:
40
+ flag = True
41
+ if flag:
42
+ continue
43
+ param.data = param.data.to(torch.float8_e4m3fn)
44
+
45
+ def convert_weight_dtype_wrapper(module, origin_dtype):
46
+ for name, module in module.named_modules():
47
+ if name == "" or "embed_tokens" in name:
48
+ continue
49
+ original_forward = module.forward
50
+ if hasattr(module, "weight") and module.weight is not None:
51
+ setattr(module, "original_forward", original_forward)
52
+ setattr(
53
+ module,
54
+ "forward",
55
+ lambda *inputs, m=module, **kwargs: autocast_model_forward(m, origin_dtype, *inputs, **kwargs)
56
+ )
cogvideox/utils/lora_utils.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LoRA network module
2
+ # reference:
3
+ # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
4
+ # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
5
+ # https://github.com/bmaltais/kohya_ss
6
+
7
+ import hashlib
8
+ import math
9
+ import os
10
+ from collections import defaultdict
11
+ from io import BytesIO
12
+ from typing import List, Optional, Type, Union
13
+
14
+ import safetensors.torch
15
+ import torch
16
+ import torch.utils.checkpoint
17
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
18
+ from safetensors.torch import load_file
19
+ from transformers import T5EncoderModel
20
+
21
+
22
+ class LoRAModule(torch.nn.Module):
23
+ """
24
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ lora_name,
30
+ org_module: torch.nn.Module,
31
+ multiplier=1.0,
32
+ lora_dim=4,
33
+ alpha=1,
34
+ dropout=None,
35
+ rank_dropout=None,
36
+ module_dropout=None,
37
+ ):
38
+ """if alpha == 0 or None, alpha is rank (no scaling)."""
39
+ super().__init__()
40
+ self.lora_name = lora_name
41
+
42
+ if org_module.__class__.__name__ == "Conv2d":
43
+ in_dim = org_module.in_channels
44
+ out_dim = org_module.out_channels
45
+ else:
46
+ in_dim = org_module.in_features
47
+ out_dim = org_module.out_features
48
+
49
+ self.lora_dim = lora_dim
50
+ if org_module.__class__.__name__ == "Conv2d":
51
+ kernel_size = org_module.kernel_size
52
+ stride = org_module.stride
53
+ padding = org_module.padding
54
+ self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
55
+ self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
56
+ else:
57
+ self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
58
+ self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
59
+
60
+ if type(alpha) == torch.Tensor:
61
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
62
+ alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
63
+ self.scale = alpha / self.lora_dim
64
+ self.register_buffer("alpha", torch.tensor(alpha))
65
+
66
+ # same as microsoft's
67
+ torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
68
+ torch.nn.init.zeros_(self.lora_up.weight)
69
+
70
+ self.multiplier = multiplier
71
+ self.org_module = org_module # remove in applying
72
+ self.dropout = dropout
73
+ self.rank_dropout = rank_dropout
74
+ self.module_dropout = module_dropout
75
+
76
+ def apply_to(self):
77
+ self.org_forward = self.org_module.forward
78
+ self.org_module.forward = self.forward
79
+ del self.org_module
80
+
81
+ def forward(self, x, *args, **kwargs):
82
+ weight_dtype = x.dtype
83
+ org_forwarded = self.org_forward(x)
84
+
85
+ # module dropout
86
+ if self.module_dropout is not None and self.training:
87
+ if torch.rand(1) < self.module_dropout:
88
+ return org_forwarded
89
+
90
+ lx = self.lora_down(x.to(self.lora_down.weight.dtype))
91
+
92
+ # normal dropout
93
+ if self.dropout is not None and self.training:
94
+ lx = torch.nn.functional.dropout(lx, p=self.dropout)
95
+
96
+ # rank dropout
97
+ if self.rank_dropout is not None and self.training:
98
+ mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
99
+ if len(lx.size()) == 3:
100
+ mask = mask.unsqueeze(1) # for Text Encoder
101
+ elif len(lx.size()) == 4:
102
+ mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
103
+ lx = lx * mask
104
+
105
+ # scaling for rank dropout: treat as if the rank is changed
106
+ scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
107
+ else:
108
+ scale = self.scale
109
+
110
+ lx = self.lora_up(lx)
111
+
112
+ return org_forwarded.to(weight_dtype) + lx.to(weight_dtype) * self.multiplier * scale
113
+
114
+
115
+ def addnet_hash_legacy(b):
116
+ """Old model hash used by sd-webui-additional-networks for .safetensors format files"""
117
+ m = hashlib.sha256()
118
+
119
+ b.seek(0x100000)
120
+ m.update(b.read(0x10000))
121
+ return m.hexdigest()[0:8]
122
+
123
+
124
+ def addnet_hash_safetensors(b):
125
+ """New model hash used by sd-webui-additional-networks for .safetensors format files"""
126
+ hash_sha256 = hashlib.sha256()
127
+ blksize = 1024 * 1024
128
+
129
+ b.seek(0)
130
+ header = b.read(8)
131
+ n = int.from_bytes(header, "little")
132
+
133
+ offset = n + 8
134
+ b.seek(offset)
135
+ for chunk in iter(lambda: b.read(blksize), b""):
136
+ hash_sha256.update(chunk)
137
+
138
+ return hash_sha256.hexdigest()
139
+
140
+
141
+ def precalculate_safetensors_hashes(tensors, metadata):
142
+ """Precalculate the model hashes needed by sd-webui-additional-networks to
143
+ save time on indexing the model later."""
144
+
145
+ # Because writing user metadata to the file can change the result of
146
+ # sd_models.model_hash(), only retain the training metadata for purposes of
147
+ # calculating the hash, as they are meant to be immutable
148
+ metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
149
+
150
+ bytes = safetensors.torch.save(tensors, metadata)
151
+ b = BytesIO(bytes)
152
+
153
+ model_hash = addnet_hash_safetensors(b)
154
+ legacy_hash = addnet_hash_legacy(b)
155
+ return model_hash, legacy_hash
156
+
157
+
158
+ class LoRANetwork(torch.nn.Module):
159
+ TRANSFORMER_TARGET_REPLACE_MODULE = ["CogVideoXTransformer3DModel", "WanTransformer3DModel"]
160
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["T5LayerSelfAttention", "T5LayerFF", "BertEncoder", "T5SelfAttention", "T5CrossAttention"]
161
+ LORA_PREFIX_TRANSFORMER = "lora_unet"
162
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
163
+ def __init__(
164
+ self,
165
+ text_encoder: Union[List[T5EncoderModel], T5EncoderModel],
166
+ unet,
167
+ multiplier: float = 1.0,
168
+ lora_dim: int = 4,
169
+ alpha: float = 1,
170
+ dropout: Optional[float] = None,
171
+ module_class: Type[object] = LoRAModule,
172
+ add_lora_in_attn_temporal: bool = False,
173
+ varbose: Optional[bool] = False,
174
+ ) -> None:
175
+ super().__init__()
176
+ self.multiplier = multiplier
177
+
178
+ self.lora_dim = lora_dim
179
+ self.alpha = alpha
180
+ self.dropout = dropout
181
+
182
+ print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
183
+ print(f"neuron dropout: p={self.dropout}")
184
+
185
+ # create module instances
186
+ def create_modules(
187
+ is_unet: bool,
188
+ root_module: torch.nn.Module,
189
+ target_replace_modules: List[torch.nn.Module],
190
+ ) -> List[LoRAModule]:
191
+ prefix = (
192
+ self.LORA_PREFIX_TRANSFORMER
193
+ if is_unet
194
+ else self.LORA_PREFIX_TEXT_ENCODER
195
+ )
196
+ loras = []
197
+ skipped = []
198
+ for name, module in root_module.named_modules():
199
+ if module.__class__.__name__ in target_replace_modules:
200
+ for child_name, child_module in module.named_modules():
201
+ is_linear = child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear"
202
+ is_conv2d = child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv"
203
+ is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
204
+
205
+ if not add_lora_in_attn_temporal:
206
+ if "attn_temporal" in child_name:
207
+ continue
208
+
209
+ if is_linear or is_conv2d:
210
+ lora_name = prefix + "." + name + "." + child_name
211
+ lora_name = lora_name.replace(".", "_")
212
+
213
+ dim = None
214
+ alpha = None
215
+
216
+ if is_linear or is_conv2d_1x1:
217
+ dim = self.lora_dim
218
+ alpha = self.alpha
219
+
220
+ if dim is None or dim == 0:
221
+ if is_linear or is_conv2d_1x1:
222
+ skipped.append(lora_name)
223
+ continue
224
+
225
+ lora = module_class(
226
+ lora_name,
227
+ child_module,
228
+ self.multiplier,
229
+ dim,
230
+ alpha,
231
+ dropout=dropout,
232
+ )
233
+ loras.append(lora)
234
+ return loras, skipped
235
+
236
+ text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
237
+
238
+ self.text_encoder_loras = []
239
+ skipped_te = []
240
+ for i, text_encoder in enumerate(text_encoders):
241
+ if text_encoder is not None:
242
+ text_encoder_loras, skipped = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
243
+ self.text_encoder_loras.extend(text_encoder_loras)
244
+ skipped_te += skipped
245
+ print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
246
+
247
+ self.unet_loras, skipped_un = create_modules(True, unet, LoRANetwork.TRANSFORMER_TARGET_REPLACE_MODULE)
248
+ print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
249
+
250
+ # assertion
251
+ names = set()
252
+ for lora in self.text_encoder_loras + self.unet_loras:
253
+ assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
254
+ names.add(lora.lora_name)
255
+
256
+ def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
257
+ if apply_text_encoder:
258
+ print("enable LoRA for text encoder")
259
+ else:
260
+ self.text_encoder_loras = []
261
+
262
+ if apply_unet:
263
+ print("enable LoRA for U-Net")
264
+ else:
265
+ self.unet_loras = []
266
+
267
+ for lora in self.text_encoder_loras + self.unet_loras:
268
+ lora.apply_to()
269
+ self.add_module(lora.lora_name, lora)
270
+
271
+ def set_multiplier(self, multiplier):
272
+ self.multiplier = multiplier
273
+ for lora in self.text_encoder_loras + self.unet_loras:
274
+ lora.multiplier = self.multiplier
275
+
276
+ def load_weights(self, file):
277
+ if os.path.splitext(file)[1] == ".safetensors":
278
+ from safetensors.torch import load_file
279
+
280
+ weights_sd = load_file(file)
281
+ else:
282
+ weights_sd = torch.load(file, map_location="cpu")
283
+ info = self.load_state_dict(weights_sd, False)
284
+ return info
285
+
286
+ def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
287
+ self.requires_grad_(True)
288
+ all_params = []
289
+
290
+ def enumerate_params(loras):
291
+ params = []
292
+ for lora in loras:
293
+ params.extend(lora.parameters())
294
+ return params
295
+
296
+ if self.text_encoder_loras:
297
+ param_data = {"params": enumerate_params(self.text_encoder_loras)}
298
+ if text_encoder_lr is not None:
299
+ param_data["lr"] = text_encoder_lr
300
+ all_params.append(param_data)
301
+
302
+ if self.unet_loras:
303
+ param_data = {"params": enumerate_params(self.unet_loras)}
304
+ if unet_lr is not None:
305
+ param_data["lr"] = unet_lr
306
+ all_params.append(param_data)
307
+
308
+ return all_params
309
+
310
+ def enable_gradient_checkpointing(self):
311
+ pass
312
+
313
+ def get_trainable_params(self):
314
+ return self.parameters()
315
+
316
+ def save_weights(self, file, dtype, metadata):
317
+ if metadata is not None and len(metadata) == 0:
318
+ metadata = None
319
+
320
+ state_dict = self.state_dict()
321
+
322
+ if dtype is not None:
323
+ for key in list(state_dict.keys()):
324
+ v = state_dict[key]
325
+ v = v.detach().clone().to("cpu").to(dtype)
326
+ state_dict[key] = v
327
+
328
+ if os.path.splitext(file)[1] == ".safetensors":
329
+ from safetensors.torch import save_file
330
+
331
+ # Precalculate model hashes to save time on indexing
332
+ if metadata is None:
333
+ metadata = {}
334
+ model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata)
335
+ metadata["sshs_model_hash"] = model_hash
336
+ metadata["sshs_legacy_hash"] = legacy_hash
337
+
338
+ save_file(state_dict, file, metadata)
339
+ else:
340
+ torch.save(state_dict, file)
341
+
342
+ def create_network(
343
+ multiplier: float,
344
+ network_dim: Optional[int],
345
+ network_alpha: Optional[float],
346
+ text_encoder: Union[T5EncoderModel, List[T5EncoderModel]],
347
+ transformer,
348
+ neuron_dropout: Optional[float] = None,
349
+ add_lora_in_attn_temporal: bool = False,
350
+ **kwargs,
351
+ ):
352
+ if network_dim is None:
353
+ network_dim = 4 # default
354
+ if network_alpha is None:
355
+ network_alpha = 1.0
356
+
357
+ network = LoRANetwork(
358
+ text_encoder,
359
+ transformer,
360
+ multiplier=multiplier,
361
+ lora_dim=network_dim,
362
+ alpha=network_alpha,
363
+ dropout=neuron_dropout,
364
+ add_lora_in_attn_temporal=add_lora_in_attn_temporal,
365
+ varbose=True,
366
+ )
367
+ return network
368
+
369
+ def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False):
370
+ LORA_PREFIX_TRANSFORMER = "lora_unet"
371
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
372
+ if state_dict is None:
373
+ state_dict = load_file(lora_path, device=device)
374
+ else:
375
+ state_dict = state_dict
376
+ updates = defaultdict(dict)
377
+ for key, value in state_dict.items():
378
+ layer, elem = key.split('.', 1)
379
+ updates[layer][elem] = value
380
+
381
+ for layer, elems in updates.items():
382
+
383
+ if "lora_te" in layer:
384
+ if transformer_only:
385
+ continue
386
+ else:
387
+ layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
388
+ curr_layer = pipeline.text_encoder
389
+ else:
390
+ layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_")
391
+ curr_layer = pipeline.transformer
392
+
393
+ try:
394
+ curr_layer = curr_layer.__getattr__("_".join(layer_infos[1:]))
395
+ except Exception:
396
+ temp_name = layer_infos.pop(0)
397
+ while len(layer_infos) > -1:
398
+ try:
399
+ curr_layer = curr_layer.__getattr__(temp_name + "_" + "_".join(layer_infos))
400
+ break
401
+ except Exception:
402
+ try:
403
+ curr_layer = curr_layer.__getattr__(temp_name)
404
+ if len(layer_infos) > 0:
405
+ temp_name = layer_infos.pop(0)
406
+ elif len(layer_infos) == 0:
407
+ break
408
+ except Exception:
409
+ if len(layer_infos) == 0:
410
+ print('Error loading layer')
411
+ if len(temp_name) > 0:
412
+ temp_name += "_" + layer_infos.pop(0)
413
+ else:
414
+ temp_name = layer_infos.pop(0)
415
+
416
+ origin_dtype = curr_layer.weight.data.dtype
417
+ origin_device = curr_layer.weight.data.device
418
+
419
+ curr_layer = curr_layer.to(device, dtype)
420
+ weight_up = elems['lora_up.weight'].to(device, dtype)
421
+ weight_down = elems['lora_down.weight'].to(device, dtype)
422
+
423
+ if 'alpha' in elems.keys():
424
+ alpha = elems['alpha'].item() / weight_up.shape[1]
425
+ else:
426
+ alpha = 1.0
427
+
428
+ if len(weight_up.shape) == 4:
429
+ curr_layer.weight.data += multiplier * alpha * torch.mm(
430
+ weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)
431
+ ).unsqueeze(2).unsqueeze(3)
432
+ else:
433
+ curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
434
+ curr_layer = curr_layer.to(origin_device, origin_dtype)
435
+
436
+ return pipeline
437
+
438
+ # TODO: Refactor with merge_lora.
439
+ def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.float32):
440
+ """Unmerge state_dict in LoRANetwork from the pipeline in diffusers."""
441
+ LORA_PREFIX_UNET = "lora_unet"
442
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
443
+ state_dict = load_file(lora_path, device=device)
444
+
445
+ updates = defaultdict(dict)
446
+ for key, value in state_dict.items():
447
+ layer, elem = key.split('.', 1)
448
+ updates[layer][elem] = value
449
+
450
+ for layer, elems in updates.items():
451
+
452
+ if "lora_te" in layer:
453
+ layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
454
+ curr_layer = pipeline.text_encoder
455
+ else:
456
+ layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
457
+ curr_layer = pipeline.transformer
458
+
459
+ try:
460
+ curr_layer = curr_layer.__getattr__("_".join(layer_infos[1:]))
461
+ except Exception:
462
+ temp_name = layer_infos.pop(0)
463
+ while len(layer_infos) > -1:
464
+ try:
465
+ curr_layer = curr_layer.__getattr__(temp_name + "_" + "_".join(layer_infos))
466
+ break
467
+ except Exception:
468
+ try:
469
+ curr_layer = curr_layer.__getattr__(temp_name)
470
+ if len(layer_infos) > 0:
471
+ temp_name = layer_infos.pop(0)
472
+ elif len(layer_infos) == 0:
473
+ break
474
+ except Exception:
475
+ if len(layer_infos) == 0:
476
+ print('Error loading layer')
477
+ if len(temp_name) > 0:
478
+ temp_name += "_" + layer_infos.pop(0)
479
+ else:
480
+ temp_name = layer_infos.pop(0)
481
+
482
+ origin_dtype = curr_layer.weight.data.dtype
483
+ origin_device = curr_layer.weight.data.device
484
+
485
+ curr_layer = curr_layer.to(device, dtype)
486
+ weight_up = elems['lora_up.weight'].to(device, dtype)
487
+ weight_down = elems['lora_down.weight'].to(device, dtype)
488
+
489
+ if 'alpha' in elems.keys():
490
+ alpha = elems['alpha'].item() / weight_up.shape[1]
491
+ else:
492
+ alpha = 1.0
493
+
494
+ if len(weight_up.shape) == 4:
495
+ curr_layer.weight.data -= multiplier * alpha * torch.mm(
496
+ weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)
497
+ ).unsqueeze(2).unsqueeze(3)
498
+ else:
499
+ curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up, weight_down)
500
+ curr_layer = curr_layer.to(origin_device, origin_dtype)
501
+
502
+ return pipeline
cogvideox/utils/utils.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import imageio
4
+ import inspect
5
+ import numpy as np
6
+ import torch
7
+ import torchvision
8
+ import cv2
9
+ from einops import rearrange
10
+ from PIL import Image
11
+
12
+ def filter_kwargs(cls, kwargs):
13
+ sig = inspect.signature(cls.__init__)
14
+ valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
15
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
16
+ return filtered_kwargs
17
+
18
+ def get_width_and_height_from_image_and_base_resolution(image, base_resolution):
19
+ target_pixels = int(base_resolution) * int(base_resolution)
20
+ original_width, original_height = Image.open(image).size
21
+ ratio = (target_pixels / (original_width * original_height)) ** 0.5
22
+ width_slider = round(original_width * ratio)
23
+ height_slider = round(original_height * ratio)
24
+ return height_slider, width_slider
25
+
26
+ def color_transfer(sc, dc):
27
+ """
28
+ Transfer color distribution from of sc, referred to dc.
29
+
30
+ Args:
31
+ sc (numpy.ndarray): input image to be transfered.
32
+ dc (numpy.ndarray): reference image
33
+
34
+ Returns:
35
+ numpy.ndarray: Transferred color distribution on the sc.
36
+ """
37
+
38
+ def get_mean_and_std(img):
39
+ x_mean, x_std = cv2.meanStdDev(img)
40
+ x_mean = np.hstack(np.around(x_mean, 2))
41
+ x_std = np.hstack(np.around(x_std, 2))
42
+ return x_mean, x_std
43
+
44
+ sc = cv2.cvtColor(sc, cv2.COLOR_RGB2LAB)
45
+ s_mean, s_std = get_mean_and_std(sc)
46
+ dc = cv2.cvtColor(dc, cv2.COLOR_RGB2LAB)
47
+ t_mean, t_std = get_mean_and_std(dc)
48
+ img_n = ((sc - s_mean) * (t_std / s_std)) + t_mean
49
+ np.putmask(img_n, img_n > 255, 255)
50
+ np.putmask(img_n, img_n < 0, 0)
51
+ dst = cv2.cvtColor(cv2.convertScaleAbs(img_n), cv2.COLOR_LAB2RGB)
52
+ return dst
53
+
54
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=12, imageio_backend=True, color_transfer_post_process=False):
55
+ videos = rearrange(videos, "b c t h w -> t b c h w")
56
+ outputs = []
57
+ for x in videos:
58
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
59
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
60
+ if rescale:
61
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
62
+ x = (x * 255).numpy().astype(np.uint8)
63
+ outputs.append(Image.fromarray(x))
64
+
65
+ if color_transfer_post_process:
66
+ for i in range(1, len(outputs)):
67
+ outputs[i] = Image.fromarray(color_transfer(np.uint8(outputs[i]), np.uint8(outputs[0])))
68
+
69
+ os.makedirs(os.path.dirname(path), exist_ok=True)
70
+ if imageio_backend:
71
+ if path.endswith("mp4"):
72
+ imageio.mimsave(path, outputs, fps=fps)
73
+ else:
74
+ imageio.mimsave(path, outputs, duration=(1000 * 1/fps))
75
+ else:
76
+ if path.endswith("mp4"):
77
+ path = path.replace('.mp4', '.gif')
78
+ outputs[0].save(path, format='GIF', append_images=outputs, save_all=True, duration=100, loop=0)
79
+
80
+ def get_image_to_video_latent(validation_image_start, validation_image_end, video_length, sample_size):
81
+ if validation_image_start is not None and validation_image_end is not None:
82
+ if type(validation_image_start) is str and os.path.isfile(validation_image_start):
83
+ image_start = clip_image = Image.open(validation_image_start).convert("RGB")
84
+ image_start = image_start.resize([sample_size[1], sample_size[0]])
85
+ clip_image = clip_image.resize([sample_size[1], sample_size[0]])
86
+ else:
87
+ image_start = clip_image = validation_image_start
88
+ image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start]
89
+ clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image]
90
+
91
+ if type(validation_image_end) is str and os.path.isfile(validation_image_end):
92
+ image_end = Image.open(validation_image_end).convert("RGB")
93
+ image_end = image_end.resize([sample_size[1], sample_size[0]])
94
+ else:
95
+ image_end = validation_image_end
96
+ image_end = [_image_end.resize([sample_size[1], sample_size[0]]) for _image_end in image_end]
97
+
98
+ if type(image_start) is list:
99
+ clip_image = clip_image[0]
100
+ start_video = torch.cat(
101
+ [torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in image_start],
102
+ dim=2
103
+ )
104
+ input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1])
105
+ input_video[:, :, :len(image_start)] = start_video
106
+
107
+ input_video_mask = torch.zeros_like(input_video[:, :1])
108
+ input_video_mask[:, :, len(image_start):] = 255
109
+ else:
110
+ input_video = torch.tile(
111
+ torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0),
112
+ [1, 1, video_length, 1, 1]
113
+ )
114
+ input_video_mask = torch.zeros_like(input_video[:, :1])
115
+ input_video_mask[:, :, 1:] = 255
116
+
117
+ if type(image_end) is list:
118
+ image_end = [_image_end.resize(image_start[0].size if type(image_start) is list else image_start.size) for _image_end in image_end]
119
+ end_video = torch.cat(
120
+ [torch.from_numpy(np.array(_image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_end in image_end],
121
+ dim=2
122
+ )
123
+ input_video[:, :, -len(end_video):] = end_video
124
+
125
+ input_video_mask[:, :, -len(image_end):] = 0
126
+ else:
127
+ image_end = image_end.resize(image_start[0].size if type(image_start) is list else image_start.size)
128
+ input_video[:, :, -1:] = torch.from_numpy(np.array(image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0)
129
+ input_video_mask[:, :, -1:] = 0
130
+
131
+ input_video = input_video / 255
132
+
133
+ elif validation_image_start is not None:
134
+ if type(validation_image_start) is str and os.path.isfile(validation_image_start):
135
+ image_start = clip_image = Image.open(validation_image_start).convert("RGB")
136
+ image_start = image_start.resize([sample_size[1], sample_size[0]])
137
+ clip_image = clip_image.resize([sample_size[1], sample_size[0]])
138
+ else:
139
+ image_start = clip_image = validation_image_start
140
+ image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start]
141
+ clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image]
142
+ image_end = None
143
+
144
+ if type(image_start) is list:
145
+ clip_image = clip_image[0]
146
+ start_video = torch.cat(
147
+ [torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in image_start],
148
+ dim=2
149
+ )
150
+ input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1])
151
+ input_video[:, :, :len(image_start)] = start_video
152
+ input_video = input_video / 255
153
+
154
+ input_video_mask = torch.zeros_like(input_video[:, :1])
155
+ input_video_mask[:, :, len(image_start):] = 255
156
+ else:
157
+ input_video = torch.tile(
158
+ torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0),
159
+ [1, 1, video_length, 1, 1]
160
+ ) / 255
161
+ input_video_mask = torch.zeros_like(input_video[:, :1])
162
+ input_video_mask[:, :, 1:, ] = 255
163
+ else:
164
+ image_start = None
165
+ image_end = None
166
+ input_video = torch.zeros([1, 3, video_length, sample_size[0], sample_size[1]])
167
+ input_video_mask = torch.ones([1, 1, video_length, sample_size[0], sample_size[1]]) * 255
168
+ clip_image = None
169
+
170
+ del image_start
171
+ del image_end
172
+ gc.collect()
173
+
174
+ return input_video, input_video_mask, clip_image
175
+
176
+ def get_video_to_video_latent(input_video_path, video_length, sample_size, fps=None, validation_video_mask=None):
177
+ if isinstance(input_video_path, str):
178
+ cap = cv2.VideoCapture(input_video_path)
179
+ input_video = []
180
+
181
+ original_fps = cap.get(cv2.CAP_PROP_FPS)
182
+ frame_skip = 1 if fps is None else int(original_fps // fps)
183
+
184
+ frame_count = 0
185
+
186
+ while True:
187
+ ret, frame = cap.read()
188
+ if not ret:
189
+ break
190
+
191
+ if frame_count % frame_skip == 0:
192
+ frame = cv2.resize(frame, (sample_size[1], sample_size[0]))
193
+ input_video.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
194
+
195
+ frame_count += 1
196
+
197
+ cap.release()
198
+ else:
199
+ input_video = input_video_path
200
+
201
+ input_video = torch.from_numpy(np.array(input_video))[:video_length]
202
+ input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255
203
+
204
+ if validation_video_mask is not None:
205
+ validation_video_mask = Image.open(validation_video_mask).convert('L').resize((sample_size[1], sample_size[0]))
206
+ input_video_mask = np.where(np.array(validation_video_mask) < 240, 0, 255)
207
+
208
+ input_video_mask = torch.from_numpy(np.array(input_video_mask)).unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0)
209
+ input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1])
210
+ input_video_mask = input_video_mask.to(input_video.device, input_video.dtype)
211
+ else:
212
+ input_video_mask = torch.zeros_like(input_video[:, :1])
213
+ input_video_mask[:, :, :] = 255
214
+
215
+ return input_video, input_video_mask, None
config/wan2.1/wan_civitai.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ format: civitai
2
+ pipeline: Wan
3
+ transformer_additional_kwargs:
4
+ transformer_subpath: ./
5
+ dict_mapping:
6
+ in_dim: in_channels
7
+ dim: hidden_size
8
+
9
+ vae_kwargs:
10
+ vae_subpath: Wan2.1_VAE.pth
11
+ temporal_compression_ratio: 4
12
+ spatial_compression_ratio: 8
13
+
14
+ text_encoder_kwargs:
15
+ text_encoder_subpath: models_t5_umt5-xxl-enc-bf16.pth
16
+ tokenizer_subpath: google/umt5-xxl
17
+ text_length: 512
18
+ vocab: 256384
19
+ dim: 4096
20
+ dim_attn: 4096
21
+ dim_ffn: 10240
22
+ num_heads: 64
23
+ num_layers: 24
24
+ num_buckets: 32
25
+ shared_pos: False
26
+ dropout: 0.0
27
+
28
+ scheduler_kwargs:
29
+ scheduler_subpath: null
30
+ num_train_timesteps: 1000
31
+ shift: 5.0
32
+ use_dynamic_shifting: false
33
+ base_shift: 0.5
34
+ max_shift: 1.15
35
+ base_image_seq_len: 256
36
+ max_image_seq_len: 4096
37
+
38
+ image_encoder_kwargs:
39
+ image_encoder_subpath: models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth
config/zero_stage2_config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {
3
+ "enabled": true
4
+ },
5
+ "train_micro_batch_size_per_gpu": 1,
6
+ "train_batch_size": "auto",
7
+ "gradient_accumulation_steps": "auto",
8
+ "dump_state": true,
9
+ "zero_optimization": {
10
+ "stage": 2,
11
+ "overlap_comm": true,
12
+ "contiguous_gradients": true,
13
+ "sub_group_size": 1e9,
14
+ "reduce_bucket_size": 5e8
15
+ }
16
+ }
config/zero_stage3_config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {
3
+ "enabled": true
4
+ },
5
+ "train_micro_batch_size_per_gpu": 1,
6
+ "train_batch_size": "auto",
7
+ "gradient_accumulation_steps": "auto",
8
+ "gradient_clipping": "auto",
9
+ "steps_per_print": 2000,
10
+ "wall_clock_breakdown": false,
11
+ "zero_optimization": {
12
+ "stage": 3,
13
+ "overlap_comm": true,
14
+ "contiguous_gradients": true,
15
+ "reduce_bucket_size": 5e8,
16
+ "sub_group_size": 1e9,
17
+ "stage3_max_live_parameters": 1e9,
18
+ "stage3_max_reuse_distance": 1e9,
19
+ "stage3_gather_16bit_weights_on_model_save": "auto",
20
+ "offload_optimizer": {
21
+ "device": "none"
22
+ },
23
+ "offload_param": {
24
+ "device": "none"
25
+ }
26
+ }
27
+ }
28
+
examples/cogvideox_fun/app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+
5
+ import torch
6
+
7
+ current_file_path = os.path.abspath(__file__)
8
+ project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))]
9
+ for project_root in project_roots:
10
+ sys.path.insert(0, project_root) if project_root not in sys.path else None
11
+
12
+ from cogvideox.api.api import (infer_forward_api,
13
+ update_diffusion_transformer_api,
14
+ update_edition_api)
15
+ from cogvideox.ui.controller import ddpm_scheduler_dict
16
+ from cogvideox.ui.cogvideox_fun_ui import ui, ui_eas, ui_modelscope
17
+
18
+ if __name__ == "__main__":
19
+ # Choose the ui mode
20
+ ui_mode = "modelscope"
21
+
22
+ # GPU memory mode, which can be choosen in [model_cpu_offload, model_cpu_offload_and_qfloat8, sequential_cpu_offload].
23
+ # model_cpu_offload means that the entire model will be moved to the CPU after use, which can save some GPU memory.
24
+ #
25
+ # model_cpu_offload_and_qfloat8 indicates that the entire model will be moved to the CPU after use,
26
+ # and the transformer model has been quantized to float8, which can save more GPU memory.
27
+ #
28
+ # sequential_cpu_offload means that each layer of the model will be moved to the CPU after use,
29
+ # resulting in slower speeds but saving a large amount of GPU memory.
30
+ GPU_memory_mode = "model_cpu_offload_and_qfloat8"
31
+ # Use torch.float16 if GPU does not support torch.bfloat16
32
+ # ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16
33
+ weight_dtype = torch.bfloat16
34
+
35
+ # Server ip
36
+ server_name = "0.0.0.0"
37
+ server_port = 7861
38
+
39
+ # Params below is used when ui_mode = "modelscope"
40
+ model_name = "models/Diffusion_Transformer/CogVideoX-Fun-V1.1-5b-InP"
41
+ # "Inpaint" or "Control"
42
+ model_type = "Inpaint"
43
+ # Save dir of this model
44
+ savedir_sample = "samples"
45
+
46
+ if ui_mode == "modelscope":
47
+ demo, controller = ui_modelscope(model_name, model_type, savedir_sample, GPU_memory_mode, ddpm_scheduler_dict, weight_dtype)
48
+ elif ui_mode == "eas":
49
+ demo, controller = ui_eas(model_name, ddpm_scheduler_dict, savedir_sample)
50
+ else:
51
+ demo, controller = ui(GPU_memory_mode, ddpm_scheduler_dict, weight_dtype)
52
+
53
+ # launch gradio
54
+ app, _, _ = demo.queue(status_update_rate=1).launch(
55
+ server_name=server_name,
56
+ server_port=server_port,
57
+ prevent_thread_lock=True
58
+ )
59
+
60
+ # launch api
61
+ infer_forward_api(None, app, controller)
62
+ update_diffusion_transformer_api(None, app, controller)
63
+ update_edition_api(None, app, controller)
64
+
65
+ # not close the python
66
+ while True:
67
+ time.sleep(5)
examples/cogvideox_fun/predict_i2v.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import numpy as np
5
+ import torch
6
+ from diffusers import (CogVideoXDDIMScheduler, DDIMScheduler,
7
+ DPMSolverMultistepScheduler,
8
+ EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
9
+ PNDMScheduler)
10
+ from PIL import Image
11
+
12
+ current_file_path = os.path.abspath(__file__)
13
+ project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))]
14
+ for project_root in project_roots:
15
+ sys.path.insert(0, project_root) if project_root not in sys.path else None
16
+
17
+ from cogvideox.models import (AutoencoderKLCogVideoX,
18
+ CogVideoXTransformer3DModel, T5EncoderModel,
19
+ T5Tokenizer)
20
+ from cogvideox.pipeline import (CogVideoXFunPipeline,
21
+ CogVideoXFunInpaintPipeline)
22
+ from cogvideox.utils.fp8_optimization import convert_weight_dtype_wrapper
23
+ from cogvideox.utils.lora_utils import merge_lora, unmerge_lora
24
+ from cogvideox.utils.utils import get_image_to_video_latent, save_videos_grid
25
+
26
+ # GPU memory mode, which can be choosen in [model_cpu_offload, model_cpu_offload_and_qfloat8, sequential_cpu_offload].
27
+ # model_cpu_offload means that the entire model will be moved to the CPU after use, which can save some GPU memory.
28
+ #
29
+ # model_cpu_offload_and_qfloat8 indicates that the entire model will be moved to the CPU after use,
30
+ # and the transformer model has been quantized to float8, which can save more GPU memory.
31
+ #
32
+ # sequential_cpu_offload means that each layer of the model will be moved to the CPU after use,
33
+ # resulting in slower speeds but saving a large amount of GPU memory.
34
+ GPU_memory_mode = "model_cpu_offload_and_qfloat8"
35
+
36
+ # Config and model path
37
+ model_name = "models/Diffusion_Transformer/CogVideoX-Fun-V1.1-2b-InP"
38
+
39
+ # Choose the sampler in "Euler" "Euler A" "DPM++" "PNDM" "DDIM_Cog" and "DDIM_Origin"
40
+ sampler_name = "DDIM_Origin"
41
+
42
+ # Load pretrained model if need
43
+ transformer_path = None
44
+ vae_path = None
45
+ lora_path = None
46
+
47
+ # Other params
48
+ sample_size = [384, 672]
49
+ # V1.0 and V1.1 support up to 49 frames of video generation,
50
+ # while V1.5 supports up to 85 frames.
51
+ video_length = 49
52
+ fps = 8
53
+
54
+ # If you want to generate ultra long videos, please set partial_video_length as the length of each sub video segment
55
+ partial_video_length = None
56
+ overlap_video_length = 4
57
+
58
+ # Use torch.float16 if GPU does not support torch.bfloat16
59
+ # ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16
60
+ weight_dtype = torch.bfloat16
61
+ # If you want to generate from text, please set the validation_image_start = None and validation_image_end = None
62
+ validation_image_start = "asset/1.png"
63
+ validation_image_end = None
64
+
65
+ # prompts
66
+ prompt = "The dog is shaking head. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic."
67
+ negative_prompt = "The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. "
68
+ guidance_scale = 6.0
69
+ seed = 43
70
+ num_inference_steps = 50
71
+ lora_weight = 0.55
72
+ save_path = "samples/cogvideox-fun-videos_i2v"
73
+
74
+ transformer = CogVideoXTransformer3DModel.from_pretrained(
75
+ model_name,
76
+ subfolder="transformer",
77
+ low_cpu_mem_usage=True,
78
+ torch_dtype=torch.float8_e4m3fn if GPU_memory_mode == "model_cpu_offload_and_qfloat8" else weight_dtype,
79
+ ).to(weight_dtype)
80
+
81
+ if transformer_path is not None:
82
+ print(f"From checkpoint: {transformer_path}")
83
+ if transformer_path.endswith("safetensors"):
84
+ from safetensors.torch import load_file, safe_open
85
+ state_dict = load_file(transformer_path)
86
+ else:
87
+ state_dict = torch.load(transformer_path, map_location="cpu")
88
+ state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
89
+
90
+ m, u = transformer.load_state_dict(state_dict, strict=False)
91
+ print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
92
+
93
+ # Get Vae
94
+ vae = AutoencoderKLCogVideoX.from_pretrained(
95
+ model_name,
96
+ subfolder="vae"
97
+ ).to(weight_dtype)
98
+
99
+ if vae_path is not None:
100
+ print(f"From checkpoint: {vae_path}")
101
+ if vae_path.endswith("safetensors"):
102
+ from safetensors.torch import load_file, safe_open
103
+ state_dict = load_file(vae_path)
104
+ else:
105
+ state_dict = torch.load(vae_path, map_location="cpu")
106
+ state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
107
+
108
+ m, u = vae.load_state_dict(state_dict, strict=False)
109
+ print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
110
+
111
+ # Get tokenizer and text_encoder
112
+ tokenizer = T5Tokenizer.from_pretrained(
113
+ model_name, subfolder="tokenizer"
114
+ )
115
+ text_encoder = T5EncoderModel.from_pretrained(
116
+ model_name, subfolder="text_encoder", torch_dtype=weight_dtype
117
+ )
118
+
119
+ # Get Scheduler
120
+ Choosen_Scheduler = scheduler_dict = {
121
+ "Euler": EulerDiscreteScheduler,
122
+ "Euler A": EulerAncestralDiscreteScheduler,
123
+ "DPM++": DPMSolverMultistepScheduler,
124
+ "PNDM": PNDMScheduler,
125
+ "DDIM_Cog": CogVideoXDDIMScheduler,
126
+ "DDIM_Origin": DDIMScheduler,
127
+ }[sampler_name]
128
+ scheduler = Choosen_Scheduler.from_pretrained(
129
+ model_name,
130
+ subfolder="scheduler"
131
+ )
132
+
133
+ if transformer.config.in_channels != vae.config.latent_channels:
134
+ pipeline = CogVideoXFunInpaintPipeline(
135
+ vae=vae,
136
+ tokenizer=tokenizer,
137
+ text_encoder=text_encoder,
138
+ transformer=transformer,
139
+ scheduler=scheduler,
140
+ )
141
+ else:
142
+ pipeline = CogVideoXFunPipeline(
143
+ vae=vae,
144
+ tokenizer=tokenizer,
145
+ text_encoder=text_encoder,
146
+ transformer=transformer,
147
+ scheduler=scheduler,
148
+ )
149
+ if GPU_memory_mode == "sequential_cpu_offload":
150
+ pipeline.enable_sequential_cpu_offload()
151
+ elif GPU_memory_mode == "model_cpu_offload_and_qfloat8":
152
+ convert_weight_dtype_wrapper(transformer, weight_dtype)
153
+ pipeline.enable_model_cpu_offload()
154
+ else:
155
+ pipeline.enable_model_cpu_offload()
156
+
157
+ generator = torch.Generator(device="cuda").manual_seed(seed)
158
+
159
+ if lora_path is not None:
160
+ pipeline = merge_lora(pipeline, lora_path, lora_weight)
161
+
162
+ if partial_video_length is not None:
163
+ partial_video_length = int((partial_video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
164
+ latent_frames = (partial_video_length - 1) // vae.config.temporal_compression_ratio + 1
165
+ if partial_video_length != 1 and transformer.config.patch_size_t is not None and latent_frames % transformer.config.patch_size_t != 0:
166
+ additional_frames = transformer.config.patch_size_t - latent_frames % transformer.config.patch_size_t
167
+ partial_video_length += additional_frames * vae.config.temporal_compression_ratio
168
+
169
+ init_frames = 0
170
+ last_frames = init_frames + partial_video_length
171
+ while init_frames < video_length:
172
+ if last_frames >= video_length:
173
+ _partial_video_length = video_length - init_frames
174
+ _partial_video_length = int((_partial_video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1
175
+ latent_frames = (_partial_video_length - 1) // vae.config.temporal_compression_ratio + 1
176
+ if _partial_video_length != 1 and transformer.config.patch_size_t is not None and latent_frames % transformer.config.patch_size_t != 0:
177
+ additional_frames = transformer.config.patch_size_t - latent_frames % transformer.config.patch_size_t
178
+ _partial_video_length += additional_frames * vae.config.temporal_compression_ratio
179
+
180
+ if _partial_video_length <= 0:
181
+ break
182
+ else:
183
+ _partial_video_length = partial_video_length
184
+
185
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(validation_image, None, video_length=_partial_video_length, sample_size=sample_size)
186
+
187
+ with torch.no_grad():
188
+ sample = pipeline(
189
+ prompt,
190
+ num_frames = _partial_video_length,
191
+ negative_prompt = negative_prompt,
192
+ height = sample_size[0],
193
+ width = sample_size[1],
194
+ generator = generator,
195
+ guidance_scale = guidance_scale,
196
+ num_inference_steps = num_inference_steps,
197
+
198
+ video = input_video,
199
+ mask_video = input_video_mask
200
+ ).videos
201
+
202
+ if init_frames != 0:
203
+ mix_ratio = torch.from_numpy(
204
+ np.array([float(_index) / float(overlap_video_length) for _index in range(overlap_video_length)], np.float32)
205
+ ).unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
206
+
207
+ new_sample[:, :, -overlap_video_length:] = new_sample[:, :, -overlap_video_length:] * (1 - mix_ratio) + \
208
+ sample[:, :, :overlap_video_length] * mix_ratio
209
+ new_sample = torch.cat([new_sample, sample[:, :, overlap_video_length:]], dim = 2)
210
+
211
+ sample = new_sample
212
+ else:
213
+ new_sample = sample
214
+
215
+ if last_frames >= video_length:
216
+ break
217
+
218
+ validation_image = [
219
+ Image.fromarray(
220
+ (sample[0, :, _index].transpose(0, 1).transpose(1, 2) * 255).numpy().astype(np.uint8)
221
+ ) for _index in range(-overlap_video_length, 0)
222
+ ]
223
+
224
+ init_frames = init_frames + _partial_video_length - overlap_video_length
225
+ last_frames = init_frames + _partial_video_length
226
+ else:
227
+ video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
228
+ latent_frames = (video_length - 1) // vae.config.temporal_compression_ratio + 1
229
+ if video_length != 1 and transformer.config.patch_size_t is not None and latent_frames % transformer.config.patch_size_t != 0:
230
+ additional_frames = transformer.config.patch_size_t - latent_frames % transformer.config.patch_size_t
231
+ video_length += additional_frames * vae.config.temporal_compression_ratio
232
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(validation_image_start, validation_image_end, video_length=video_length, sample_size=sample_size)
233
+
234
+ with torch.no_grad():
235
+ sample = pipeline(
236
+ prompt,
237
+ num_frames = video_length,
238
+ negative_prompt = negative_prompt,
239
+ height = sample_size[0],
240
+ width = sample_size[1],
241
+ generator = generator,
242
+ guidance_scale = guidance_scale,
243
+ num_inference_steps = num_inference_steps,
244
+
245
+ video = input_video,
246
+ mask_video = input_video_mask
247
+ ).videos
248
+
249
+ if lora_path is not None:
250
+ pipeline = unmerge_lora(pipeline, lora_path, lora_weight)
251
+
252
+ if not os.path.exists(save_path):
253
+ os.makedirs(save_path, exist_ok=True)
254
+
255
+ index = len([path for path in os.listdir(save_path)]) + 1
256
+ prefix = str(index).zfill(8)
257
+
258
+ if video_length == 1:
259
+ video_path = os.path.join(save_path, prefix + ".png")
260
+
261
+ image = sample[0, :, 0]
262
+ image = image.transpose(0, 1).transpose(1, 2)
263
+ image = (image * 255).numpy().astype(np.uint8)
264
+ image = Image.fromarray(image)
265
+ image.save(video_path)
266
+ else:
267
+ video_path = os.path.join(save_path, prefix + ".mp4")
268
+ save_videos_grid(sample, video_path, fps=fps)
examples/cogvideox_fun/predict_t2v.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import numpy as np
5
+ import torch
6
+ from diffusers import (CogVideoXDDIMScheduler, DDIMScheduler,
7
+ DPMSolverMultistepScheduler,
8
+ EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
9
+ PNDMScheduler)
10
+ from PIL import Image
11
+ from transformers import T5EncoderModel
12
+
13
+ current_file_path = os.path.abspath(__file__)
14
+ project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))]
15
+ for project_root in project_roots:
16
+ sys.path.insert(0, project_root) if project_root not in sys.path else None
17
+
18
+ from cogvideox.models import (AutoencoderKLCogVideoX,
19
+ CogVideoXTransformer3DModel, T5EncoderModel,
20
+ T5Tokenizer)
21
+ from cogvideox.pipeline import (CogVideoXFunPipeline,
22
+ CogVideoXFunInpaintPipeline)
23
+ from cogvideox.utils.fp8_optimization import convert_weight_dtype_wrapper
24
+ from cogvideox.utils.lora_utils import merge_lora, unmerge_lora
25
+ from cogvideox.utils.utils import get_image_to_video_latent, save_videos_grid
26
+
27
+ # GPU memory mode, which can be choosen in [model_cpu_offload, model_cpu_offload_and_qfloat8, sequential_cpu_offload].
28
+ # model_cpu_offload means that the entire model will be moved to the CPU after use, which can save some GPU memory.
29
+ #
30
+ # model_cpu_offload_and_qfloat8 indicates that the entire model will be moved to the CPU after use,
31
+ # and the transformer model has been quantized to float8, which can save more GPU memory.
32
+ #
33
+ # sequential_cpu_offload means that each layer of the model will be moved to the CPU after use,
34
+ # resulting in slower speeds but saving a large amount of GPU memory.
35
+ GPU_memory_mode = "model_cpu_offload_and_qfloat8"
36
+
37
+ # model path
38
+ model_name = "models/Diffusion_Transformer/CogVideoX-Fun-V1.1-2b-InP"
39
+
40
+ # Choose the sampler in "Euler" "Euler A" "DPM++" "PNDM" "DDIM_Cog" and "DDIM_Origin"
41
+ sampler_name = "DDIM_Origin"
42
+
43
+ # Load pretrained model if need
44
+ transformer_path = None
45
+ vae_path = None
46
+ lora_path = None
47
+
48
+ # Other params
49
+ sample_size = [384, 672]
50
+ # V1.0 and V1.1 support up to 49 frames of video generation,
51
+ # while V1.5 supports up to 85 frames.
52
+ video_length = 49
53
+ fps = 8
54
+
55
+ # Use torch.float16 if GPU does not support torch.bfloat16
56
+ # ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16
57
+ weight_dtype = torch.bfloat16
58
+ prompt = "A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic."
59
+ negative_prompt = "The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. "
60
+ guidance_scale = 6.0
61
+ seed = 43
62
+ num_inference_steps = 50
63
+ lora_weight = 0.55
64
+ save_path = "samples/cogvideox-fun-videos-t2v"
65
+
66
+ transformer = CogVideoXTransformer3DModel.from_pretrained(
67
+ model_name,
68
+ subfolder="transformer",
69
+ low_cpu_mem_usage=True,
70
+ torch_dtype=torch.float8_e4m3fn if GPU_memory_mode == "model_cpu_offload_and_qfloat8" else weight_dtype,
71
+ ).to(weight_dtype)
72
+
73
+ if transformer_path is not None:
74
+ print(f"From checkpoint: {transformer_path}")
75
+ if transformer_path.endswith("safetensors"):
76
+ from safetensors.torch import load_file, safe_open
77
+ state_dict = load_file(transformer_path)
78
+ else:
79
+ state_dict = torch.load(transformer_path, map_location="cpu")
80
+ state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
81
+
82
+ m, u = transformer.load_state_dict(state_dict, strict=False)
83
+ print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
84
+
85
+ # Get Vae
86
+ vae = AutoencoderKLCogVideoX.from_pretrained(
87
+ model_name,
88
+ subfolder="vae"
89
+ ).to(weight_dtype)
90
+
91
+ if vae_path is not None:
92
+ print(f"From checkpoint: {vae_path}")
93
+ if vae_path.endswith("safetensors"):
94
+ from safetensors.torch import load_file, safe_open
95
+ state_dict = load_file(vae_path)
96
+ else:
97
+ state_dict = torch.load(vae_path, map_location="cpu")
98
+ state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
99
+
100
+ m, u = vae.load_state_dict(state_dict, strict=False)
101
+ print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
102
+
103
+ # Get tokenizer and text_encoder
104
+ tokenizer = T5Tokenizer.from_pretrained(
105
+ model_name, subfolder="tokenizer"
106
+ )
107
+ text_encoder = T5EncoderModel.from_pretrained(
108
+ model_name, subfolder="text_encoder", torch_dtype=weight_dtype
109
+ )
110
+
111
+ # Get Scheduler
112
+ Choosen_Scheduler = scheduler_dict = {
113
+ "Euler": EulerDiscreteScheduler,
114
+ "Euler A": EulerAncestralDiscreteScheduler,
115
+ "DPM++": DPMSolverMultistepScheduler,
116
+ "PNDM": PNDMScheduler,
117
+ "DDIM_Cog": CogVideoXDDIMScheduler,
118
+ "DDIM_Origin": DDIMScheduler,
119
+ }[sampler_name]
120
+ scheduler = Choosen_Scheduler.from_pretrained(
121
+ model_name,
122
+ subfolder="scheduler"
123
+ )
124
+
125
+ if transformer.config.in_channels != vae.config.latent_channels:
126
+ pipeline = CogVideoXFunInpaintPipeline(
127
+ vae=vae,
128
+ tokenizer=tokenizer,
129
+ text_encoder=text_encoder,
130
+ transformer=transformer,
131
+ scheduler=scheduler,
132
+ )
133
+ else:
134
+ pipeline = CogVideoXFunPipeline(
135
+ vae=vae,
136
+ tokenizer=tokenizer,
137
+ text_encoder=text_encoder,
138
+ transformer=transformer,
139
+ scheduler=scheduler,
140
+ )
141
+ if GPU_memory_mode == "sequential_cpu_offload":
142
+ pipeline.enable_sequential_cpu_offload()
143
+ elif GPU_memory_mode == "model_cpu_offload_and_qfloat8":
144
+ convert_weight_dtype_wrapper(transformer, weight_dtype)
145
+ pipeline.enable_model_cpu_offload()
146
+ else:
147
+ pipeline.enable_model_cpu_offload()
148
+
149
+ generator = torch.Generator(device="cuda").manual_seed(seed)
150
+
151
+ if lora_path is not None:
152
+ pipeline = merge_lora(pipeline, lora_path, lora_weight)
153
+
154
+ with torch.no_grad():
155
+ video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
156
+ latent_frames = (video_length - 1) // vae.config.temporal_compression_ratio + 1
157
+ if video_length != 1 and transformer.config.patch_size_t is not None and latent_frames % transformer.config.patch_size_t != 0:
158
+ additional_frames = transformer.config.patch_size_t - latent_frames % transformer.config.patch_size_t
159
+ video_length += additional_frames * vae.config.temporal_compression_ratio
160
+
161
+ if transformer.config.in_channels != vae.config.latent_channels:
162
+ input_video, input_video_mask, _ = get_image_to_video_latent(None, None, video_length=video_length, sample_size=sample_size)
163
+
164
+ sample = pipeline(
165
+ prompt,
166
+ num_frames = video_length,
167
+ negative_prompt = negative_prompt,
168
+ height = sample_size[0],
169
+ width = sample_size[1],
170
+ generator = generator,
171
+ guidance_scale = guidance_scale,
172
+ num_inference_steps = num_inference_steps,
173
+
174
+ video = input_video,
175
+ mask_video = input_video_mask,
176
+ ).videos
177
+ else:
178
+ sample = pipeline(
179
+ prompt,
180
+ num_frames = video_length,
181
+ negative_prompt = negative_prompt,
182
+ height = sample_size[0],
183
+ width = sample_size[1],
184
+ generator = generator,
185
+ guidance_scale = guidance_scale,
186
+ num_inference_steps = num_inference_steps,
187
+ ).videos
188
+
189
+ if lora_path is not None:
190
+ pipeline = unmerge_lora(pipeline, lora_path, lora_weight)
191
+
192
+ if not os.path.exists(save_path):
193
+ os.makedirs(save_path, exist_ok=True)
194
+
195
+ index = len([path for path in os.listdir(save_path)]) + 1
196
+ prefix = str(index).zfill(8)
197
+
198
+ if video_length == 1:
199
+ video_path = os.path.join(save_path, prefix + ".png")
200
+
201
+ image = sample[0, :, 0]
202
+ image = image.transpose(0, 1).transpose(1, 2)
203
+ image = (image * 255).numpy().astype(np.uint8)
204
+ image = Image.fromarray(image)
205
+ image.save(video_path)
206
+ else:
207
+ video_path = os.path.join(save_path, prefix + ".mp4")
208
+ save_videos_grid(sample, video_path, fps=fps)
examples/cogvideox_fun/predict_v2v.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import numpy as np
5
+ import torch
6
+ from diffusers import (CogVideoXDDIMScheduler, DDIMScheduler,
7
+ DPMSolverMultistepScheduler,
8
+ EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
9
+ PNDMScheduler)
10
+ from PIL import Image
11
+
12
+ current_file_path = os.path.abspath(__file__)
13
+ project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))]
14
+ for project_root in project_roots:
15
+ sys.path.insert(0, project_root) if project_root not in sys.path else None
16
+
17
+ from cogvideox.models import (AutoencoderKLCogVideoX,
18
+ CogVideoXTransformer3DModel, T5EncoderModel,
19
+ T5Tokenizer)
20
+ from cogvideox.pipeline import (CogVideoXFunPipeline,
21
+ CogVideoXFunInpaintPipeline)
22
+ from cogvideox.utils.lora_utils import merge_lora, unmerge_lora
23
+ from cogvideox.utils.fp8_optimization import convert_weight_dtype_wrapper
24
+ from cogvideox.utils.utils import get_video_to_video_latent, save_videos_grid
25
+
26
+ # GPU memory mode, which can be choosen in [model_cpu_offload, model_cpu_offload_and_qfloat8, sequential_cpu_offload].
27
+ # model_cpu_offload means that the entire model will be moved to the CPU after use, which can save some GPU memory.
28
+ #
29
+ # model_cpu_offload_and_qfloat8 indicates that the entire model will be moved to the CPU after use,
30
+ # and the transformer model has been quantized to float8, which can save more GPU memory.
31
+ #
32
+ # sequential_cpu_offload means that each layer of the model will be moved to the CPU after use,
33
+ # resulting in slower speeds but saving a large amount of GPU memory.
34
+ GPU_memory_mode = "model_cpu_offload_and_qfloat8"
35
+
36
+ # model path
37
+ model_name = "models/Diffusion_Transformer/CogVideoX-Fun-V1.1-2b-InP"
38
+
39
+ # Choose the sampler in "Euler" "Euler A" "DPM++" "PNDM" "DDIM_Cog" and "DDIM_Origin"
40
+ sampler_name = "DDIM_Origin"
41
+
42
+ # Load pretrained model if need
43
+ transformer_path = None
44
+ vae_path = None
45
+ lora_path = None
46
+ # Other params
47
+ sample_size = [384, 672]
48
+ # V1.0 and V1.1 support up to 49 frames of video generation,
49
+ # while V1.5 supports up to 85 frames.
50
+ video_length = 49
51
+ fps = 8
52
+
53
+ # Use torch.float16 if GPU does not support torch.bfloat16
54
+ # ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16
55
+ weight_dtype = torch.bfloat16
56
+ # If you are preparing to redraw the reference video, set validation_video and validation_video_mask.
57
+ # If you do not use validation_video_mask, the entire video will be redrawn;
58
+ # if you use validation_video_mask, only a portion of the video will be redrawn.
59
+ # Please set a larger denoise_strength when using validation_video_mask, such as 1.00 instead of 0.70
60
+ validation_video = "asset/1.mp4"
61
+ validation_video_mask = None
62
+ denoise_strength = 0.70
63
+
64
+ # prompts
65
+ prompt = "A cute cat is playing the guitar. "
66
+ negative_prompt = "The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. "
67
+ guidance_scale = 6.0
68
+ seed = 43
69
+ num_inference_steps = 50
70
+ lora_weight = 0.55
71
+ save_path = "samples/cogvideox-fun-videos_v2v"
72
+
73
+ transformer = CogVideoXTransformer3DModel.from_pretrained(
74
+ model_name,
75
+ subfolder="transformer",
76
+ low_cpu_mem_usage=True,
77
+ torch_dtype=torch.float8_e4m3fn if GPU_memory_mode == "model_cpu_offload_and_qfloat8" else weight_dtype,
78
+ ).to(weight_dtype)
79
+
80
+ if transformer_path is not None:
81
+ print(f"From checkpoint: {transformer_path}")
82
+ if transformer_path.endswith("safetensors"):
83
+ from safetensors.torch import load_file, safe_open
84
+ state_dict = load_file(transformer_path)
85
+ else:
86
+ state_dict = torch.load(transformer_path, map_location="cpu")
87
+ state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
88
+
89
+ m, u = transformer.load_state_dict(state_dict, strict=False)
90
+ print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
91
+
92
+ # Get Vae
93
+ vae = AutoencoderKLCogVideoX.from_pretrained(
94
+ model_name,
95
+ subfolder="vae"
96
+ ).to(weight_dtype)
97
+
98
+ if vae_path is not None:
99
+ print(f"From checkpoint: {vae_path}")
100
+ if vae_path.endswith("safetensors"):
101
+ from safetensors.torch import load_file, safe_open
102
+ state_dict = load_file(vae_path)
103
+ else:
104
+ state_dict = torch.load(vae_path, map_location="cpu")
105
+ state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
106
+
107
+ m, u = vae.load_state_dict(state_dict, strict=False)
108
+ print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
109
+
110
+ # Get tokenizer and text_encoder
111
+ tokenizer = T5Tokenizer.from_pretrained(
112
+ model_name, subfolder="tokenizer"
113
+ )
114
+ text_encoder = T5EncoderModel.from_pretrained(
115
+ model_name, subfolder="text_encoder", torch_dtype=weight_dtype
116
+ )
117
+
118
+ # Get Scheduler
119
+ Choosen_Scheduler = scheduler_dict = {
120
+ "Euler": EulerDiscreteScheduler,
121
+ "Euler A": EulerAncestralDiscreteScheduler,
122
+ "DPM++": DPMSolverMultistepScheduler,
123
+ "PNDM": PNDMScheduler,
124
+ "DDIM_Cog": CogVideoXDDIMScheduler,
125
+ "DDIM_Origin": DDIMScheduler,
126
+ }[sampler_name]
127
+ scheduler = Choosen_Scheduler.from_pretrained(
128
+ model_name,
129
+ subfolder="scheduler"
130
+ )
131
+
132
+ if transformer.config.in_channels != vae.config.latent_channels:
133
+ pipeline = CogVideoXFunInpaintPipeline(
134
+ vae=vae,
135
+ tokenizer=tokenizer,
136
+ text_encoder=text_encoder,
137
+ transformer=transformer,
138
+ scheduler=scheduler,
139
+ )
140
+ else:
141
+ pipeline = CogVideoXFunPipeline(
142
+ vae=vae,
143
+ tokenizer=tokenizer,
144
+ text_encoder=text_encoder,
145
+ transformer=transformer,
146
+ scheduler=scheduler,
147
+ )
148
+ if GPU_memory_mode == "sequential_cpu_offload":
149
+ pipeline.enable_sequential_cpu_offload()
150
+ elif GPU_memory_mode == "model_cpu_offload_and_qfloat8":
151
+ convert_weight_dtype_wrapper(transformer, weight_dtype)
152
+ pipeline.enable_model_cpu_offload()
153
+ else:
154
+ pipeline.enable_model_cpu_offload()
155
+
156
+ generator = torch.Generator(device="cuda").manual_seed(seed)
157
+
158
+ if lora_path is not None:
159
+ pipeline = merge_lora(pipeline, lora_path, lora_weight)
160
+
161
+ video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
162
+ latent_frames = (video_length - 1) // vae.config.temporal_compression_ratio + 1
163
+ if video_length != 1 and transformer.config.patch_size_t is not None and latent_frames % transformer.config.patch_size_t != 0:
164
+ additional_frames = transformer.config.patch_size_t - latent_frames % transformer.config.patch_size_t
165
+ video_length += additional_frames * vae.config.temporal_compression_ratio
166
+ input_video, input_video_mask, clip_image = get_video_to_video_latent(validation_video, video_length=video_length, sample_size=sample_size, validation_video_mask=validation_video_mask, fps=fps)
167
+
168
+ with torch.no_grad():
169
+ sample = pipeline(
170
+ prompt,
171
+ num_frames = video_length,
172
+ negative_prompt = negative_prompt,
173
+ height = sample_size[0],
174
+ width = sample_size[1],
175
+ generator = generator,
176
+ guidance_scale = guidance_scale,
177
+ num_inference_steps = num_inference_steps,
178
+
179
+ video = input_video,
180
+ mask_video = input_video_mask,
181
+ strength = denoise_strength,
182
+ ).videos
183
+
184
+ if lora_path is not None:
185
+ pipeline = unmerge_lora(pipeline, lora_path, lora_weight)
186
+
187
+ if not os.path.exists(save_path):
188
+ os.makedirs(save_path, exist_ok=True)
189
+
190
+ index = len([path for path in os.listdir(save_path)]) + 1
191
+ prefix = str(index).zfill(8)
192
+
193
+ if video_length == 1:
194
+ save_sample_path = os.path.join(save_path, prefix + f".png")
195
+
196
+ image = sample[0, :, 0]
197
+ image = image.transpose(0, 1).transpose(1, 2)
198
+ image = (image * 255).numpy().astype(np.uint8)
199
+ image = Image.fromarray(image)
200
+ image.save(save_sample_path)
201
+ else:
202
+ video_path = os.path.join(save_path, prefix + ".mp4")
203
+ save_videos_grid(sample, video_path, fps=fps)
examples/cogvideox_fun/predict_v2v_control.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ from diffusers import (CogVideoXDDIMScheduler, DDIMScheduler,
8
+ DPMSolverMultistepScheduler,
9
+ EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
10
+ PNDMScheduler)
11
+ from PIL import Image
12
+ from transformers import T5EncoderModel
13
+
14
+ current_file_path = os.path.abspath(__file__)
15
+ project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))]
16
+ for project_root in project_roots:
17
+ sys.path.insert(0, project_root) if project_root not in sys.path else None
18
+
19
+ from cogvideox.models import (AutoencoderKLCogVideoX,
20
+ CogVideoXTransformer3DModel, T5EncoderModel,
21
+ T5Tokenizer)
22
+ from cogvideox.pipeline import (CogVideoXFunControlPipeline,
23
+ CogVideoXFunInpaintPipeline)
24
+ from cogvideox.utils.fp8_optimization import convert_weight_dtype_wrapper
25
+ from cogvideox.utils.lora_utils import merge_lora, unmerge_lora
26
+ from cogvideox.utils.utils import get_video_to_video_latent, save_videos_grid
27
+
28
+ # GPU memory mode, which can be choosen in [model_cpu_offload, model_cpu_offload_and_qfloat8, sequential_cpu_offload].
29
+ # model_cpu_offload means that the entire model will be moved to the CPU after use, which can save some GPU memory.
30
+ #
31
+ # model_cpu_offload_and_qfloat8 indicates that the entire model will be moved to the CPU after use,
32
+ # and the transformer model has been quantized to float8, which can save more GPU memory.
33
+ #
34
+ # sequential_cpu_offload means that each layer of the model will be moved to the CPU after use,
35
+ # resulting in slower speeds but saving a large amount of GPU memory.
36
+ GPU_memory_mode = "model_cpu_offload_and_qfloat8"
37
+
38
+ # model path
39
+ model_name = "models/Diffusion_Transformer/CogVideoX-Fun-V1.1-2b-Pose"
40
+
41
+ # Choose the sampler in "Euler" "Euler A" "DPM++" "PNDM" "DDIM_Cog" and "DDIM_Origin"
42
+ sampler_name = "DDIM_Origin"
43
+
44
+ # Load pretrained model if need
45
+ transformer_path = None
46
+ vae_path = None
47
+ lora_path = None
48
+ # Other params
49
+ sample_size = [672, 384]
50
+ # V1.0 and V1.1 support up to 49 frames of video generation,
51
+ # while V1.5 supports up to 85 frames.
52
+ video_length = 49
53
+ fps = 8
54
+
55
+ # Use torch.float16 if GPU does not support torch.bfloat16
56
+ # ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16
57
+ weight_dtype = torch.bfloat16
58
+ control_video = "asset/pose.mp4"
59
+
60
+ # prompts
61
+ prompt = "A young woman with beautiful face, dressed in white, is moving her body. "
62
+ negative_prompt = "The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. "
63
+ guidance_scale = 6.0
64
+ seed = 43
65
+ num_inference_steps = 50
66
+ lora_weight = 0.55
67
+ save_path = "samples/cogvideox-fun-videos_control"
68
+
69
+ transformer = CogVideoXTransformer3DModel.from_pretrained(
70
+ model_name,
71
+ subfolder="transformer",
72
+ low_cpu_mem_usage=True,
73
+ torch_dtype=torch.float8_e4m3fn if GPU_memory_mode == "model_cpu_offload_and_qfloat8" else weight_dtype,
74
+ ).to(weight_dtype)
75
+
76
+ if transformer_path is not None:
77
+ print(f"From checkpoint: {transformer_path}")
78
+ if transformer_path.endswith("safetensors"):
79
+ from safetensors.torch import load_file, safe_open
80
+ state_dict = load_file(transformer_path)
81
+ else:
82
+ state_dict = torch.load(transformer_path, map_location="cpu")
83
+ state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
84
+
85
+ m, u = transformer.load_state_dict(state_dict, strict=False)
86
+ print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
87
+
88
+ # Get Vae
89
+ vae = AutoencoderKLCogVideoX.from_pretrained(
90
+ model_name,
91
+ subfolder="vae"
92
+ ).to(weight_dtype)
93
+
94
+ if vae_path is not None:
95
+ print(f"From checkpoint: {vae_path}")
96
+ if vae_path.endswith("safetensors"):
97
+ from safetensors.torch import load_file, safe_open
98
+ state_dict = load_file(vae_path)
99
+ else:
100
+ state_dict = torch.load(vae_path, map_location="cpu")
101
+ state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
102
+
103
+ m, u = vae.load_state_dict(state_dict, strict=False)
104
+ print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
105
+
106
+ # Get tokenizer and text_encoder
107
+ tokenizer = T5Tokenizer.from_pretrained(
108
+ model_name, subfolder="tokenizer"
109
+ )
110
+ text_encoder = T5EncoderModel.from_pretrained(
111
+ model_name, subfolder="text_encoder", torch_dtype=weight_dtype
112
+ )
113
+
114
+ # Get Scheduler
115
+ Choosen_Scheduler = scheduler_dict = {
116
+ "Euler": EulerDiscreteScheduler,
117
+ "Euler A": EulerAncestralDiscreteScheduler,
118
+ "DPM++": DPMSolverMultistepScheduler,
119
+ "PNDM": PNDMScheduler,
120
+ "DDIM_Cog": CogVideoXDDIMScheduler,
121
+ "DDIM_Origin": DDIMScheduler,
122
+ }[sampler_name]
123
+ scheduler = Choosen_Scheduler.from_pretrained(
124
+ model_name,
125
+ subfolder="scheduler"
126
+ )
127
+
128
+ pipeline = CogVideoXFunControlPipeline.from_pretrained(
129
+ vae=vae,
130
+ tokenizer=tokenizer,
131
+ text_encoder=text_encoder,
132
+ transformer=transformer,
133
+ scheduler=scheduler,
134
+ )
135
+ if GPU_memory_mode == "sequential_cpu_offload":
136
+ pipeline.enable_sequential_cpu_offload()
137
+ elif GPU_memory_mode == "model_cpu_offload_and_qfloat8":
138
+ convert_weight_dtype_wrapper(transformer, weight_dtype)
139
+ pipeline.enable_model_cpu_offload()
140
+ else:
141
+ pipeline.enable_model_cpu_offload()
142
+
143
+ generator = torch.Generator(device="cuda").manual_seed(seed)
144
+
145
+ if lora_path is not None:
146
+ pipeline = merge_lora(pipeline, lora_path, lora_weight)
147
+
148
+ video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
149
+ latent_frames = (video_length - 1) // vae.config.temporal_compression_ratio + 1
150
+ if video_length != 1 and transformer.config.patch_size_t is not None and latent_frames % transformer.config.patch_size_t != 0:
151
+ additional_frames = transformer.config.patch_size_t - latent_frames % transformer.config.patch_size_t
152
+ video_length += additional_frames * vae.config.temporal_compression_ratio
153
+ input_video, input_video_mask, clip_image = get_video_to_video_latent(control_video, video_length=video_length, sample_size=sample_size, fps=fps)
154
+
155
+ with torch.no_grad():
156
+ sample = pipeline(
157
+ prompt,
158
+ num_frames = video_length,
159
+ negative_prompt = negative_prompt,
160
+ height = sample_size[0],
161
+ width = sample_size[1],
162
+ generator = generator,
163
+ guidance_scale = guidance_scale,
164
+ num_inference_steps = num_inference_steps,
165
+
166
+ control_video = input_video,
167
+ ).videos
168
+
169
+ if lora_path is not None:
170
+ pipeline = unmerge_lora(pipeline, lora_path, lora_weight)
171
+
172
+ if not os.path.exists(save_path):
173
+ os.makedirs(save_path, exist_ok=True)
174
+
175
+ index = len([path for path in os.listdir(save_path)]) + 1
176
+ prefix = str(index).zfill(8)
177
+
178
+ if video_length == 1:
179
+ save_sample_path = os.path.join(save_path, prefix + f".png")
180
+
181
+ image = sample[0, :, 0]
182
+ image = image.transpose(0, 1).transpose(1, 2)
183
+ image = (image * 255).numpy().astype(np.uint8)
184
+ image = Image.fromarray(image)
185
+ image.save(save_sample_path)
186
+ else:
187
+ video_path = os.path.join(save_path, prefix + ".mp4")
188
+ save_videos_grid(sample, video_path, fps=fps)
examples/wan2.1/app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+
5
+ import torch
6
+
7
+ current_file_path = os.path.abspath(__file__)
8
+ project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))]
9
+ for project_root in project_roots:
10
+ sys.path.insert(0, project_root) if project_root not in sys.path else None
11
+
12
+ from cogvideox.api.api import (infer_forward_api,
13
+ update_diffusion_transformer_api,
14
+ update_edition_api)
15
+ from cogvideox.ui.controller import flow_scheduler_dict
16
+ from cogvideox.ui.wan_ui import ui, ui_eas, ui_modelscope
17
+
18
+ if __name__ == "__main__":
19
+ # Choose the ui mode
20
+ ui_mode = "normal"
21
+
22
+ # GPU memory mode, which can be choosen in [model_cpu_offload, model_cpu_offload_and_qfloat8, sequential_cpu_offload].
23
+ # model_cpu_offload means that the entire model will be moved to the CPU after use, which can save some GPU memory.
24
+ #
25
+ # model_cpu_offload_and_qfloat8 indicates that the entire model will be moved to the CPU after use,
26
+ # and the transformer model has been quantized to float8, which can save more GPU memory.
27
+ #
28
+ # sequential_cpu_offload means that each layer of the model will be moved to the CPU after use,
29
+ # resulting in slower speeds but saving a large amount of GPU memory.
30
+ GPU_memory_mode = "sequential_cpu_offload"
31
+ # Use torch.float16 if GPU does not support torch.bfloat16
32
+ # ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16
33
+ weight_dtype = torch.bfloat16
34
+ # Config path
35
+ config_path = "config/wan2.1/wan_civitai.yaml"
36
+
37
+ # Server ip
38
+ server_name = "0.0.0.0"
39
+ server_port = 7860
40
+
41
+ # Params below is used when ui_mode = "modelscope"
42
+ model_name = "models/Diffusion_Transformer/Wan2.1-I2V-14B-480P"
43
+ # "Inpaint" or "Control"
44
+ model_type = "Inpaint"
45
+ # Save dir of this model
46
+ savedir_sample = "samples"
47
+
48
+ if ui_mode == "modelscope":
49
+ demo, controller = ui_modelscope(model_name, model_type, savedir_sample, GPU_memory_mode, flow_scheduler_dict, weight_dtype, config_path)
50
+ elif ui_mode == "eas":
51
+ demo, controller = ui_eas(model_name, flow_scheduler_dict, savedir_sample, config_path)
52
+ else:
53
+ demo, controller = ui(GPU_memory_mode, flow_scheduler_dict, weight_dtype, config_path)
54
+
55
+ # launch gradio
56
+ app, _, _ = demo.queue(status_update_rate=1).launch(
57
+ server_name=server_name,
58
+ server_port=server_port,
59
+ prevent_thread_lock=True
60
+ )
61
+
62
+ # launch api
63
+ infer_forward_api(None, app, controller)
64
+ update_diffusion_transformer_api(None, app, controller)
65
+ update_edition_api(None, app, controller)
66
+
67
+ # not close the python
68
+ while True:
69
+ time.sleep(5)
examples/wan2.1/predict_i2v.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import numpy as np
5
+ import torch
6
+ from diffusers import FlowMatchEulerDiscreteScheduler
7
+ from omegaconf import OmegaConf
8
+ from PIL import Image
9
+ from transformers import AutoTokenizer
10
+
11
+ current_file_path = os.path.abspath(__file__)
12
+ project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))]
13
+ for project_root in project_roots:
14
+ sys.path.insert(0, project_root) if project_root not in sys.path else None
15
+
16
+ from cogvideox.models import (AutoencoderKLWan, AutoTokenizer, CLIPModel,
17
+ WanT5EncoderModel, WanTransformer3DModel)
18
+ from cogvideox.pipeline import WanI2VPipeline
19
+ from cogvideox.utils.fp8_optimization import (convert_model_weight_to_float8, replace_parameters_by_name,
20
+ convert_weight_dtype_wrapper)
21
+ from cogvideox.utils.lora_utils import merge_lora, unmerge_lora
22
+ from cogvideox.utils.utils import (filter_kwargs, get_image_to_video_latent,
23
+ save_videos_grid)
24
+
25
+ # GPU memory mode, which can be choosen in [model_cpu_offload, model_cpu_offload_and_qfloat8, sequential_cpu_offload].
26
+ # model_cpu_offload means that the entire model will be moved to the CPU after use, which can save some GPU memory.
27
+ #
28
+ # model_cpu_offload_and_qfloat8 indicates that the entire model will be moved to the CPU after use,
29
+ # and the transformer model has been quantized to float8, which can save more GPU memory.
30
+ #
31
+ # sequential_cpu_offload means that each layer of the model will be moved to the CPU after use,
32
+ # resulting in slower speeds but saving a large amount of GPU memory.
33
+ GPU_memory_mode = "sequential_cpu_offload"
34
+
35
+ # Config and model path
36
+ config_path = "config/wan2.1/wan_civitai.yaml"
37
+ # model path
38
+ model_name = "models/Diffusion_Transformer/Wan2.1-I2V-14B-480P"
39
+
40
+ # Choose the sampler in "Euler" "Euler A" "DPM++" "PNDM" and "DDIM"
41
+ sampler_name = "Flow"
42
+
43
+ # Load pretrained model if need
44
+ transformer_path = None
45
+ vae_path = None
46
+ lora_path = None
47
+
48
+ # Other params
49
+ sample_size = [480, 832]
50
+ video_length = 81
51
+ fps = 16
52
+
53
+ # Use torch.float16 if GPU does not support torch.bfloat16
54
+ # ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16
55
+ weight_dtype = torch.bfloat16
56
+ # If you want to generate from text, please set the validation_image_start = None and validation_image_end = None
57
+ validation_image_start = "asset/1.png"
58
+
59
+ # prompts
60
+ prompt = "一只棕褐色的狗正摇晃着脑袋,坐在一个舒适的房间里的浅色沙发上。沙发看起来柔软而宽敞,为这只活泼的狗狗提供了一个完美的休息地点。在狗的后面,靠墙摆放着一个架子,架子上挂着一幅精美的镶框画,画中描绘着一些美丽的风景或场景。画框周围装饰着粉红色的花朵,这些花朵不仅增添了房间的色彩,还带来了一丝自然和生机。房间里的灯光柔和而温暖,从天花板上的吊灯和角落里的台灯散发出来,营造出一种温馨舒适的氛围。整个空间给人一种宁静和谐的感觉,仿佛时间在这里变得缓慢而美好。"
61
+ negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
62
+ guidance_scale = 6.0
63
+ seed = 43
64
+ num_inference_steps = 50
65
+ lora_weight = 0.55
66
+ save_path = "samples/wan-videos-i2v"
67
+
68
+ config = OmegaConf.load(config_path)
69
+
70
+ transformer = WanTransformer3DModel.from_pretrained(
71
+ os.path.join(model_name, config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')),
72
+ transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
73
+ low_cpu_mem_usage=True,
74
+ torch_dtype=weight_dtype,
75
+ )
76
+
77
+ if transformer_path is not None:
78
+ print(f"From checkpoint: {transformer_path}")
79
+ if transformer_path.endswith("safetensors"):
80
+ from safetensors.torch import load_file, safe_open
81
+ state_dict = load_file(transformer_path)
82
+ else:
83
+ state_dict = torch.load(transformer_path, map_location="cpu")
84
+ state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
85
+
86
+ m, u = transformer.load_state_dict(state_dict, strict=False)
87
+ print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
88
+
89
+ # Get Vae
90
+ vae = AutoencoderKLWan.from_pretrained(
91
+ os.path.join(model_name, config['vae_kwargs'].get('vae_subpath', 'vae')),
92
+ additional_kwargs=OmegaConf.to_container(config['vae_kwargs']),
93
+ ).to(weight_dtype)
94
+
95
+ if vae_path is not None:
96
+ print(f"From checkpoint: {vae_path}")
97
+ if vae_path.endswith("safetensors"):
98
+ from safetensors.torch import load_file, safe_open
99
+ state_dict = load_file(vae_path)
100
+ else:
101
+ state_dict = torch.load(vae_path, map_location="cpu")
102
+ state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
103
+
104
+ m, u = vae.load_state_dict(state_dict, strict=False)
105
+ print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
106
+
107
+ # Get Tokenizer
108
+ tokenizer = AutoTokenizer.from_pretrained(
109
+ os.path.join(model_name, config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')),
110
+ )
111
+
112
+ # Get Text encoder
113
+ text_encoder = WanT5EncoderModel.from_pretrained(
114
+ os.path.join(model_name, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
115
+ additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']),
116
+ ).to(weight_dtype)
117
+ text_encoder = text_encoder.eval()
118
+
119
+ # Get Clip Image Encoder
120
+ clip_image_encoder = CLIPModel.from_pretrained(
121
+ os.path.join(model_name, config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')),
122
+ ).to(weight_dtype)
123
+ clip_image_encoder = clip_image_encoder.eval()
124
+
125
+ # Get Scheduler
126
+ Choosen_Scheduler = scheduler_dict = {
127
+ "Flow": FlowMatchEulerDiscreteScheduler,
128
+ }[sampler_name]
129
+ scheduler = Choosen_Scheduler(
130
+ **filter_kwargs(Choosen_Scheduler, OmegaConf.to_container(config['scheduler_kwargs']))
131
+ )
132
+
133
+ # Get Pipeline
134
+ pipeline = WanI2VPipeline(
135
+ transformer=transformer,
136
+ vae=vae,
137
+ tokenizer=tokenizer,
138
+ text_encoder=text_encoder,
139
+ scheduler=scheduler,
140
+ clip_image_encoder=clip_image_encoder
141
+ )
142
+ if GPU_memory_mode == "sequential_cpu_offload":
143
+ replace_parameters_by_name(transformer, ["modulation",], device="cuda")
144
+ transformer.freqs = transformer.freqs.to(device="cuda")
145
+ pipeline.enable_sequential_cpu_offload()
146
+ elif GPU_memory_mode == "model_cpu_offload_and_qfloat8":
147
+ convert_model_weight_to_float8(transformer, exclude_module_name=["modulation",])
148
+ convert_weight_dtype_wrapper(transformer, weight_dtype)
149
+ pipeline.enable_model_cpu_offload()
150
+ else:
151
+ pipeline.enable_model_cpu_offload()
152
+
153
+ generator = torch.Generator(device="cuda").manual_seed(seed)
154
+
155
+ if lora_path is not None:
156
+ pipeline = merge_lora(pipeline, lora_path, lora_weight)
157
+
158
+ with torch.no_grad():
159
+ video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
160
+ latent_frames = (video_length - 1) // vae.config.temporal_compression_ratio + 1
161
+
162
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(validation_image_start, None, video_length=video_length, sample_size=sample_size)
163
+
164
+ sample = pipeline(
165
+ prompt,
166
+ num_frames = video_length,
167
+ negative_prompt = negative_prompt,
168
+ height = sample_size[0],
169
+ width = sample_size[1],
170
+ generator = generator,
171
+ guidance_scale = guidance_scale,
172
+ num_inference_steps = num_inference_steps,
173
+
174
+ video = input_video,
175
+ mask_video = input_video_mask,
176
+ clip_image = clip_image,
177
+ ).videos
178
+
179
+ if lora_path is not None:
180
+ pipeline = unmerge_lora(pipeline, lora_path, lora_weight)
181
+
182
+ if not os.path.exists(save_path):
183
+ os.makedirs(save_path, exist_ok=True)
184
+
185
+ index = len([path for path in os.listdir(save_path)]) + 1
186
+ prefix = str(index).zfill(8)
187
+
188
+ if video_length == 1:
189
+ video_path = os.path.join(save_path, prefix + ".png")
190
+
191
+ image = sample[0, :, 0]
192
+ image = image.transpose(0, 1).transpose(1, 2)
193
+ image = (image * 255).numpy().astype(np.uint8)
194
+ image = Image.fromarray(image)
195
+ image.save(video_path)
196
+ else:
197
+ video_path = os.path.join(save_path, prefix + ".mp4")
198
+ save_videos_grid(sample, video_path, fps=fps)
examples/wan2.1/predict_t2v.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import numpy as np
5
+ import torch
6
+ from diffusers import FlowMatchEulerDiscreteScheduler
7
+ from omegaconf import OmegaConf
8
+ from PIL import Image
9
+
10
+ current_file_path = os.path.abspath(__file__)
11
+ project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))]
12
+ for project_root in project_roots:
13
+ sys.path.insert(0, project_root) if project_root not in sys.path else None
14
+
15
+ from cogvideox.models import (AutoencoderKLWan, WanT5EncoderModel, AutoTokenizer,
16
+ WanTransformer3DModel)
17
+ from cogvideox.pipeline import WanPipeline
18
+ from cogvideox.utils.fp8_optimization import (convert_model_weight_to_float8, replace_parameters_by_name,
19
+ convert_weight_dtype_wrapper)
20
+ from cogvideox.utils.lora_utils import merge_lora, unmerge_lora
21
+ from cogvideox.utils.utils import (filter_kwargs, get_image_to_video_latent,
22
+ save_videos_grid)
23
+
24
+ # GPU memory mode, which can be choosen in [model_cpu_offload, model_cpu_offload_and_qfloat8, sequential_cpu_offload].
25
+ # model_cpu_offload means that the entire model will be moved to the CPU after use, which can save some GPU memory.
26
+ #
27
+ # model_cpu_offload_and_qfloat8 indicates that the entire model will be moved to the CPU after use,
28
+ # and the transformer model has been quantized to float8, which can save more GPU memory.
29
+ #
30
+ # sequential_cpu_offload means that each layer of the model will be moved to the CPU after use,
31
+ # resulting in slower speeds but saving a large amount of GPU memory.
32
+ GPU_memory_mode = "sequential_cpu_offload"
33
+
34
+ # Config and model path
35
+ config_path = "config/wan2.1/wan_civitai.yaml"
36
+ # model path
37
+ model_name = "models/Diffusion_Transformer/Wan2.1-T2V-14B"
38
+
39
+ # Choose the sampler in "Euler" "Euler A" "DPM++" "PNDM" and "DDIM"
40
+ sampler_name = "Flow"
41
+
42
+ # Load pretrained model if need
43
+ transformer_path = None
44
+ vae_path = None
45
+ lora_path = None
46
+
47
+ # Other params
48
+ sample_size = [480, 832]
49
+ video_length = 81
50
+ fps = 16
51
+
52
+ # Use torch.float16 if GPU does not support torch.bfloat16
53
+ # ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16
54
+ weight_dtype = torch.bfloat16
55
+ prompt = "一只棕褐色的狗正摇晃着脑袋,坐在一个舒适的房间里的浅色沙发上。沙发看起来柔软而宽敞,为这只活泼的狗狗提供了一个完美的休息地点。在狗的后面,靠墙摆放着一个架子,架子上挂着一幅精美的镶框画,画中描绘着一些美丽的风景或场景。画框周围装饰着粉红色的花朵,这些花朵不仅增添了房间的色彩,还带来了一丝自然和生机。房间里的灯光柔和而温暖,从天花板上的吊灯和角落里的台灯散发出来,营造出一种温馨舒适的氛围。整个空间给人一种宁静和谐的感觉,仿佛时间在这里变得缓慢而美好。"
56
+ negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
57
+ guidance_scale = 6.0
58
+ seed = 43
59
+ num_inference_steps = 50
60
+ lora_weight = 0.55
61
+ save_path = "samples/wan-videos-t2v"
62
+
63
+ config = OmegaConf.load(config_path)
64
+
65
+ transformer = WanTransformer3DModel.from_pretrained(
66
+ os.path.join(model_name, config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')),
67
+ transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
68
+ low_cpu_mem_usage=True,
69
+ torch_dtype=weight_dtype,
70
+ )
71
+
72
+ if transformer_path is not None:
73
+ print(f"From checkpoint: {transformer_path}")
74
+ if transformer_path.endswith("safetensors"):
75
+ from safetensors.torch import load_file, safe_open
76
+ state_dict = load_file(transformer_path)
77
+ else:
78
+ state_dict = torch.load(transformer_path, map_location="cpu")
79
+ state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
80
+
81
+ m, u = transformer.load_state_dict(state_dict, strict=False)
82
+ print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
83
+
84
+ # Get Vae
85
+ vae = AutoencoderKLWan.from_pretrained(
86
+ os.path.join(model_name, config['vae_kwargs'].get('vae_subpath', 'vae')),
87
+ additional_kwargs=OmegaConf.to_container(config['vae_kwargs']),
88
+ ).to(weight_dtype)
89
+
90
+ if vae_path is not None:
91
+ print(f"From checkpoint: {vae_path}")
92
+ if vae_path.endswith("safetensors"):
93
+ from safetensors.torch import load_file, safe_open
94
+ state_dict = load_file(vae_path)
95
+ else:
96
+ state_dict = torch.load(vae_path, map_location="cpu")
97
+ state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
98
+
99
+ m, u = vae.load_state_dict(state_dict, strict=False)
100
+ print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
101
+
102
+ # Get Tokenizer
103
+ tokenizer = AutoTokenizer.from_pretrained(
104
+ os.path.join(model_name, config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')),
105
+ )
106
+
107
+ # Get Text encoder
108
+ text_encoder = WanT5EncoderModel.from_pretrained(
109
+ os.path.join(model_name, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
110
+ additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']),
111
+ ).to(weight_dtype)
112
+
113
+ # Get Scheduler
114
+ Choosen_Scheduler = scheduler_dict = {
115
+ "Flow": FlowMatchEulerDiscreteScheduler,
116
+ }[sampler_name]
117
+ scheduler = Choosen_Scheduler(
118
+ **filter_kwargs(Choosen_Scheduler, OmegaConf.to_container(config['scheduler_kwargs']))
119
+ )
120
+
121
+ # Get Pipeline
122
+ pipeline = WanPipeline(
123
+ transformer=transformer,
124
+ vae=vae,
125
+ tokenizer=tokenizer,
126
+ text_encoder=text_encoder,
127
+ scheduler=scheduler,
128
+ )
129
+ if GPU_memory_mode == "sequential_cpu_offload":
130
+ replace_parameters_by_name(transformer, ["modulation",], device="cuda")
131
+ transformer.freqs = transformer.freqs.to(device="cuda")
132
+ pipeline.enable_sequential_cpu_offload()
133
+ elif GPU_memory_mode == "model_cpu_offload_and_qfloat8":
134
+ convert_model_weight_to_float8(transformer, exclude_module_name=["modulation",])
135
+ convert_weight_dtype_wrapper(transformer, weight_dtype)
136
+ pipeline.enable_model_cpu_offload()
137
+ else:
138
+ pipeline.enable_model_cpu_offload()
139
+
140
+ generator = torch.Generator(device="cuda").manual_seed(seed)
141
+
142
+ if lora_path is not None:
143
+ pipeline = merge_lora(pipeline, lora_path, lora_weight)
144
+
145
+ with torch.no_grad():
146
+ video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
147
+ latent_frames = (video_length - 1) // vae.config.temporal_compression_ratio + 1
148
+
149
+ sample = pipeline(
150
+ prompt,
151
+ num_frames = video_length,
152
+ negative_prompt = negative_prompt,
153
+ height = sample_size[0],
154
+ width = sample_size[1],
155
+ generator = generator,
156
+ guidance_scale = guidance_scale,
157
+ num_inference_steps = num_inference_steps,
158
+ ).videos
159
+
160
+ if lora_path is not None:
161
+ pipeline = unmerge_lora(pipeline, lora_path, lora_weight)
162
+
163
+ if not os.path.exists(save_path):
164
+ os.makedirs(save_path, exist_ok=True)
165
+
166
+ index = len([path for path in os.listdir(save_path)]) + 1
167
+ prefix = str(index).zfill(8)
168
+
169
+ if video_length == 1:
170
+ video_path = os.path.join(save_path, prefix + ".png")
171
+
172
+ image = sample[0, :, 0]
173
+ image = image.transpose(0, 1).transpose(1, 2)
174
+ image = (image * 255).numpy().astype(np.uint8)
175
+ image = Image.fromarray(image)
176
+ image.save(video_path)
177
+ else:
178
+ video_path = os.path.join(save_path, prefix + ".mp4")
179
+ save_videos_grid(sample, video_path, fps=fps)
examples/wan2.1_fun/app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+
5
+ import torch
6
+
7
+ current_file_path = os.path.abspath(__file__)
8
+ project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))]
9
+ for project_root in project_roots:
10
+ sys.path.insert(0, project_root) if project_root not in sys.path else None
11
+
12
+ from cogvideox.api.api import (infer_forward_api,
13
+ update_diffusion_transformer_api,
14
+ update_edition_api)
15
+ from cogvideox.ui.controller import flow_scheduler_dict
16
+ from cogvideox.ui.wan_fun_ui import ui, ui_eas, ui_modelscope
17
+
18
+ if __name__ == "__main__":
19
+ # Choose the ui mode
20
+ ui_mode = "eas"
21
+
22
+ # GPU memory mode, which can be choosen in [model_cpu_offload, model_cpu_offload_and_qfloat8, sequential_cpu_offload].
23
+ # model_cpu_offload means that the entire model will be moved to the CPU after use, which can save some GPU memory.
24
+ #
25
+ # model_cpu_offload_and_qfloat8 indicates that the entire model will be moved to the CPU after use,
26
+ # and the transformer model has been quantized to float8, which can save more GPU memory.
27
+ #
28
+ # sequential_cpu_offload means that each layer of the model will be moved to the CPU after use,
29
+ # resulting in slower speeds but saving a large amount of GPU memory.
30
+ GPU_memory_mode = "model_cpu_offload"
31
+ # Use torch.float16 if GPU does not support torch.bfloat16
32
+ # ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16
33
+ weight_dtype = torch.bfloat16
34
+ # Config path
35
+ config_path = "config/wan2.1/wan_civitai.yaml"
36
+
37
+ # Server ip
38
+ server_name = "0.0.0.0"
39
+ server_port = 7860
40
+
41
+ # Params below is used when ui_mode = "modelscope"
42
+ model_name = "models/Diffusion_Transformer/Wan2.1-Fun-1.3B-InP"
43
+ # "Inpaint" or "Control"
44
+ model_type = "Inpaint"
45
+ # Save dir of this model
46
+ savedir_sample = "samples"
47
+
48
+ if ui_mode == "modelscope":
49
+ demo, controller = ui_modelscope(model_name, model_type, savedir_sample, GPU_memory_mode, flow_scheduler_dict, weight_dtype, config_path)
50
+ elif ui_mode == "eas":
51
+ demo, controller = ui_eas(model_name, flow_scheduler_dict, savedir_sample, config_path)
52
+ else:
53
+ demo, controller = ui(GPU_memory_mode, flow_scheduler_dict, weight_dtype, config_path)
54
+
55
+ # launch gradio
56
+ app, _, _ = demo.queue(status_update_rate=1).launch(
57
+ server_name=server_name,
58
+ server_port=server_port,
59
+ prevent_thread_lock=True
60
+ )
61
+
62
+ # launch api
63
+ infer_forward_api(None, app, controller)
64
+ update_diffusion_transformer_api(None, app, controller)
65
+ update_edition_api(None, app, controller)
66
+
67
+ # not close the python
68
+ while True:
69
+ time.sleep(5)
examples/wan2.1_fun/predict_i2v.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import numpy as np
5
+ import torch
6
+ from diffusers import FlowMatchEulerDiscreteScheduler
7
+ from omegaconf import OmegaConf
8
+ from PIL import Image
9
+ from transformers import AutoTokenizer
10
+
11
+ current_file_path = os.path.abspath(__file__)
12
+ project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))]
13
+ for project_root in project_roots:
14
+ sys.path.insert(0, project_root) if project_root not in sys.path else None
15
+
16
+ from cogvideox.models import (AutoencoderKLWan, CLIPModel, WanT5EncoderModel,
17
+ WanTransformer3DModel)
18
+ from cogvideox.pipeline import WanFunInpaintPipeline
19
+ from cogvideox.utils.fp8_optimization import (convert_model_weight_to_float8, replace_parameters_by_name,
20
+ convert_weight_dtype_wrapper)
21
+ from cogvideox.utils.lora_utils import merge_lora, unmerge_lora
22
+ from cogvideox.utils.utils import (filter_kwargs, get_image_to_video_latent,
23
+ save_videos_grid)
24
+
25
+ # GPU memory mode, which can be choosen in [model_cpu_offload, model_cpu_offload_and_qfloat8, sequential_cpu_offload].
26
+ # model_cpu_offload means that the entire model will be moved to the CPU after use, which can save some GPU memory.
27
+ #
28
+ # model_cpu_offload_and_qfloat8 indicates that the entire model will be moved to the CPU after use,
29
+ # and the transformer model has been quantized to float8, which can save more GPU memory.
30
+ #
31
+ # sequential_cpu_offload means that each layer of the model will be moved to the CPU after use,
32
+ # resulting in slower speeds but saving a large amount of GPU memory.
33
+ GPU_memory_mode = "sequential_cpu_offload"
34
+
35
+ # Config and model path
36
+ config_path = "config/wan2.1/wan_civitai.yaml"
37
+ # model path
38
+ model_name = "models/Diffusion_Transformer/Wan2.1-Fun-1.3B-InP"
39
+
40
+ # Choose the sampler in "Euler" "Euler A" "DPM++" "PNDM" and "DDIM"
41
+ sampler_name = "Flow"
42
+
43
+ # Load pretrained model if need
44
+ transformer_path = None
45
+ vae_path = None
46
+ lora_path = None
47
+
48
+ # Other params
49
+ sample_size = [480, 832]
50
+ video_length = 81
51
+ fps = 16
52
+
53
+ # Use torch.float16 if GPU does not support torch.bfloat16
54
+ # ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16
55
+ weight_dtype = torch.bfloat16
56
+ # If you want to generate from text, please set the validation_image_start = None and validation_image_end = None
57
+ validation_image_start = "asset/1.png"
58
+ validation_image_end = None
59
+
60
+ # prompts
61
+ prompt = "一只棕褐色的狗正摇晃着脑袋,坐在一个舒适的房间里的浅色沙发上。沙发看起来柔软而宽敞,为这只活泼的狗狗提供了一个完美的休息地点。在狗的后面,靠墙摆放着一个架子,架子上挂着一幅精美的镶框画,画中描绘着一些美丽的风景或场景。画框周围装饰着粉红色的花朵,这些花朵不仅增添了房间的色彩,还带来了一丝自然和生机。房间里的灯光柔和而温暖,从天花板上的吊灯和角落里的台灯散发出来,营造出一种温馨舒适的氛围。整个空间给人一种宁静和谐的感觉,仿佛时间在这里变得缓慢而美好。"
62
+ negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
63
+ guidance_scale = 6.0
64
+ seed = 43
65
+ num_inference_steps = 50
66
+ lora_weight = 0.55
67
+ save_path = "samples/wan-videos-fun-i2v"
68
+
69
+ config = OmegaConf.load(config_path)
70
+
71
+ transformer = WanTransformer3DModel.from_pretrained(
72
+ os.path.join(model_name, config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')),
73
+ transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
74
+ low_cpu_mem_usage=True,
75
+ torch_dtype=weight_dtype,
76
+ )
77
+
78
+ if transformer_path is not None:
79
+ print(f"From checkpoint: {transformer_path}")
80
+ if transformer_path.endswith("safetensors"):
81
+ from safetensors.torch import load_file, safe_open
82
+ state_dict = load_file(transformer_path)
83
+ else:
84
+ state_dict = torch.load(transformer_path, map_location="cpu")
85
+ state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
86
+
87
+ m, u = transformer.load_state_dict(state_dict, strict=False)
88
+ print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
89
+
90
+ # Get Vae
91
+ vae = AutoencoderKLWan.from_pretrained(
92
+ os.path.join(model_name, config['vae_kwargs'].get('vae_subpath', 'vae')),
93
+ additional_kwargs=OmegaConf.to_container(config['vae_kwargs']),
94
+ ).to(weight_dtype)
95
+
96
+ if vae_path is not None:
97
+ print(f"From checkpoint: {vae_path}")
98
+ if vae_path.endswith("safetensors"):
99
+ from safetensors.torch import load_file, safe_open
100
+ state_dict = load_file(vae_path)
101
+ else:
102
+ state_dict = torch.load(vae_path, map_location="cpu")
103
+ state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
104
+
105
+ m, u = vae.load_state_dict(state_dict, strict=False)
106
+ print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
107
+
108
+ # Get Tokenizer
109
+ tokenizer = AutoTokenizer.from_pretrained(
110
+ os.path.join(model_name, config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')),
111
+ )
112
+
113
+ # Get Text encoder
114
+ text_encoder = WanT5EncoderModel.from_pretrained(
115
+ os.path.join(model_name, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
116
+ additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']),
117
+ ).to(weight_dtype)
118
+ text_encoder = text_encoder.eval()
119
+
120
+ # Get Clip Image Encoder
121
+ clip_image_encoder = CLIPModel.from_pretrained(
122
+ os.path.join(model_name, config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')),
123
+ ).to(weight_dtype)
124
+ clip_image_encoder = clip_image_encoder.eval()
125
+
126
+ # Get Scheduler
127
+ Choosen_Scheduler = scheduler_dict = {
128
+ "Flow": FlowMatchEulerDiscreteScheduler,
129
+ }[sampler_name]
130
+ scheduler = Choosen_Scheduler(
131
+ **filter_kwargs(Choosen_Scheduler, OmegaConf.to_container(config['scheduler_kwargs']))
132
+ )
133
+
134
+ # Get Pipeline
135
+ pipeline = WanFunInpaintPipeline(
136
+ transformer=transformer,
137
+ vae=vae,
138
+ tokenizer=tokenizer,
139
+ text_encoder=text_encoder,
140
+ scheduler=scheduler,
141
+ clip_image_encoder=clip_image_encoder
142
+ )
143
+ if GPU_memory_mode == "sequential_cpu_offload":
144
+ replace_parameters_by_name(transformer, ["modulation",], device="cuda")
145
+ transformer.freqs = transformer.freqs.to(device="cuda")
146
+ pipeline.enable_sequential_cpu_offload()
147
+ elif GPU_memory_mode == "model_cpu_offload_and_qfloat8":
148
+ convert_model_weight_to_float8(transformer, exclude_module_name=["modulation",])
149
+ convert_weight_dtype_wrapper(transformer, weight_dtype)
150
+ pipeline.enable_model_cpu_offload()
151
+ else:
152
+ pipeline.enable_model_cpu_offload()
153
+
154
+ generator = torch.Generator(device="cuda").manual_seed(seed)
155
+
156
+ if lora_path is not None:
157
+ pipeline = merge_lora(pipeline, lora_path, lora_weight)
158
+
159
+ with torch.no_grad():
160
+ video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
161
+ latent_frames = (video_length - 1) // vae.config.temporal_compression_ratio + 1
162
+
163
+ input_video, input_video_mask, clip_image = get_image_to_video_latent(validation_image_start, validation_image_end, video_length=video_length, sample_size=sample_size)
164
+
165
+ sample = pipeline(
166
+ prompt,
167
+ num_frames = video_length,
168
+ negative_prompt = negative_prompt,
169
+ height = sample_size[0],
170
+ width = sample_size[1],
171
+ generator = generator,
172
+ guidance_scale = guidance_scale,
173
+ num_inference_steps = num_inference_steps,
174
+
175
+ video = input_video,
176
+ mask_video = input_video_mask,
177
+ clip_image = clip_image,
178
+ ).videos
179
+
180
+ if lora_path is not None:
181
+ pipeline = unmerge_lora(pipeline, lora_path, lora_weight)
182
+
183
+ if not os.path.exists(save_path):
184
+ os.makedirs(save_path, exist_ok=True)
185
+
186
+ index = len([path for path in os.listdir(save_path)]) + 1
187
+ prefix = str(index).zfill(8)
188
+
189
+ if video_length == 1:
190
+ video_path = os.path.join(save_path, prefix + ".png")
191
+
192
+ image = sample[0, :, 0]
193
+ image = image.transpose(0, 1).transpose(1, 2)
194
+ image = (image * 255).numpy().astype(np.uint8)
195
+ image = Image.fromarray(image)
196
+ image.save(video_path)
197
+ else:
198
+ video_path = os.path.join(save_path, prefix + ".mp4")
199
+ save_videos_grid(sample, video_path, fps=fps)