JasonSmithSO commited on
Commit
33585de
·
verified ·
1 Parent(s): aebc1d4

Upload 124 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. CODEOWNERS +1 -0
  2. api_server/__init__.py +0 -0
  3. api_server/routes/__init__.py +0 -0
  4. api_server/routes/internal/README.md +3 -0
  5. api_server/routes/internal/__init__.py +0 -0
  6. api_server/routes/internal/internal_routes.py +44 -0
  7. api_server/services/__init__.py +0 -0
  8. api_server/services/file_service.py +13 -0
  9. api_server/utils/file_operations.py +42 -0
  10. app/__init__.py +0 -0
  11. app/app_settings.py +54 -0
  12. app/frontend_management.py +195 -0
  13. app/logger.py +31 -0
  14. app/user_manager.py +232 -0
  15. comfy/checkpoint_pickle.py +13 -0
  16. comfy/cldm/cldm.py +437 -0
  17. comfy/cldm/control_types.py +10 -0
  18. comfy/cldm/mmdit.py +81 -0
  19. comfy/cli_args.py +192 -0
  20. comfy/clip_config_bigg.json +23 -0
  21. comfy/clip_model.py +196 -0
  22. comfy/clip_vision.py +121 -0
  23. comfy/clip_vision_config_g.json +18 -0
  24. comfy/clip_vision_config_h.json +18 -0
  25. comfy/clip_vision_config_vitl.json +18 -0
  26. comfy/clip_vision_config_vitl_336.json +18 -0
  27. comfy/comfy_types.py +32 -0
  28. comfy/conds.py +83 -0
  29. comfy/controlnet.py +737 -0
  30. comfy/diffusers_convert.py +281 -0
  31. comfy/diffusers_load.py +36 -0
  32. comfy/extra_samplers/uni_pc.py +875 -0
  33. comfy/float.py +66 -0
  34. comfy/gligen.py +343 -0
  35. comfy/k_diffusion/deis.py +121 -0
  36. comfy/k_diffusion/sampling.py +1145 -0
  37. comfy/k_diffusion/utils.py +313 -0
  38. comfy/latent_formats.py +172 -0
  39. comfy/ldm/audio/autoencoder.py +282 -0
  40. comfy/ldm/audio/dit.py +891 -0
  41. comfy/ldm/audio/embedders.py +108 -0
  42. comfy/ldm/aura/mmdit.py +478 -0
  43. comfy/ldm/cascade/common.py +154 -0
  44. comfy/ldm/cascade/controlnet.py +93 -0
  45. comfy/ldm/cascade/stage_a.py +255 -0
  46. comfy/ldm/cascade/stage_b.py +256 -0
  47. comfy/ldm/cascade/stage_c.py +273 -0
  48. comfy/ldm/cascade/stage_c_coder.py +95 -0
  49. comfy/ldm/common_dit.py +21 -0
  50. comfy/ldm/flux/controlnet.py +205 -0
CODEOWNERS ADDED
@@ -0,0 +1 @@
 
 
1
+ * @comfyanonymous
api_server/__init__.py ADDED
File without changes
api_server/routes/__init__.py ADDED
File without changes
api_server/routes/internal/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # ComfyUI Internal Routes
2
+
3
+ All routes under the `/internal` path are designated for **internal use by ComfyUI only**. These routes are not intended for use by external applications may change at any time without notice.
api_server/routes/internal/__init__.py ADDED
File without changes
api_server/routes/internal/internal_routes.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from aiohttp import web
2
+ from typing import Optional
3
+ from folder_paths import models_dir, user_directory, output_directory
4
+ from api_server.services.file_service import FileService
5
+ import app.logger
6
+
7
+ class InternalRoutes:
8
+ '''
9
+ The top level web router for internal routes: /internal/*
10
+ The endpoints here should NOT be depended upon. It is for ComfyUI frontend use only.
11
+ Check README.md for more information.
12
+
13
+ '''
14
+ def __init__(self):
15
+ self.routes: web.RouteTableDef = web.RouteTableDef()
16
+ self._app: Optional[web.Application] = None
17
+ self.file_service = FileService({
18
+ "models": models_dir,
19
+ "user": user_directory,
20
+ "output": output_directory
21
+ })
22
+
23
+ def setup_routes(self):
24
+ @self.routes.get('/files')
25
+ async def list_files(request):
26
+ directory_key = request.query.get('directory', '')
27
+ try:
28
+ file_list = self.file_service.list_files(directory_key)
29
+ return web.json_response({"files": file_list})
30
+ except ValueError as e:
31
+ return web.json_response({"error": str(e)}, status=400)
32
+ except Exception as e:
33
+ return web.json_response({"error": str(e)}, status=500)
34
+
35
+ @self.routes.get('/logs')
36
+ async def get_logs(request):
37
+ return web.json_response(app.logger.get_logs())
38
+
39
+ def get_app(self):
40
+ if self._app is None:
41
+ self._app = web.Application()
42
+ self.setup_routes()
43
+ self._app.add_routes(self.routes)
44
+ return self._app
api_server/services/__init__.py ADDED
File without changes
api_server/services/file_service.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional
2
+ from api_server.utils.file_operations import FileSystemOperations, FileSystemItem
3
+
4
+ class FileService:
5
+ def __init__(self, allowed_directories: Dict[str, str], file_system_ops: Optional[FileSystemOperations] = None):
6
+ self.allowed_directories: Dict[str, str] = allowed_directories
7
+ self.file_system_ops: FileSystemOperations = file_system_ops or FileSystemOperations()
8
+
9
+ def list_files(self, directory_key: str) -> List[FileSystemItem]:
10
+ if directory_key not in self.allowed_directories:
11
+ raise ValueError("Invalid directory key")
12
+ directory_path: str = self.allowed_directories[directory_key]
13
+ return self.file_system_ops.walk_directory(directory_path)
api_server/utils/file_operations.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Union, TypedDict, Literal
3
+ from typing_extensions import TypeGuard
4
+ class FileInfo(TypedDict):
5
+ name: str
6
+ path: str
7
+ type: Literal["file"]
8
+ size: int
9
+
10
+ class DirectoryInfo(TypedDict):
11
+ name: str
12
+ path: str
13
+ type: Literal["directory"]
14
+
15
+ FileSystemItem = Union[FileInfo, DirectoryInfo]
16
+
17
+ def is_file_info(item: FileSystemItem) -> TypeGuard[FileInfo]:
18
+ return item["type"] == "file"
19
+
20
+ class FileSystemOperations:
21
+ @staticmethod
22
+ def walk_directory(directory: str) -> List[FileSystemItem]:
23
+ file_list: List[FileSystemItem] = []
24
+ for root, dirs, files in os.walk(directory):
25
+ for name in files:
26
+ file_path = os.path.join(root, name)
27
+ relative_path = os.path.relpath(file_path, directory)
28
+ file_list.append({
29
+ "name": name,
30
+ "path": relative_path,
31
+ "type": "file",
32
+ "size": os.path.getsize(file_path)
33
+ })
34
+ for name in dirs:
35
+ dir_path = os.path.join(root, name)
36
+ relative_path = os.path.relpath(dir_path, directory)
37
+ file_list.append({
38
+ "name": name,
39
+ "path": relative_path,
40
+ "type": "directory"
41
+ })
42
+ return file_list
app/__init__.py ADDED
File without changes
app/app_settings.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from aiohttp import web
4
+
5
+
6
+ class AppSettings():
7
+ def __init__(self, user_manager):
8
+ self.user_manager = user_manager
9
+
10
+ def get_settings(self, request):
11
+ file = self.user_manager.get_request_user_filepath(
12
+ request, "comfy.settings.json")
13
+ if os.path.isfile(file):
14
+ with open(file) as f:
15
+ return json.load(f)
16
+ else:
17
+ return {}
18
+
19
+ def save_settings(self, request, settings):
20
+ file = self.user_manager.get_request_user_filepath(
21
+ request, "comfy.settings.json")
22
+ with open(file, "w") as f:
23
+ f.write(json.dumps(settings, indent=4))
24
+
25
+ def add_routes(self, routes):
26
+ @routes.get("/settings")
27
+ async def get_settings(request):
28
+ return web.json_response(self.get_settings(request))
29
+
30
+ @routes.get("/settings/{id}")
31
+ async def get_setting(request):
32
+ value = None
33
+ settings = self.get_settings(request)
34
+ setting_id = request.match_info.get("id", None)
35
+ if setting_id and setting_id in settings:
36
+ value = settings[setting_id]
37
+ return web.json_response(value)
38
+
39
+ @routes.post("/settings")
40
+ async def post_settings(request):
41
+ settings = self.get_settings(request)
42
+ new_settings = await request.json()
43
+ self.save_settings(request, {**settings, **new_settings})
44
+ return web.Response(status=200)
45
+
46
+ @routes.post("/settings/{id}")
47
+ async def post_setting(request):
48
+ setting_id = request.match_info.get("id", None)
49
+ if not setting_id:
50
+ return web.Response(status=400)
51
+ settings = self.get_settings(request)
52
+ settings[setting_id] = await request.json()
53
+ self.save_settings(request, settings)
54
+ return web.Response(status=200)
app/frontend_management.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import argparse
3
+ import logging
4
+ import os
5
+ import re
6
+ import tempfile
7
+ import zipfile
8
+ from dataclasses import dataclass
9
+ from functools import cached_property
10
+ from pathlib import Path
11
+ from typing import TypedDict, Optional
12
+
13
+ import requests
14
+ from typing_extensions import NotRequired
15
+ from comfy.cli_args import DEFAULT_VERSION_STRING
16
+
17
+
18
+ REQUEST_TIMEOUT = 10 # seconds
19
+
20
+
21
+ class Asset(TypedDict):
22
+ url: str
23
+
24
+
25
+ class Release(TypedDict):
26
+ id: int
27
+ tag_name: str
28
+ name: str
29
+ prerelease: bool
30
+ created_at: str
31
+ published_at: str
32
+ body: str
33
+ assets: NotRequired[list[Asset]]
34
+
35
+
36
+ @dataclass
37
+ class FrontEndProvider:
38
+ owner: str
39
+ repo: str
40
+
41
+ @property
42
+ def folder_name(self) -> str:
43
+ return f"{self.owner}_{self.repo}"
44
+
45
+ @property
46
+ def release_url(self) -> str:
47
+ return f"https://api.github.com/repos/{self.owner}/{self.repo}/releases"
48
+
49
+ @cached_property
50
+ def all_releases(self) -> list[Release]:
51
+ releases = []
52
+ api_url = self.release_url
53
+ while api_url:
54
+ response = requests.get(api_url, timeout=REQUEST_TIMEOUT)
55
+ response.raise_for_status() # Raises an HTTPError if the response was an error
56
+ releases.extend(response.json())
57
+ # GitHub uses the Link header to provide pagination links. Check if it exists and update api_url accordingly.
58
+ if "next" in response.links:
59
+ api_url = response.links["next"]["url"]
60
+ else:
61
+ api_url = None
62
+ return releases
63
+
64
+ @cached_property
65
+ def latest_release(self) -> Release:
66
+ latest_release_url = f"{self.release_url}/latest"
67
+ response = requests.get(latest_release_url, timeout=REQUEST_TIMEOUT)
68
+ response.raise_for_status() # Raises an HTTPError if the response was an error
69
+ return response.json()
70
+
71
+ def get_release(self, version: str) -> Release:
72
+ if version == "latest":
73
+ return self.latest_release
74
+ else:
75
+ for release in self.all_releases:
76
+ if release["tag_name"] in [version, f"v{version}"]:
77
+ return release
78
+ raise ValueError(f"Version {version} not found in releases")
79
+
80
+
81
+ def download_release_asset_zip(release: Release, destination_path: str) -> None:
82
+ """Download dist.zip from github release."""
83
+ asset_url = None
84
+ for asset in release.get("assets", []):
85
+ if asset["name"] == "dist.zip":
86
+ asset_url = asset["url"]
87
+ break
88
+
89
+ if not asset_url:
90
+ raise ValueError("dist.zip not found in the release assets")
91
+
92
+ # Use a temporary file to download the zip content
93
+ with tempfile.TemporaryFile() as tmp_file:
94
+ headers = {"Accept": "application/octet-stream"}
95
+ response = requests.get(
96
+ asset_url, headers=headers, allow_redirects=True, timeout=REQUEST_TIMEOUT
97
+ )
98
+ response.raise_for_status() # Ensure we got a successful response
99
+
100
+ # Write the content to the temporary file
101
+ tmp_file.write(response.content)
102
+
103
+ # Go back to the beginning of the temporary file
104
+ tmp_file.seek(0)
105
+
106
+ # Extract the zip file content to the destination path
107
+ with zipfile.ZipFile(tmp_file, "r") as zip_ref:
108
+ zip_ref.extractall(destination_path)
109
+
110
+
111
+ class FrontendManager:
112
+ DEFAULT_FRONTEND_PATH = str(Path(__file__).parents[1] / "web")
113
+ CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
114
+
115
+ @classmethod
116
+ def parse_version_string(cls, value: str) -> tuple[str, str, str]:
117
+ """
118
+ Args:
119
+ value (str): The version string to parse.
120
+
121
+ Returns:
122
+ tuple[str, str]: A tuple containing provider name and version.
123
+
124
+ Raises:
125
+ argparse.ArgumentTypeError: If the version string is invalid.
126
+ """
127
+ VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+|latest)$"
128
+ match_result = re.match(VERSION_PATTERN, value)
129
+ if match_result is None:
130
+ raise argparse.ArgumentTypeError(f"Invalid version string: {value}")
131
+
132
+ return match_result.group(1), match_result.group(2), match_result.group(3)
133
+
134
+ @classmethod
135
+ def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str:
136
+ """
137
+ Initializes the frontend for the specified version.
138
+
139
+ Args:
140
+ version_string (str): The version string.
141
+ provider (FrontEndProvider, optional): The provider to use. Defaults to None.
142
+
143
+ Returns:
144
+ str: The path to the initialized frontend.
145
+
146
+ Raises:
147
+ Exception: If there is an error during the initialization process.
148
+ main error source might be request timeout or invalid URL.
149
+ """
150
+ if version_string == DEFAULT_VERSION_STRING:
151
+ return cls.DEFAULT_FRONTEND_PATH
152
+
153
+ repo_owner, repo_name, version = cls.parse_version_string(version_string)
154
+ provider = provider or FrontEndProvider(repo_owner, repo_name)
155
+ release = provider.get_release(version)
156
+
157
+ semantic_version = release["tag_name"].lstrip("v")
158
+ web_root = str(
159
+ Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
160
+ )
161
+ if not os.path.exists(web_root):
162
+ try:
163
+ os.makedirs(web_root, exist_ok=True)
164
+ logging.info(
165
+ "Downloading frontend(%s) version(%s) to (%s)",
166
+ provider.folder_name,
167
+ semantic_version,
168
+ web_root,
169
+ )
170
+ logging.debug(release)
171
+ download_release_asset_zip(release, destination_path=web_root)
172
+ finally:
173
+ # Clean up the directory if it is empty, i.e. the download failed
174
+ if not os.listdir(web_root):
175
+ os.rmdir(web_root)
176
+
177
+ return web_root
178
+
179
+ @classmethod
180
+ def init_frontend(cls, version_string: str) -> str:
181
+ """
182
+ Initializes the frontend with the specified version string.
183
+
184
+ Args:
185
+ version_string (str): The version string to initialize the frontend with.
186
+
187
+ Returns:
188
+ str: The path of the initialized frontend.
189
+ """
190
+ try:
191
+ return cls.init_frontend_unsafe(version_string)
192
+ except Exception as e:
193
+ logging.error("Failed to initialize frontend: %s", e)
194
+ logging.info("Falling back to the default frontend.")
195
+ return cls.DEFAULT_FRONTEND_PATH
app/logger.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from logging.handlers import MemoryHandler
3
+ from collections import deque
4
+
5
+ logs = None
6
+ formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
7
+
8
+
9
+ def get_logs():
10
+ return "\n".join([formatter.format(x) for x in logs])
11
+
12
+
13
+ def setup_logger(verbose: bool = False, capacity: int = 300):
14
+ global logs
15
+ if logs:
16
+ return
17
+
18
+ # Setup default global logger
19
+ logger = logging.getLogger()
20
+ logger.setLevel(logging.DEBUG if verbose else logging.INFO)
21
+
22
+ stream_handler = logging.StreamHandler()
23
+ stream_handler.setFormatter(logging.Formatter("[Comfyd] %(message)s"))
24
+ logger.addHandler(stream_handler)
25
+
26
+ # Create a memory handler with a deque as its buffer
27
+ logs = deque(maxlen=capacity)
28
+ memory_handler = MemoryHandler(capacity, flushLevel=logging.INFO)
29
+ memory_handler.buffer = logs
30
+ memory_handler.setFormatter(formatter)
31
+ logger.addHandler(memory_handler)
app/user_manager.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ import uuid
5
+ import glob
6
+ import shutil
7
+ from aiohttp import web
8
+ from urllib import parse
9
+ from comfy.cli_args import args
10
+ import folder_paths
11
+ from .app_settings import AppSettings
12
+
13
+ default_user = "default"
14
+
15
+
16
+ class UserManager():
17
+ def __init__(self):
18
+ user_directory = folder_paths.get_user_directory()
19
+
20
+ self.settings = AppSettings(self)
21
+ if not os.path.exists(user_directory):
22
+ os.mkdir(user_directory)
23
+ if not args.multi_user:
24
+ print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
25
+ print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
26
+
27
+ if args.multi_user:
28
+ if os.path.isfile(self.get_users_file()):
29
+ with open(self.get_users_file()) as f:
30
+ self.users = json.load(f)
31
+ else:
32
+ self.users = {}
33
+ else:
34
+ self.users = {"default": "default"}
35
+
36
+ def get_users_file(self):
37
+ return os.path.join(folder_paths.get_user_directory(), "users.json")
38
+
39
+ def get_request_user_id(self, request):
40
+ user = "default"
41
+ if args.multi_user and "comfy-user" in request.headers:
42
+ user = request.headers["comfy-user"]
43
+
44
+ if user not in self.users:
45
+ raise KeyError("Unknown user: " + user)
46
+
47
+ return user
48
+
49
+ def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
50
+ user_directory = folder_paths.get_user_directory()
51
+
52
+ if type == "userdata":
53
+ root_dir = user_directory
54
+ else:
55
+ raise KeyError("Unknown filepath type:" + type)
56
+
57
+ user = self.get_request_user_id(request)
58
+ path = user_root = os.path.abspath(os.path.join(root_dir, user))
59
+
60
+ # prevent leaving /{type}
61
+ if os.path.commonpath((root_dir, user_root)) != root_dir:
62
+ return None
63
+
64
+ if file is not None:
65
+ # Check if filename is url encoded
66
+ if "%" in file:
67
+ file = parse.unquote(file)
68
+
69
+ # prevent leaving /{type}/{user}
70
+ path = os.path.abspath(os.path.join(user_root, file))
71
+ if os.path.commonpath((user_root, path)) != user_root:
72
+ return None
73
+
74
+ parent = os.path.split(path)[0]
75
+
76
+ if create_dir and not os.path.exists(parent):
77
+ os.makedirs(parent, exist_ok=True)
78
+
79
+ return path
80
+
81
+ def add_user(self, name):
82
+ name = name.strip()
83
+ if not name:
84
+ raise ValueError("username not provided")
85
+ user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name)
86
+ user_id = user_id + "_" + str(uuid.uuid4())
87
+
88
+ self.users[user_id] = name
89
+
90
+ with open(self.get_users_file(), "w") as f:
91
+ json.dump(self.users, f)
92
+
93
+ return user_id
94
+
95
+ def add_routes(self, routes):
96
+ self.settings.add_routes(routes)
97
+
98
+ @routes.get("/users")
99
+ async def get_users(request):
100
+ if args.multi_user:
101
+ return web.json_response({"storage": "server", "users": self.users})
102
+ else:
103
+ user_dir = self.get_request_user_filepath(request, None, create_dir=False)
104
+ return web.json_response({
105
+ "storage": "server",
106
+ "migrated": os.path.exists(user_dir)
107
+ })
108
+
109
+ @routes.post("/users")
110
+ async def post_users(request):
111
+ body = await request.json()
112
+ username = body["username"]
113
+ if username in self.users.values():
114
+ return web.json_response({"error": "Duplicate username."}, status=400)
115
+
116
+ user_id = self.add_user(username)
117
+ return web.json_response(user_id)
118
+
119
+ @routes.get("/userdata")
120
+ async def listuserdata(request):
121
+ directory = request.rel_url.query.get('dir', '')
122
+ if not directory:
123
+ return web.Response(status=400, text="Directory not provided")
124
+
125
+ path = self.get_request_user_filepath(request, directory)
126
+ if not path:
127
+ return web.Response(status=403, text="Invalid directory")
128
+
129
+ if not os.path.exists(path):
130
+ return web.Response(status=404, text="Directory not found")
131
+
132
+ recurse = request.rel_url.query.get('recurse', '').lower() == "true"
133
+ full_info = request.rel_url.query.get('full_info', '').lower() == "true"
134
+
135
+ # Use different patterns based on whether we're recursing or not
136
+ if recurse:
137
+ pattern = os.path.join(glob.escape(path), '**', '*')
138
+ else:
139
+ pattern = os.path.join(glob.escape(path), '*')
140
+
141
+ results = glob.glob(pattern, recursive=recurse)
142
+
143
+ if full_info:
144
+ results = [
145
+ {
146
+ 'path': os.path.relpath(x, path).replace(os.sep, '/'),
147
+ 'size': os.path.getsize(x),
148
+ 'modified': os.path.getmtime(x)
149
+ } for x in results if os.path.isfile(x)
150
+ ]
151
+ else:
152
+ results = [
153
+ os.path.relpath(x, path).replace(os.sep, '/')
154
+ for x in results
155
+ if os.path.isfile(x)
156
+ ]
157
+
158
+ split_path = request.rel_url.query.get('split', '').lower() == "true"
159
+ if split_path and not full_info:
160
+ results = [[x] + x.split('/') for x in results]
161
+
162
+ return web.json_response(results)
163
+
164
+ def get_user_data_path(request, check_exists = False, param = "file"):
165
+ file = request.match_info.get(param, None)
166
+ if not file:
167
+ return web.Response(status=400)
168
+
169
+ path = self.get_request_user_filepath(request, file)
170
+ if not path:
171
+ return web.Response(status=403)
172
+
173
+ if check_exists and not os.path.exists(path):
174
+ return web.Response(status=404)
175
+
176
+ return path
177
+
178
+ @routes.get("/userdata/{file}")
179
+ async def getuserdata(request):
180
+ path = get_user_data_path(request, check_exists=True)
181
+ if not isinstance(path, str):
182
+ return path
183
+
184
+ return web.FileResponse(path)
185
+
186
+ @routes.post("/userdata/{file}")
187
+ async def post_userdata(request):
188
+ path = get_user_data_path(request)
189
+ if not isinstance(path, str):
190
+ return path
191
+
192
+ overwrite = request.query["overwrite"] != "false"
193
+ if not overwrite and os.path.exists(path):
194
+ return web.Response(status=409)
195
+
196
+ body = await request.read()
197
+
198
+ with open(path, "wb") as f:
199
+ f.write(body)
200
+
201
+ resp = os.path.relpath(path, self.get_request_user_filepath(request, None))
202
+ return web.json_response(resp)
203
+
204
+ @routes.delete("/userdata/{file}")
205
+ async def delete_userdata(request):
206
+ path = get_user_data_path(request, check_exists=True)
207
+ if not isinstance(path, str):
208
+ return path
209
+
210
+ os.remove(path)
211
+
212
+ return web.Response(status=204)
213
+
214
+ @routes.post("/userdata/{file}/move/{dest}")
215
+ async def move_userdata(request):
216
+ source = get_user_data_path(request, check_exists=True)
217
+ if not isinstance(source, str):
218
+ return source
219
+
220
+ dest = get_user_data_path(request, check_exists=False, param="dest")
221
+ if not isinstance(source, str):
222
+ return dest
223
+
224
+ overwrite = request.query["overwrite"] != "false"
225
+ if not overwrite and os.path.exists(dest):
226
+ return web.Response(status=409)
227
+
228
+ print(f"moving '{source}' -> '{dest}'")
229
+ shutil.move(source, dest)
230
+
231
+ resp = os.path.relpath(dest, self.get_request_user_filepath(request, None))
232
+ return web.json_response(resp)
comfy/checkpoint_pickle.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+
3
+ load = pickle.load
4
+
5
+ class Empty:
6
+ pass
7
+
8
+ class Unpickler(pickle.Unpickler):
9
+ def find_class(self, module, name):
10
+ #TODO: safe unpickle
11
+ if module.startswith("pytorch_lightning"):
12
+ return Empty
13
+ return super().find_class(module, name)
comfy/cldm/cldm.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #taken from: https://github.com/lllyasviel/ControlNet
2
+ #and modified
3
+
4
+ import torch
5
+ import torch as th
6
+ import torch.nn as nn
7
+
8
+ from ..ldm.modules.diffusionmodules.util import (
9
+ zero_module,
10
+ timestep_embedding,
11
+ )
12
+
13
+ from ..ldm.modules.attention import SpatialTransformer
14
+ from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
15
+ from ..ldm.util import exists
16
+ from .control_types import UNION_CONTROLNET_TYPES
17
+ from collections import OrderedDict
18
+ import comfy.ops
19
+ from comfy.ldm.modules.attention import optimized_attention
20
+
21
+ class OptimizedAttention(nn.Module):
22
+ def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
23
+ super().__init__()
24
+ self.heads = nhead
25
+ self.c = c
26
+
27
+ self.in_proj = operations.Linear(c, c * 3, bias=True, dtype=dtype, device=device)
28
+ self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
29
+
30
+ def forward(self, x):
31
+ x = self.in_proj(x)
32
+ q, k, v = x.split(self.c, dim=2)
33
+ out = optimized_attention(q, k, v, self.heads)
34
+ return self.out_proj(out)
35
+
36
+ class QuickGELU(nn.Module):
37
+ def forward(self, x: torch.Tensor):
38
+ return x * torch.sigmoid(1.702 * x)
39
+
40
+ class ResBlockUnionControlnet(nn.Module):
41
+ def __init__(self, dim, nhead, dtype=None, device=None, operations=None):
42
+ super().__init__()
43
+ self.attn = OptimizedAttention(dim, nhead, dtype=dtype, device=device, operations=operations)
44
+ self.ln_1 = operations.LayerNorm(dim, dtype=dtype, device=device)
45
+ self.mlp = nn.Sequential(
46
+ OrderedDict([("c_fc", operations.Linear(dim, dim * 4, dtype=dtype, device=device)), ("gelu", QuickGELU()),
47
+ ("c_proj", operations.Linear(dim * 4, dim, dtype=dtype, device=device))]))
48
+ self.ln_2 = operations.LayerNorm(dim, dtype=dtype, device=device)
49
+
50
+ def attention(self, x: torch.Tensor):
51
+ return self.attn(x)
52
+
53
+ def forward(self, x: torch.Tensor):
54
+ x = x + self.attention(self.ln_1(x))
55
+ x = x + self.mlp(self.ln_2(x))
56
+ return x
57
+
58
+ class ControlledUnetModel(UNetModel):
59
+ #implemented in the ldm unet
60
+ pass
61
+
62
+ class ControlNet(nn.Module):
63
+ def __init__(
64
+ self,
65
+ image_size,
66
+ in_channels,
67
+ model_channels,
68
+ hint_channels,
69
+ num_res_blocks,
70
+ dropout=0,
71
+ channel_mult=(1, 2, 4, 8),
72
+ conv_resample=True,
73
+ dims=2,
74
+ num_classes=None,
75
+ use_checkpoint=False,
76
+ dtype=torch.float32,
77
+ num_heads=-1,
78
+ num_head_channels=-1,
79
+ num_heads_upsample=-1,
80
+ use_scale_shift_norm=False,
81
+ resblock_updown=False,
82
+ use_new_attention_order=False,
83
+ use_spatial_transformer=False, # custom transformer support
84
+ transformer_depth=1, # custom transformer support
85
+ context_dim=None, # custom transformer support
86
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
87
+ legacy=True,
88
+ disable_self_attentions=None,
89
+ num_attention_blocks=None,
90
+ disable_middle_self_attn=False,
91
+ use_linear_in_transformer=False,
92
+ adm_in_channels=None,
93
+ transformer_depth_middle=None,
94
+ transformer_depth_output=None,
95
+ attn_precision=None,
96
+ union_controlnet_num_control_type=None,
97
+ device=None,
98
+ operations=comfy.ops.disable_weight_init,
99
+ **kwargs,
100
+ ):
101
+ super().__init__()
102
+ assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
103
+ if use_spatial_transformer:
104
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
105
+
106
+ if context_dim is not None:
107
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
108
+ # from omegaconf.listconfig import ListConfig
109
+ # if type(context_dim) == ListConfig:
110
+ # context_dim = list(context_dim)
111
+
112
+ if num_heads_upsample == -1:
113
+ num_heads_upsample = num_heads
114
+
115
+ if num_heads == -1:
116
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
117
+
118
+ if num_head_channels == -1:
119
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
120
+
121
+ self.dims = dims
122
+ self.image_size = image_size
123
+ self.in_channels = in_channels
124
+ self.model_channels = model_channels
125
+
126
+ if isinstance(num_res_blocks, int):
127
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
128
+ else:
129
+ if len(num_res_blocks) != len(channel_mult):
130
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
131
+ "as a list/tuple (per-level) with the same length as channel_mult")
132
+ self.num_res_blocks = num_res_blocks
133
+
134
+ if disable_self_attentions is not None:
135
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
136
+ assert len(disable_self_attentions) == len(channel_mult)
137
+ if num_attention_blocks is not None:
138
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
139
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
140
+
141
+ transformer_depth = transformer_depth[:]
142
+
143
+ self.dropout = dropout
144
+ self.channel_mult = channel_mult
145
+ self.conv_resample = conv_resample
146
+ self.num_classes = num_classes
147
+ self.use_checkpoint = use_checkpoint
148
+ self.dtype = dtype
149
+ self.num_heads = num_heads
150
+ self.num_head_channels = num_head_channels
151
+ self.num_heads_upsample = num_heads_upsample
152
+ self.predict_codebook_ids = n_embed is not None
153
+
154
+ time_embed_dim = model_channels * 4
155
+ self.time_embed = nn.Sequential(
156
+ operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
157
+ nn.SiLU(),
158
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
159
+ )
160
+
161
+ if self.num_classes is not None:
162
+ if isinstance(self.num_classes, int):
163
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
164
+ elif self.num_classes == "continuous":
165
+ print("setting up linear c_adm embedding layer")
166
+ self.label_emb = nn.Linear(1, time_embed_dim)
167
+ elif self.num_classes == "sequential":
168
+ assert adm_in_channels is not None
169
+ self.label_emb = nn.Sequential(
170
+ nn.Sequential(
171
+ operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
172
+ nn.SiLU(),
173
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
174
+ )
175
+ )
176
+ else:
177
+ raise ValueError()
178
+
179
+ self.input_blocks = nn.ModuleList(
180
+ [
181
+ TimestepEmbedSequential(
182
+ operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
183
+ )
184
+ ]
185
+ )
186
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
187
+
188
+ self.input_hint_block = TimestepEmbedSequential(
189
+ operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
190
+ nn.SiLU(),
191
+ operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
192
+ nn.SiLU(),
193
+ operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
194
+ nn.SiLU(),
195
+ operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
196
+ nn.SiLU(),
197
+ operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
198
+ nn.SiLU(),
199
+ operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
200
+ nn.SiLU(),
201
+ operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
202
+ nn.SiLU(),
203
+ operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
204
+ )
205
+
206
+ self._feature_size = model_channels
207
+ input_block_chans = [model_channels]
208
+ ch = model_channels
209
+ ds = 1
210
+ for level, mult in enumerate(channel_mult):
211
+ for nr in range(self.num_res_blocks[level]):
212
+ layers = [
213
+ ResBlock(
214
+ ch,
215
+ time_embed_dim,
216
+ dropout,
217
+ out_channels=mult * model_channels,
218
+ dims=dims,
219
+ use_checkpoint=use_checkpoint,
220
+ use_scale_shift_norm=use_scale_shift_norm,
221
+ dtype=self.dtype,
222
+ device=device,
223
+ operations=operations,
224
+ )
225
+ ]
226
+ ch = mult * model_channels
227
+ num_transformers = transformer_depth.pop(0)
228
+ if num_transformers > 0:
229
+ if num_head_channels == -1:
230
+ dim_head = ch // num_heads
231
+ else:
232
+ num_heads = ch // num_head_channels
233
+ dim_head = num_head_channels
234
+ if legacy:
235
+ #num_heads = 1
236
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
237
+ if exists(disable_self_attentions):
238
+ disabled_sa = disable_self_attentions[level]
239
+ else:
240
+ disabled_sa = False
241
+
242
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
243
+ layers.append(
244
+ SpatialTransformer(
245
+ ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
246
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
247
+ use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
248
+ )
249
+ )
250
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
251
+ self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
252
+ self._feature_size += ch
253
+ input_block_chans.append(ch)
254
+ if level != len(channel_mult) - 1:
255
+ out_ch = ch
256
+ self.input_blocks.append(
257
+ TimestepEmbedSequential(
258
+ ResBlock(
259
+ ch,
260
+ time_embed_dim,
261
+ dropout,
262
+ out_channels=out_ch,
263
+ dims=dims,
264
+ use_checkpoint=use_checkpoint,
265
+ use_scale_shift_norm=use_scale_shift_norm,
266
+ down=True,
267
+ dtype=self.dtype,
268
+ device=device,
269
+ operations=operations
270
+ )
271
+ if resblock_updown
272
+ else Downsample(
273
+ ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
274
+ )
275
+ )
276
+ )
277
+ ch = out_ch
278
+ input_block_chans.append(ch)
279
+ self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
280
+ ds *= 2
281
+ self._feature_size += ch
282
+
283
+ if num_head_channels == -1:
284
+ dim_head = ch // num_heads
285
+ else:
286
+ num_heads = ch // num_head_channels
287
+ dim_head = num_head_channels
288
+ if legacy:
289
+ #num_heads = 1
290
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
291
+ mid_block = [
292
+ ResBlock(
293
+ ch,
294
+ time_embed_dim,
295
+ dropout,
296
+ dims=dims,
297
+ use_checkpoint=use_checkpoint,
298
+ use_scale_shift_norm=use_scale_shift_norm,
299
+ dtype=self.dtype,
300
+ device=device,
301
+ operations=operations
302
+ )]
303
+ if transformer_depth_middle >= 0:
304
+ mid_block += [SpatialTransformer( # always uses a self-attn
305
+ ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
306
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
307
+ use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
308
+ ),
309
+ ResBlock(
310
+ ch,
311
+ time_embed_dim,
312
+ dropout,
313
+ dims=dims,
314
+ use_checkpoint=use_checkpoint,
315
+ use_scale_shift_norm=use_scale_shift_norm,
316
+ dtype=self.dtype,
317
+ device=device,
318
+ operations=operations
319
+ )]
320
+ self.middle_block = TimestepEmbedSequential(*mid_block)
321
+ self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
322
+ self._feature_size += ch
323
+
324
+ if union_controlnet_num_control_type is not None:
325
+ self.num_control_type = union_controlnet_num_control_type
326
+ num_trans_channel = 320
327
+ num_trans_head = 8
328
+ num_trans_layer = 1
329
+ num_proj_channel = 320
330
+ # task_scale_factor = num_trans_channel ** 0.5
331
+ self.task_embedding = nn.Parameter(torch.empty(self.num_control_type, num_trans_channel, dtype=self.dtype, device=device))
332
+
333
+ self.transformer_layes = nn.Sequential(*[ResBlockUnionControlnet(num_trans_channel, num_trans_head, dtype=self.dtype, device=device, operations=operations) for _ in range(num_trans_layer)])
334
+ self.spatial_ch_projs = operations.Linear(num_trans_channel, num_proj_channel, dtype=self.dtype, device=device)
335
+ #-----------------------------------------------------------------------------------------------------
336
+
337
+ control_add_embed_dim = 256
338
+ class ControlAddEmbedding(nn.Module):
339
+ def __init__(self, in_dim, out_dim, num_control_type, dtype=None, device=None, operations=None):
340
+ super().__init__()
341
+ self.num_control_type = num_control_type
342
+ self.in_dim = in_dim
343
+ self.linear_1 = operations.Linear(in_dim * num_control_type, out_dim, dtype=dtype, device=device)
344
+ self.linear_2 = operations.Linear(out_dim, out_dim, dtype=dtype, device=device)
345
+ def forward(self, control_type, dtype, device):
346
+ c_type = torch.zeros((self.num_control_type,), device=device)
347
+ c_type[control_type] = 1.0
348
+ c_type = timestep_embedding(c_type.flatten(), self.in_dim, repeat_only=False).to(dtype).reshape((-1, self.num_control_type * self.in_dim))
349
+ return self.linear_2(torch.nn.functional.silu(self.linear_1(c_type)))
350
+
351
+ self.control_add_embedding = ControlAddEmbedding(control_add_embed_dim, time_embed_dim, self.num_control_type, dtype=self.dtype, device=device, operations=operations)
352
+ else:
353
+ self.task_embedding = None
354
+ self.control_add_embedding = None
355
+
356
+ def union_controlnet_merge(self, hint, control_type, emb, context):
357
+ # Equivalent to: https://github.com/xinsir6/ControlNetPlus/tree/main
358
+ inputs = []
359
+ condition_list = []
360
+
361
+ for idx in range(min(1, len(control_type))):
362
+ controlnet_cond = self.input_hint_block(hint[idx], emb, context)
363
+ feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
364
+ if idx < len(control_type):
365
+ feat_seq += self.task_embedding[control_type[idx]].to(dtype=feat_seq.dtype, device=feat_seq.device)
366
+
367
+ inputs.append(feat_seq.unsqueeze(1))
368
+ condition_list.append(controlnet_cond)
369
+
370
+ x = torch.cat(inputs, dim=1)
371
+ x = self.transformer_layes(x)
372
+ controlnet_cond_fuser = None
373
+ for idx in range(len(control_type)):
374
+ alpha = self.spatial_ch_projs(x[:, idx])
375
+ alpha = alpha.unsqueeze(-1).unsqueeze(-1)
376
+ o = condition_list[idx] + alpha
377
+ if controlnet_cond_fuser is None:
378
+ controlnet_cond_fuser = o
379
+ else:
380
+ controlnet_cond_fuser += o
381
+ return controlnet_cond_fuser
382
+
383
+ def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
384
+ return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
385
+
386
+ def forward(self, x, hint, timesteps, context, y=None, **kwargs):
387
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
388
+ emb = self.time_embed(t_emb)
389
+
390
+ guided_hint = None
391
+ if self.control_add_embedding is not None: #Union Controlnet
392
+ control_type = kwargs.get("control_type", [])
393
+
394
+ if any([c >= self.num_control_type for c in control_type]):
395
+ max_type = max(control_type)
396
+ max_type_name = {
397
+ v: k for k, v in UNION_CONTROLNET_TYPES.items()
398
+ }[max_type]
399
+ raise ValueError(
400
+ f"Control type {max_type_name}({max_type}) is out of range for the number of control types" +
401
+ f"({self.num_control_type}) supported.\n" +
402
+ "Please consider using the ProMax ControlNet Union model.\n" +
403
+ "https://huggingface.co/xinsir/controlnet-union-sdxl-1.0/tree/main"
404
+ )
405
+
406
+ emb += self.control_add_embedding(control_type, emb.dtype, emb.device)
407
+ if len(control_type) > 0:
408
+ if len(hint.shape) < 5:
409
+ hint = hint.unsqueeze(dim=0)
410
+ guided_hint = self.union_controlnet_merge(hint, control_type, emb, context)
411
+
412
+ if guided_hint is None:
413
+ guided_hint = self.input_hint_block(hint, emb, context)
414
+
415
+ out_output = []
416
+ out_middle = []
417
+
418
+ hs = []
419
+ if self.num_classes is not None:
420
+ assert y.shape[0] == x.shape[0]
421
+ emb = emb + self.label_emb(y)
422
+
423
+ h = x
424
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
425
+ if guided_hint is not None:
426
+ h = module(h, emb, context)
427
+ h += guided_hint
428
+ guided_hint = None
429
+ else:
430
+ h = module(h, emb, context)
431
+ out_output.append(zero_conv(h, emb, context))
432
+
433
+ h = self.middle_block(h, emb, context)
434
+ out_middle.append(self.middle_block_out(h, emb, context))
435
+
436
+ return {"middle": out_middle, "output": out_output}
437
+
comfy/cldm/control_types.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ UNION_CONTROLNET_TYPES = {
2
+ "openpose": 0,
3
+ "depth": 1,
4
+ "hed/pidi/scribble/ted": 2,
5
+ "canny/lineart/anime_lineart/mlsd": 3,
6
+ "normal": 4,
7
+ "segment": 5,
8
+ "tile": 6,
9
+ "repaint": 7,
10
+ }
comfy/cldm/mmdit.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Dict, Optional
3
+ import comfy.ldm.modules.diffusionmodules.mmdit
4
+
5
+ class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT):
6
+ def __init__(
7
+ self,
8
+ num_blocks = None,
9
+ control_latent_channels = None,
10
+ dtype = None,
11
+ device = None,
12
+ operations = None,
13
+ **kwargs,
14
+ ):
15
+ super().__init__(dtype=dtype, device=device, operations=operations, final_layer=False, num_blocks=num_blocks, **kwargs)
16
+ # controlnet_blocks
17
+ self.controlnet_blocks = torch.nn.ModuleList([])
18
+ for _ in range(len(self.joint_blocks)):
19
+ self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
20
+
21
+ if control_latent_channels is None:
22
+ control_latent_channels = self.in_channels
23
+
24
+ self.pos_embed_input = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(
25
+ None,
26
+ self.patch_size,
27
+ control_latent_channels,
28
+ self.hidden_size,
29
+ bias=True,
30
+ strict_img_size=False,
31
+ dtype=dtype,
32
+ device=device,
33
+ operations=operations
34
+ )
35
+
36
+ def forward(
37
+ self,
38
+ x: torch.Tensor,
39
+ timesteps: torch.Tensor,
40
+ y: Optional[torch.Tensor] = None,
41
+ context: Optional[torch.Tensor] = None,
42
+ hint = None,
43
+ ) -> torch.Tensor:
44
+
45
+ #weird sd3 controlnet specific stuff
46
+ y = torch.zeros_like(y)
47
+
48
+ if self.context_processor is not None:
49
+ context = self.context_processor(context)
50
+
51
+ hw = x.shape[-2:]
52
+ x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype, device=x.device)
53
+ x += self.pos_embed_input(hint)
54
+
55
+ c = self.t_embedder(timesteps, dtype=x.dtype)
56
+ if y is not None and self.y_embedder is not None:
57
+ y = self.y_embedder(y)
58
+ c = c + y
59
+
60
+ if context is not None:
61
+ context = self.context_embedder(context)
62
+
63
+ output = []
64
+
65
+ blocks = len(self.joint_blocks)
66
+ for i in range(blocks):
67
+ context, x = self.joint_blocks[i](
68
+ context,
69
+ x,
70
+ c=c,
71
+ use_checkpoint=self.use_checkpoint,
72
+ )
73
+
74
+ out = self.controlnet_blocks[i](x)
75
+ count = self.depth // blocks
76
+ if i == blocks - 1:
77
+ count -= 1
78
+ for j in range(count):
79
+ output.append(out)
80
+
81
+ return {"output": output}
comfy/cli_args.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import enum
3
+ import os
4
+ from typing import Optional
5
+ import comfy.options
6
+
7
+
8
+ class EnumAction(argparse.Action):
9
+ """
10
+ Argparse action for handling Enums
11
+ """
12
+ def __init__(self, **kwargs):
13
+ # Pop off the type value
14
+ enum_type = kwargs.pop("type", None)
15
+
16
+ # Ensure an Enum subclass is provided
17
+ if enum_type is None:
18
+ raise ValueError("type must be assigned an Enum when using EnumAction")
19
+ if not issubclass(enum_type, enum.Enum):
20
+ raise TypeError("type must be an Enum when using EnumAction")
21
+
22
+ # Generate choices from the Enum
23
+ choices = tuple(e.value for e in enum_type)
24
+ kwargs.setdefault("choices", choices)
25
+ kwargs.setdefault("metavar", f"[{','.join(list(choices))}]")
26
+
27
+ super(EnumAction, self).__init__(**kwargs)
28
+
29
+ self._enum = enum_type
30
+
31
+ def __call__(self, parser, namespace, values, option_string=None):
32
+ # Convert value back into an Enum
33
+ value = self._enum(values)
34
+ setattr(namespace, self.dest, value)
35
+
36
+
37
+ parser = argparse.ArgumentParser()
38
+
39
+ parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
40
+ parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
41
+ parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
42
+ parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
43
+ parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
44
+ parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")
45
+
46
+ parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
47
+ parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
48
+ parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).")
49
+ parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory.")
50
+ parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
51
+ parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
52
+ parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
53
+ cm_group = parser.add_mutually_exclusive_group()
54
+ cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
55
+ cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
56
+
57
+
58
+ fp_group = parser.add_mutually_exclusive_group()
59
+ fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
60
+ fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
61
+
62
+ fpunet_group = parser.add_mutually_exclusive_group()
63
+ fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
64
+ fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.")
65
+ fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
66
+ fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
67
+
68
+ fpvae_group = parser.add_mutually_exclusive_group()
69
+ fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
70
+ fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
71
+ fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
72
+
73
+ parser.add_argument("--cpu-vae", action="store_true", help="Run the VAE on the CPU.")
74
+
75
+ fpte_group = parser.add_mutually_exclusive_group()
76
+ fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
77
+ fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
78
+ fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
79
+ fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
80
+
81
+ parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
82
+
83
+ parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
84
+
85
+ parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")
86
+
87
+ class LatentPreviewMethod(enum.Enum):
88
+ NoPreviews = "none"
89
+ Auto = "auto"
90
+ Latent2RGB = "latent2rgb"
91
+ TAESD = "taesd"
92
+
93
+ parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
94
+
95
+ parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
96
+
97
+ cache_group = parser.add_mutually_exclusive_group()
98
+ cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
99
+ cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
100
+
101
+ attn_group = parser.add_mutually_exclusive_group()
102
+ attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
103
+ attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
104
+ attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
105
+
106
+ parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
107
+
108
+ upcast = parser.add_mutually_exclusive_group()
109
+ upcast.add_argument("--force-upcast-attention", action="store_true", help="Force enable attention upcasting, please report if it fixes black images.")
110
+ upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
111
+
112
+
113
+ vram_group = parser.add_mutually_exclusive_group()
114
+ vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
115
+ vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
116
+ vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
117
+ vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
118
+ vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
119
+ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
120
+
121
+ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reverved depending on your OS.")
122
+
123
+
124
+ parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
125
+
126
+ parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
127
+ parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
128
+ parser.add_argument("--fast", action="store_true", help="Enable some untested and potentially quality deteriorating optimizations.")
129
+
130
+ parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
131
+ parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
132
+ parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
133
+
134
+ parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
135
+ parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
136
+
137
+ parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
138
+
139
+ parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")
140
+
141
+ # The default built-in provider hosted under web/
142
+ DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
143
+
144
+ parser.add_argument(
145
+ "--front-end-version",
146
+ type=str,
147
+ default=DEFAULT_VERSION_STRING,
148
+ help="""
149
+ Specifies the version of the frontend to be used. This command needs internet connectivity to query and
150
+ download available frontend implementations from GitHub releases.
151
+
152
+ The version string should be in the format of:
153
+ [repoOwner]/[repoName]@[version]
154
+ where version is one of: "latest" or a valid version number (e.g. "1.0.0")
155
+ """,
156
+ )
157
+
158
+ def is_valid_directory(path: Optional[str]) -> Optional[str]:
159
+ """Validate if the given path is a directory."""
160
+ if path is None:
161
+ return None
162
+
163
+ if not os.path.isdir(path):
164
+ raise argparse.ArgumentTypeError(f"{path} is not a valid directory.")
165
+ return path
166
+
167
+ parser.add_argument(
168
+ "--front-end-root",
169
+ type=is_valid_directory,
170
+ default=None,
171
+ help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
172
+ )
173
+
174
+ parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path.")
175
+
176
+ if comfy.options.args_parsing:
177
+ args = parser.parse_args()
178
+ else:
179
+ args = parser.parse_args([])
180
+
181
+ if args.windows_standalone_build:
182
+ args.auto_launch = True
183
+
184
+ if args.disable_auto_launch:
185
+ args.auto_launch = False
186
+
187
+ #import logging
188
+ #logging_level = logging.INFO
189
+ #if args.verbose:
190
+ # logging_level = logging.DEBUG
191
+
192
+ #logging.basicConfig(format="[Comfyd] %(message)s", level=logging_level)
comfy/clip_config_bigg.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CLIPTextModel"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "bos_token_id": 0,
7
+ "dropout": 0.0,
8
+ "eos_token_id": 49407,
9
+ "hidden_act": "gelu",
10
+ "hidden_size": 1280,
11
+ "initializer_factor": 1.0,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 5120,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 77,
16
+ "model_type": "clip_text_model",
17
+ "num_attention_heads": 20,
18
+ "num_hidden_layers": 32,
19
+ "pad_token_id": 1,
20
+ "projection_dim": 1280,
21
+ "torch_dtype": "float32",
22
+ "vocab_size": 49408
23
+ }
comfy/clip_model.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from comfy.ldm.modules.attention import optimized_attention_for_device
3
+ import comfy.ops
4
+
5
+ class CLIPAttention(torch.nn.Module):
6
+ def __init__(self, embed_dim, heads, dtype, device, operations):
7
+ super().__init__()
8
+
9
+ self.heads = heads
10
+ self.q_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
11
+ self.k_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
12
+ self.v_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
13
+
14
+ self.out_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
15
+
16
+ def forward(self, x, mask=None, optimized_attention=None):
17
+ q = self.q_proj(x)
18
+ k = self.k_proj(x)
19
+ v = self.v_proj(x)
20
+
21
+ out = optimized_attention(q, k, v, self.heads, mask)
22
+ return self.out_proj(out)
23
+
24
+ ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
25
+ "gelu": torch.nn.functional.gelu,
26
+ }
27
+
28
+ class CLIPMLP(torch.nn.Module):
29
+ def __init__(self, embed_dim, intermediate_size, activation, dtype, device, operations):
30
+ super().__init__()
31
+ self.fc1 = operations.Linear(embed_dim, intermediate_size, bias=True, dtype=dtype, device=device)
32
+ self.activation = ACTIVATIONS[activation]
33
+ self.fc2 = operations.Linear(intermediate_size, embed_dim, bias=True, dtype=dtype, device=device)
34
+
35
+ def forward(self, x):
36
+ x = self.fc1(x)
37
+ x = self.activation(x)
38
+ x = self.fc2(x)
39
+ return x
40
+
41
+ class CLIPLayer(torch.nn.Module):
42
+ def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
43
+ super().__init__()
44
+ self.layer_norm1 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
45
+ self.self_attn = CLIPAttention(embed_dim, heads, dtype, device, operations)
46
+ self.layer_norm2 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
47
+ self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device, operations)
48
+
49
+ def forward(self, x, mask=None, optimized_attention=None):
50
+ x += self.self_attn(self.layer_norm1(x), mask, optimized_attention)
51
+ x += self.mlp(self.layer_norm2(x))
52
+ return x
53
+
54
+
55
+ class CLIPEncoder(torch.nn.Module):
56
+ def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
57
+ super().__init__()
58
+ self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)])
59
+
60
+ def forward(self, x, mask=None, intermediate_output=None):
61
+ optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
62
+
63
+ if intermediate_output is not None:
64
+ if intermediate_output < 0:
65
+ intermediate_output = len(self.layers) + intermediate_output
66
+
67
+ intermediate = None
68
+ for i, l in enumerate(self.layers):
69
+ x = l(x, mask, optimized_attention)
70
+ if i == intermediate_output:
71
+ intermediate = x.clone()
72
+ return x, intermediate
73
+
74
+ class CLIPEmbeddings(torch.nn.Module):
75
+ def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None, operations=None):
76
+ super().__init__()
77
+ self.token_embedding = operations.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
78
+ self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
79
+
80
+ def forward(self, input_tokens, dtype=torch.float32):
81
+ return self.token_embedding(input_tokens, out_dtype=dtype) + comfy.ops.cast_to(self.position_embedding.weight, dtype=dtype, device=input_tokens.device)
82
+
83
+
84
+ class CLIPTextModel_(torch.nn.Module):
85
+ def __init__(self, config_dict, dtype, device, operations):
86
+ num_layers = config_dict["num_hidden_layers"]
87
+ embed_dim = config_dict["hidden_size"]
88
+ heads = config_dict["num_attention_heads"]
89
+ intermediate_size = config_dict["intermediate_size"]
90
+ intermediate_activation = config_dict["hidden_act"]
91
+ num_positions = config_dict["max_position_embeddings"]
92
+ self.eos_token_id = config_dict["eos_token_id"]
93
+
94
+ super().__init__()
95
+ self.embeddings = CLIPEmbeddings(embed_dim, num_positions=num_positions, dtype=dtype, device=device, operations=operations)
96
+ self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
97
+ self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
98
+
99
+ def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
100
+ x = self.embeddings(input_tokens, dtype=dtype)
101
+ mask = None
102
+ if attention_mask is not None:
103
+ mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
104
+ mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
105
+
106
+ causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
107
+ if mask is not None:
108
+ mask += causal_mask
109
+ else:
110
+ mask = causal_mask
111
+
112
+ x, i = self.encoder(x, mask=mask, intermediate_output=intermediate_output)
113
+ x = self.final_layer_norm(x)
114
+ if i is not None and final_layer_norm_intermediate:
115
+ i = self.final_layer_norm(i)
116
+
117
+ pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),]
118
+ return x, i, pooled_output
119
+
120
+ class CLIPTextModel(torch.nn.Module):
121
+ def __init__(self, config_dict, dtype, device, operations):
122
+ super().__init__()
123
+ self.num_layers = config_dict["num_hidden_layers"]
124
+ self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
125
+ embed_dim = config_dict["hidden_size"]
126
+ self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
127
+ self.dtype = dtype
128
+
129
+ def get_input_embeddings(self):
130
+ return self.text_model.embeddings.token_embedding
131
+
132
+ def set_input_embeddings(self, embeddings):
133
+ self.text_model.embeddings.token_embedding = embeddings
134
+
135
+ def forward(self, *args, **kwargs):
136
+ x = self.text_model(*args, **kwargs)
137
+ out = self.text_projection(x[2])
138
+ return (x[0], x[1], out, x[2])
139
+
140
+
141
+ class CLIPVisionEmbeddings(torch.nn.Module):
142
+ def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None):
143
+ super().__init__()
144
+ self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
145
+
146
+ self.patch_embedding = operations.Conv2d(
147
+ in_channels=num_channels,
148
+ out_channels=embed_dim,
149
+ kernel_size=patch_size,
150
+ stride=patch_size,
151
+ bias=False,
152
+ dtype=dtype,
153
+ device=device
154
+ )
155
+
156
+ num_patches = (image_size // patch_size) ** 2
157
+ num_positions = num_patches + 1
158
+ self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
159
+
160
+ def forward(self, pixel_values):
161
+ embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
162
+ return torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
163
+
164
+
165
+ class CLIPVision(torch.nn.Module):
166
+ def __init__(self, config_dict, dtype, device, operations):
167
+ super().__init__()
168
+ num_layers = config_dict["num_hidden_layers"]
169
+ embed_dim = config_dict["hidden_size"]
170
+ heads = config_dict["num_attention_heads"]
171
+ intermediate_size = config_dict["intermediate_size"]
172
+ intermediate_activation = config_dict["hidden_act"]
173
+
174
+ self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=dtype, device=device, operations=operations)
175
+ self.pre_layrnorm = operations.LayerNorm(embed_dim)
176
+ self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
177
+ self.post_layernorm = operations.LayerNorm(embed_dim)
178
+
179
+ def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
180
+ x = self.embeddings(pixel_values)
181
+ x = self.pre_layrnorm(x)
182
+ #TODO: attention_mask?
183
+ x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
184
+ pooled_output = self.post_layernorm(x[:, 0, :])
185
+ return x, i, pooled_output
186
+
187
+ class CLIPVisionModelProjection(torch.nn.Module):
188
+ def __init__(self, config_dict, dtype, device, operations):
189
+ super().__init__()
190
+ self.vision_model = CLIPVision(config_dict, dtype, device, operations)
191
+ self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
192
+
193
+ def forward(self, *args, **kwargs):
194
+ x = self.vision_model(*args, **kwargs)
195
+ out = self.visual_projection(x[2])
196
+ return (x[0], x[1], out)
comfy/clip_vision.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
2
+ import os
3
+ import torch
4
+ import json
5
+ import logging
6
+
7
+ import comfy.ops
8
+ import comfy.model_patcher
9
+ import comfy.model_management
10
+ import comfy.utils
11
+ import comfy.clip_model
12
+
13
+ class Output:
14
+ def __getitem__(self, key):
15
+ return getattr(self, key)
16
+ def __setitem__(self, key, item):
17
+ setattr(self, key, item)
18
+
19
+ def clip_preprocess(image, size=224):
20
+ mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
21
+ std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
22
+ image = image.movedim(-1, 1)
23
+ if not (image.shape[2] == size and image.shape[3] == size):
24
+ scale = (size / min(image.shape[2], image.shape[3]))
25
+ image = torch.nn.functional.interpolate(image, size=(round(scale * image.shape[2]), round(scale * image.shape[3])), mode="bicubic", antialias=True)
26
+ h = (image.shape[2] - size)//2
27
+ w = (image.shape[3] - size)//2
28
+ image = image[:,:,h:h+size,w:w+size]
29
+ image = torch.clip((255. * image), 0, 255).round() / 255.0
30
+ return (image - mean.view([3,1,1])) / std.view([3,1,1])
31
+
32
+ class ClipVisionModel():
33
+ def __init__(self, json_config):
34
+ with open(json_config) as f:
35
+ config = json.load(f)
36
+
37
+ self.image_size = config.get("image_size", 224)
38
+ self.load_device = comfy.model_management.text_encoder_device()
39
+ offload_device = comfy.model_management.text_encoder_offload_device()
40
+ self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
41
+ self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.manual_cast)
42
+ self.model.eval()
43
+
44
+ self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
45
+
46
+ def load_sd(self, sd):
47
+ return self.model.load_state_dict(sd, strict=False)
48
+
49
+ def get_sd(self):
50
+ return self.model.state_dict()
51
+
52
+ def encode_image(self, image):
53
+ comfy.model_management.load_model_gpu(self.patcher)
54
+ pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float()
55
+ out = self.model(pixel_values=pixel_values, intermediate_output=-2)
56
+
57
+ outputs = Output()
58
+ outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
59
+ outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
60
+ outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
61
+ return outputs
62
+
63
+ def convert_to_transformers(sd, prefix):
64
+ sd_k = sd.keys()
65
+ if "{}transformer.resblocks.0.attn.in_proj_weight".format(prefix) in sd_k:
66
+ keys_to_replace = {
67
+ "{}class_embedding".format(prefix): "vision_model.embeddings.class_embedding",
68
+ "{}conv1.weight".format(prefix): "vision_model.embeddings.patch_embedding.weight",
69
+ "{}positional_embedding".format(prefix): "vision_model.embeddings.position_embedding.weight",
70
+ "{}ln_post.bias".format(prefix): "vision_model.post_layernorm.bias",
71
+ "{}ln_post.weight".format(prefix): "vision_model.post_layernorm.weight",
72
+ "{}ln_pre.bias".format(prefix): "vision_model.pre_layrnorm.bias",
73
+ "{}ln_pre.weight".format(prefix): "vision_model.pre_layrnorm.weight",
74
+ }
75
+
76
+ for x in keys_to_replace:
77
+ if x in sd_k:
78
+ sd[keys_to_replace[x]] = sd.pop(x)
79
+
80
+ if "{}proj".format(prefix) in sd_k:
81
+ sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1)
82
+
83
+ sd = transformers_convert(sd, prefix, "vision_model.", 48)
84
+ else:
85
+ replace_prefix = {prefix: ""}
86
+ sd = state_dict_prefix_replace(sd, replace_prefix)
87
+ return sd
88
+
89
+ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
90
+ if convert_keys:
91
+ sd = convert_to_transformers(sd, prefix)
92
+ if "vision_model.encoder.layers.47.layer_norm1.weight" in sd:
93
+ json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json")
94
+ elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
95
+ json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
96
+ elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
97
+ if sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
98
+ json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
99
+ else:
100
+ json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
101
+ else:
102
+ return None
103
+
104
+ clip = ClipVisionModel(json_config)
105
+ m, u = clip.load_sd(sd)
106
+ if len(m) > 0:
107
+ logging.warning("missing clip vision: {}".format(m))
108
+ u = set(u)
109
+ keys = list(sd.keys())
110
+ for k in keys:
111
+ if k not in u:
112
+ t = sd.pop(k)
113
+ del t
114
+ return clip
115
+
116
+ def load(ckpt_path):
117
+ sd = load_torch_file(ckpt_path)
118
+ if "visual.transformer.resblocks.0.attn.in_proj_weight" in sd:
119
+ return load_clipvision_from_sd(sd, prefix="visual.", convert_keys=True)
120
+ else:
121
+ return load_clipvision_from_sd(sd)
comfy/clip_vision_config_g.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_dropout": 0.0,
3
+ "dropout": 0.0,
4
+ "hidden_act": "gelu",
5
+ "hidden_size": 1664,
6
+ "image_size": 224,
7
+ "initializer_factor": 1.0,
8
+ "initializer_range": 0.02,
9
+ "intermediate_size": 8192,
10
+ "layer_norm_eps": 1e-05,
11
+ "model_type": "clip_vision_model",
12
+ "num_attention_heads": 16,
13
+ "num_channels": 3,
14
+ "num_hidden_layers": 48,
15
+ "patch_size": 14,
16
+ "projection_dim": 1280,
17
+ "torch_dtype": "float32"
18
+ }
comfy/clip_vision_config_h.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_dropout": 0.0,
3
+ "dropout": 0.0,
4
+ "hidden_act": "gelu",
5
+ "hidden_size": 1280,
6
+ "image_size": 224,
7
+ "initializer_factor": 1.0,
8
+ "initializer_range": 0.02,
9
+ "intermediate_size": 5120,
10
+ "layer_norm_eps": 1e-05,
11
+ "model_type": "clip_vision_model",
12
+ "num_attention_heads": 16,
13
+ "num_channels": 3,
14
+ "num_hidden_layers": 32,
15
+ "patch_size": 14,
16
+ "projection_dim": 1024,
17
+ "torch_dtype": "float32"
18
+ }
comfy/clip_vision_config_vitl.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_dropout": 0.0,
3
+ "dropout": 0.0,
4
+ "hidden_act": "quick_gelu",
5
+ "hidden_size": 1024,
6
+ "image_size": 224,
7
+ "initializer_factor": 1.0,
8
+ "initializer_range": 0.02,
9
+ "intermediate_size": 4096,
10
+ "layer_norm_eps": 1e-05,
11
+ "model_type": "clip_vision_model",
12
+ "num_attention_heads": 16,
13
+ "num_channels": 3,
14
+ "num_hidden_layers": 24,
15
+ "patch_size": 14,
16
+ "projection_dim": 768,
17
+ "torch_dtype": "float32"
18
+ }
comfy/clip_vision_config_vitl_336.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_dropout": 0.0,
3
+ "dropout": 0.0,
4
+ "hidden_act": "quick_gelu",
5
+ "hidden_size": 1024,
6
+ "image_size": 336,
7
+ "initializer_factor": 1.0,
8
+ "initializer_range": 0.02,
9
+ "intermediate_size": 4096,
10
+ "layer_norm_eps": 1e-5,
11
+ "model_type": "clip_vision_model",
12
+ "num_attention_heads": 16,
13
+ "num_channels": 3,
14
+ "num_hidden_layers": 24,
15
+ "patch_size": 14,
16
+ "projection_dim": 768,
17
+ "torch_dtype": "float32"
18
+ }
comfy/comfy_types.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Callable, Protocol, TypedDict, Optional, List
3
+
4
+
5
+ class UnetApplyFunction(Protocol):
6
+ """Function signature protocol on comfy.model_base.BaseModel.apply_model"""
7
+
8
+ def __call__(self, x: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
9
+ pass
10
+
11
+
12
+ class UnetApplyConds(TypedDict):
13
+ """Optional conditions for unet apply function."""
14
+
15
+ c_concat: Optional[torch.Tensor]
16
+ c_crossattn: Optional[torch.Tensor]
17
+ control: Optional[torch.Tensor]
18
+ transformer_options: Optional[dict]
19
+
20
+
21
+ class UnetParams(TypedDict):
22
+ # Tensor of shape [B, C, H, W]
23
+ input: torch.Tensor
24
+ # Tensor of shape [B]
25
+ timestep: torch.Tensor
26
+ c: UnetApplyConds
27
+ # List of [0, 1], [0], [1], ...
28
+ # 0 means conditional, 1 means conditional unconditional
29
+ cond_or_uncond: List[int]
30
+
31
+
32
+ UnetWrapperFunction = Callable[[UnetApplyFunction, UnetParams], torch.Tensor]
comfy/conds.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import comfy.utils
4
+
5
+
6
+ def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
7
+ return abs(a*b) // math.gcd(a, b)
8
+
9
+ class CONDRegular:
10
+ def __init__(self, cond):
11
+ self.cond = cond
12
+
13
+ def _copy_with(self, cond):
14
+ return self.__class__(cond)
15
+
16
+ def process_cond(self, batch_size, device, **kwargs):
17
+ return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
18
+
19
+ def can_concat(self, other):
20
+ if self.cond.shape != other.cond.shape:
21
+ return False
22
+ return True
23
+
24
+ def concat(self, others):
25
+ conds = [self.cond]
26
+ for x in others:
27
+ conds.append(x.cond)
28
+ return torch.cat(conds)
29
+
30
+ class CONDNoiseShape(CONDRegular):
31
+ def process_cond(self, batch_size, device, area, **kwargs):
32
+ data = self.cond
33
+ if area is not None:
34
+ dims = len(area) // 2
35
+ for i in range(dims):
36
+ data = data.narrow(i + 2, area[i + dims], area[i])
37
+
38
+ return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device))
39
+
40
+
41
+ class CONDCrossAttn(CONDRegular):
42
+ def can_concat(self, other):
43
+ s1 = self.cond.shape
44
+ s2 = other.cond.shape
45
+ if s1 != s2:
46
+ if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
47
+ return False
48
+
49
+ mult_min = lcm(s1[1], s2[1])
50
+ diff = mult_min // min(s1[1], s2[1])
51
+ if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
52
+ return False
53
+ return True
54
+
55
+ def concat(self, others):
56
+ conds = [self.cond]
57
+ crossattn_max_len = self.cond.shape[1]
58
+ for x in others:
59
+ c = x.cond
60
+ crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
61
+ conds.append(c)
62
+
63
+ out = []
64
+ for c in conds:
65
+ if c.shape[1] < crossattn_max_len:
66
+ c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
67
+ out.append(c)
68
+ return torch.cat(out)
69
+
70
+ class CONDConstant(CONDRegular):
71
+ def __init__(self, cond):
72
+ self.cond = cond
73
+
74
+ def process_cond(self, batch_size, device, **kwargs):
75
+ return self._copy_with(self.cond)
76
+
77
+ def can_concat(self, other):
78
+ if self.cond != other.cond:
79
+ return False
80
+ return True
81
+
82
+ def concat(self, others):
83
+ return self.cond
comfy/controlnet.py ADDED
@@ -0,0 +1,737 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is part of ComfyUI.
3
+ Copyright (C) 2024 Comfy
4
+
5
+ This program is free software: you can redistribute it and/or modify
6
+ it under the terms of the GNU General Public License as published by
7
+ the Free Software Foundation, either version 3 of the License, or
8
+ (at your option) any later version.
9
+
10
+ This program is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ GNU General Public License for more details.
14
+
15
+ You should have received a copy of the GNU General Public License
16
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ """
18
+
19
+
20
+ import torch
21
+ from enum import Enum
22
+ import math
23
+ import os
24
+ import logging
25
+ import comfy.utils
26
+ import comfy.model_management
27
+ import comfy.model_detection
28
+ import comfy.model_patcher
29
+ import comfy.ops
30
+ import comfy.latent_formats
31
+
32
+ import comfy.cldm.cldm
33
+ import comfy.t2i_adapter.adapter
34
+ import comfy.ldm.cascade.controlnet
35
+ import comfy.cldm.mmdit
36
+ import comfy.ldm.hydit.controlnet
37
+ import comfy.ldm.flux.controlnet
38
+
39
+
40
+ def broadcast_image_to(tensor, target_batch_size, batched_number):
41
+ current_batch_size = tensor.shape[0]
42
+ #print(current_batch_size, target_batch_size)
43
+ if current_batch_size == 1:
44
+ return tensor
45
+
46
+ per_batch = target_batch_size // batched_number
47
+ tensor = tensor[:per_batch]
48
+
49
+ if per_batch > tensor.shape[0]:
50
+ tensor = torch.cat([tensor] * (per_batch // tensor.shape[0]) + [tensor[:(per_batch % tensor.shape[0])]], dim=0)
51
+
52
+ current_batch_size = tensor.shape[0]
53
+ if current_batch_size == target_batch_size:
54
+ return tensor
55
+ else:
56
+ return torch.cat([tensor] * batched_number, dim=0)
57
+
58
+ class StrengthType(Enum):
59
+ CONSTANT = 1
60
+ LINEAR_UP = 2
61
+
62
+ class ControlBase:
63
+ def __init__(self, device=None):
64
+ self.cond_hint_original = None
65
+ self.cond_hint = None
66
+ self.strength = 1.0
67
+ self.timestep_percent_range = (0.0, 1.0)
68
+ self.latent_format = None
69
+ self.vae = None
70
+ self.global_average_pooling = False
71
+ self.timestep_range = None
72
+ self.compression_ratio = 8
73
+ self.upscale_algorithm = 'nearest-exact'
74
+ self.extra_args = {}
75
+
76
+ if device is None:
77
+ device = comfy.model_management.get_torch_device()
78
+ self.device = device
79
+ self.previous_controlnet = None
80
+ self.extra_conds = []
81
+ self.strength_type = StrengthType.CONSTANT
82
+ self.concat_mask = False
83
+ self.extra_concat_orig = []
84
+ self.extra_concat = None
85
+
86
+ def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
87
+ self.cond_hint_original = cond_hint
88
+ self.strength = strength
89
+ self.timestep_percent_range = timestep_percent_range
90
+ if self.latent_format is not None:
91
+ self.vae = vae
92
+ self.extra_concat_orig = extra_concat.copy()
93
+ if self.concat_mask and len(self.extra_concat_orig) == 0:
94
+ self.extra_concat_orig.append(torch.tensor([[[[1.0]]]]))
95
+ return self
96
+
97
+ def pre_run(self, model, percent_to_timestep_function):
98
+ self.timestep_range = (percent_to_timestep_function(self.timestep_percent_range[0]), percent_to_timestep_function(self.timestep_percent_range[1]))
99
+ if self.previous_controlnet is not None:
100
+ self.previous_controlnet.pre_run(model, percent_to_timestep_function)
101
+
102
+ def set_previous_controlnet(self, controlnet):
103
+ self.previous_controlnet = controlnet
104
+ return self
105
+
106
+ def cleanup(self):
107
+ if self.previous_controlnet is not None:
108
+ self.previous_controlnet.cleanup()
109
+
110
+ self.cond_hint = None
111
+ self.extra_concat = None
112
+ self.timestep_range = None
113
+
114
+ def get_models(self):
115
+ out = []
116
+ if self.previous_controlnet is not None:
117
+ out += self.previous_controlnet.get_models()
118
+ return out
119
+
120
+ def copy_to(self, c):
121
+ c.cond_hint_original = self.cond_hint_original
122
+ c.strength = self.strength
123
+ c.timestep_percent_range = self.timestep_percent_range
124
+ c.global_average_pooling = self.global_average_pooling
125
+ c.compression_ratio = self.compression_ratio
126
+ c.upscale_algorithm = self.upscale_algorithm
127
+ c.latent_format = self.latent_format
128
+ c.extra_args = self.extra_args.copy()
129
+ c.vae = self.vae
130
+ c.extra_conds = self.extra_conds.copy()
131
+ c.strength_type = self.strength_type
132
+ c.concat_mask = self.concat_mask
133
+ c.extra_concat_orig = self.extra_concat_orig.copy()
134
+
135
+ def inference_memory_requirements(self, dtype):
136
+ if self.previous_controlnet is not None:
137
+ return self.previous_controlnet.inference_memory_requirements(dtype)
138
+ return 0
139
+
140
+ def control_merge(self, control, control_prev, output_dtype):
141
+ out = {'input':[], 'middle':[], 'output': []}
142
+
143
+ for key in control:
144
+ control_output = control[key]
145
+ applied_to = set()
146
+ for i in range(len(control_output)):
147
+ x = control_output[i]
148
+ if x is not None:
149
+ if self.global_average_pooling:
150
+ x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
151
+
152
+ if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once
153
+ applied_to.add(x)
154
+ if self.strength_type == StrengthType.CONSTANT:
155
+ x *= self.strength
156
+ elif self.strength_type == StrengthType.LINEAR_UP:
157
+ x *= (self.strength ** float(len(control_output) - i))
158
+
159
+ if output_dtype is not None and x.dtype != output_dtype:
160
+ x = x.to(output_dtype)
161
+
162
+ out[key].append(x)
163
+
164
+ if control_prev is not None:
165
+ for x in ['input', 'middle', 'output']:
166
+ o = out[x]
167
+ for i in range(len(control_prev[x])):
168
+ prev_val = control_prev[x][i]
169
+ if i >= len(o):
170
+ o.append(prev_val)
171
+ elif prev_val is not None:
172
+ if o[i] is None:
173
+ o[i] = prev_val
174
+ else:
175
+ if o[i].shape[0] < prev_val.shape[0]:
176
+ o[i] = prev_val + o[i]
177
+ else:
178
+ o[i] = prev_val + o[i] #TODO: change back to inplace add if shared tensors stop being an issue
179
+ return out
180
+
181
+ def set_extra_arg(self, argument, value=None):
182
+ self.extra_args[argument] = value
183
+
184
+
185
+ class ControlNet(ControlBase):
186
+ def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False):
187
+ super().__init__(device)
188
+ self.control_model = control_model
189
+ self.load_device = load_device
190
+ if control_model is not None:
191
+ self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
192
+
193
+ self.compression_ratio = compression_ratio
194
+ self.global_average_pooling = global_average_pooling
195
+ self.model_sampling_current = None
196
+ self.manual_cast_dtype = manual_cast_dtype
197
+ self.latent_format = latent_format
198
+ self.extra_conds += extra_conds
199
+ self.strength_type = strength_type
200
+ self.concat_mask = concat_mask
201
+
202
+ def get_control(self, x_noisy, t, cond, batched_number):
203
+ control_prev = None
204
+ if self.previous_controlnet is not None:
205
+ control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
206
+
207
+ if self.timestep_range is not None:
208
+ if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
209
+ if control_prev is not None:
210
+ return control_prev
211
+ else:
212
+ return None
213
+
214
+ dtype = self.control_model.dtype
215
+ if self.manual_cast_dtype is not None:
216
+ dtype = self.manual_cast_dtype
217
+
218
+ if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
219
+ if self.cond_hint is not None:
220
+ del self.cond_hint
221
+ self.cond_hint = None
222
+ compression_ratio = self.compression_ratio
223
+ if self.vae is not None:
224
+ compression_ratio *= self.vae.downscale_ratio
225
+ self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
226
+ if self.vae is not None:
227
+ loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
228
+ self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
229
+ comfy.model_management.load_models_gpu(loaded_models)
230
+ if self.latent_format is not None:
231
+ self.cond_hint = self.latent_format.process_in(self.cond_hint)
232
+ if len(self.extra_concat_orig) > 0:
233
+ to_concat = []
234
+ for c in self.extra_concat_orig:
235
+ c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center")
236
+ to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
237
+ self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
238
+
239
+ self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
240
+ if x_noisy.shape[0] != self.cond_hint.shape[0]:
241
+ self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
242
+
243
+ context = cond.get('crossattn_controlnet', cond['c_crossattn'])
244
+ extra = self.extra_args.copy()
245
+ for c in self.extra_conds:
246
+ temp = cond.get(c, None)
247
+ if temp is not None:
248
+ extra[c] = temp.to(dtype)
249
+
250
+ timestep = self.model_sampling_current.timestep(t)
251
+ x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
252
+
253
+ control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
254
+ return self.control_merge(control, control_prev, output_dtype=None)
255
+
256
+ def copy(self):
257
+ c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
258
+ c.control_model = self.control_model
259
+ c.control_model_wrapped = self.control_model_wrapped
260
+ self.copy_to(c)
261
+ return c
262
+
263
+ def get_models(self):
264
+ out = super().get_models()
265
+ out.append(self.control_model_wrapped)
266
+ return out
267
+
268
+ def pre_run(self, model, percent_to_timestep_function):
269
+ super().pre_run(model, percent_to_timestep_function)
270
+ self.model_sampling_current = model.model_sampling
271
+
272
+ def cleanup(self):
273
+ self.model_sampling_current = None
274
+ super().cleanup()
275
+
276
+ class ControlLoraOps:
277
+ class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
278
+ def __init__(self, in_features: int, out_features: int, bias: bool = True,
279
+ device=None, dtype=None) -> None:
280
+ factory_kwargs = {'device': device, 'dtype': dtype}
281
+ super().__init__()
282
+ self.in_features = in_features
283
+ self.out_features = out_features
284
+ self.weight = None
285
+ self.up = None
286
+ self.down = None
287
+ self.bias = None
288
+
289
+ def forward(self, input):
290
+ weight, bias = comfy.ops.cast_bias_weight(self, input)
291
+ if self.up is not None:
292
+ return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
293
+ else:
294
+ return torch.nn.functional.linear(input, weight, bias)
295
+
296
+ class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
297
+ def __init__(
298
+ self,
299
+ in_channels,
300
+ out_channels,
301
+ kernel_size,
302
+ stride=1,
303
+ padding=0,
304
+ dilation=1,
305
+ groups=1,
306
+ bias=True,
307
+ padding_mode='zeros',
308
+ device=None,
309
+ dtype=None
310
+ ):
311
+ super().__init__()
312
+ self.in_channels = in_channels
313
+ self.out_channels = out_channels
314
+ self.kernel_size = kernel_size
315
+ self.stride = stride
316
+ self.padding = padding
317
+ self.dilation = dilation
318
+ self.transposed = False
319
+ self.output_padding = 0
320
+ self.groups = groups
321
+ self.padding_mode = padding_mode
322
+
323
+ self.weight = None
324
+ self.bias = None
325
+ self.up = None
326
+ self.down = None
327
+
328
+
329
+ def forward(self, input):
330
+ weight, bias = comfy.ops.cast_bias_weight(self, input)
331
+ if self.up is not None:
332
+ return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
333
+ else:
334
+ return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
335
+
336
+
337
+ class ControlLora(ControlNet):
338
+ def __init__(self, control_weights, global_average_pooling=False, device=None):
339
+ ControlBase.__init__(self, device)
340
+ self.control_weights = control_weights
341
+ self.global_average_pooling = global_average_pooling
342
+ self.extra_conds += ["y"]
343
+
344
+ def pre_run(self, model, percent_to_timestep_function):
345
+ super().pre_run(model, percent_to_timestep_function)
346
+ controlnet_config = model.model_config.unet_config.copy()
347
+ controlnet_config.pop("out_channels")
348
+ controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
349
+ self.manual_cast_dtype = model.manual_cast_dtype
350
+ dtype = model.get_dtype()
351
+ if self.manual_cast_dtype is None:
352
+ class control_lora_ops(ControlLoraOps, comfy.ops.disable_weight_init):
353
+ pass
354
+ else:
355
+ class control_lora_ops(ControlLoraOps, comfy.ops.manual_cast):
356
+ pass
357
+ dtype = self.manual_cast_dtype
358
+
359
+ controlnet_config["operations"] = control_lora_ops
360
+ controlnet_config["dtype"] = dtype
361
+ self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
362
+ self.control_model.to(comfy.model_management.get_torch_device())
363
+ diffusion_model = model.diffusion_model
364
+ sd = diffusion_model.state_dict()
365
+ cm = self.control_model.state_dict()
366
+
367
+ for k in sd:
368
+ weight = sd[k]
369
+ try:
370
+ comfy.utils.set_attr_param(self.control_model, k, weight)
371
+ except:
372
+ pass
373
+
374
+ for k in self.control_weights:
375
+ if k not in {"lora_controlnet"}:
376
+ comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
377
+
378
+ def copy(self):
379
+ c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
380
+ self.copy_to(c)
381
+ return c
382
+
383
+ def cleanup(self):
384
+ del self.control_model
385
+ self.control_model = None
386
+ super().cleanup()
387
+
388
+ def get_models(self):
389
+ out = ControlBase.get_models(self)
390
+ return out
391
+
392
+ def inference_memory_requirements(self, dtype):
393
+ return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
394
+
395
+ def controlnet_config(sd):
396
+ model_config = comfy.model_detection.model_config_from_unet(sd, "", True)
397
+
398
+ supported_inference_dtypes = model_config.supported_inference_dtypes
399
+
400
+ controlnet_config = model_config.unet_config
401
+ unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
402
+ load_device = comfy.model_management.get_torch_device()
403
+ manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
404
+ if manual_cast_dtype is not None:
405
+ operations = comfy.ops.manual_cast
406
+ else:
407
+ operations = comfy.ops.disable_weight_init
408
+
409
+ offload_device = comfy.model_management.unet_offload_device()
410
+ return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
411
+
412
+ def controlnet_load_state_dict(control_model, sd):
413
+ missing, unexpected = control_model.load_state_dict(sd, strict=False)
414
+
415
+ if len(missing) > 0:
416
+ logging.warning("missing controlnet keys: {}".format(missing))
417
+
418
+ if len(unexpected) > 0:
419
+ logging.debug("unexpected controlnet keys: {}".format(unexpected))
420
+ return control_model
421
+
422
+ def load_controlnet_mmdit(sd):
423
+ new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
424
+ model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
425
+ num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
426
+ for k in sd:
427
+ new_sd[k] = sd[k]
428
+
429
+ concat_mask = False
430
+ control_latent_channels = new_sd.get("pos_embed_input.proj.weight").shape[1]
431
+ if control_latent_channels == 17: #inpaint controlnet
432
+ concat_mask = True
433
+
434
+ control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
435
+ control_model = controlnet_load_state_dict(control_model, new_sd)
436
+
437
+ latent_format = comfy.latent_formats.SD3()
438
+ latent_format.shift_factor = 0 #SD3 controlnet weirdness
439
+ control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
440
+ return control
441
+
442
+
443
+ def load_controlnet_hunyuandit(controlnet_data):
444
+ model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data)
445
+
446
+ control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
447
+ control_model = controlnet_load_state_dict(control_model, controlnet_data)
448
+
449
+ latent_format = comfy.latent_formats.SDXL()
450
+ extra_conds = ['text_embedding_mask', 'encoder_hidden_states_t5', 'text_embedding_mask_t5', 'image_meta_size', 'style', 'cos_cis_img', 'sin_cis_img']
451
+ control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
452
+ return control
453
+
454
+ def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False):
455
+ model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
456
+ control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
457
+ control_model = controlnet_load_state_dict(control_model, sd)
458
+ extra_conds = ['y', 'guidance']
459
+ control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
460
+ return control
461
+
462
+ def load_controlnet_flux_instantx(sd):
463
+ new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
464
+ model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
465
+ for k in sd:
466
+ new_sd[k] = sd[k]
467
+
468
+ num_union_modes = 0
469
+ union_cnet = "controlnet_mode_embedder.weight"
470
+ if union_cnet in new_sd:
471
+ num_union_modes = new_sd[union_cnet].shape[0]
472
+
473
+ control_latent_channels = new_sd.get("pos_embed_input.weight").shape[1] // 4
474
+ concat_mask = False
475
+ if control_latent_channels == 17:
476
+ concat_mask = True
477
+
478
+ control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
479
+ control_model = controlnet_load_state_dict(control_model, new_sd)
480
+
481
+ latent_format = comfy.latent_formats.Flux()
482
+ extra_conds = ['y', 'guidance']
483
+ control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
484
+ return control
485
+
486
+ def convert_mistoline(sd):
487
+ return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
488
+
489
+
490
+ def load_controlnet(ckpt_path, model=None):
491
+ controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
492
+ if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
493
+ return load_controlnet_hunyuandit(controlnet_data)
494
+
495
+ if "lora_controlnet" in controlnet_data:
496
+ return ControlLora(controlnet_data)
497
+
498
+ controlnet_config = None
499
+ supported_inference_dtypes = None
500
+
501
+ if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
502
+ controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data)
503
+ diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
504
+ diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
505
+ diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
506
+
507
+ count = 0
508
+ loop = True
509
+ while loop:
510
+ suffix = [".weight", ".bias"]
511
+ for s in suffix:
512
+ k_in = "controlnet_down_blocks.{}{}".format(count, s)
513
+ k_out = "zero_convs.{}.0{}".format(count, s)
514
+ if k_in not in controlnet_data:
515
+ loop = False
516
+ break
517
+ diffusers_keys[k_in] = k_out
518
+ count += 1
519
+
520
+ count = 0
521
+ loop = True
522
+ while loop:
523
+ suffix = [".weight", ".bias"]
524
+ for s in suffix:
525
+ if count == 0:
526
+ k_in = "controlnet_cond_embedding.conv_in{}".format(s)
527
+ else:
528
+ k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s)
529
+ k_out = "input_hint_block.{}{}".format(count * 2, s)
530
+ if k_in not in controlnet_data:
531
+ k_in = "controlnet_cond_embedding.conv_out{}".format(s)
532
+ loop = False
533
+ diffusers_keys[k_in] = k_out
534
+ count += 1
535
+
536
+ new_sd = {}
537
+ for k in diffusers_keys:
538
+ if k in controlnet_data:
539
+ new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
540
+
541
+ if "control_add_embedding.linear_1.bias" in controlnet_data: #Union Controlnet
542
+ controlnet_config["union_controlnet_num_control_type"] = controlnet_data["task_embedding"].shape[0]
543
+ for k in list(controlnet_data.keys()):
544
+ new_k = k.replace('.attn.in_proj_', '.attn.in_proj.')
545
+ new_sd[new_k] = controlnet_data.pop(k)
546
+
547
+ leftover_keys = controlnet_data.keys()
548
+ if len(leftover_keys) > 0:
549
+ logging.warning("leftover keys: {}".format(leftover_keys))
550
+ controlnet_data = new_sd
551
+ elif "controlnet_blocks.0.weight" in controlnet_data:
552
+ if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
553
+ return load_controlnet_flux_xlabs_mistoline(controlnet_data)
554
+ elif "pos_embed_input.proj.weight" in controlnet_data:
555
+ return load_controlnet_mmdit(controlnet_data) #SD3 diffusers controlnet
556
+ elif "controlnet_x_embedder.weight" in controlnet_data:
557
+ return load_controlnet_flux_instantx(controlnet_data)
558
+ elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
559
+ return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True)
560
+
561
+ pth_key = 'control_model.zero_convs.0.0.weight'
562
+ pth = False
563
+ key = 'zero_convs.0.0.weight'
564
+ if pth_key in controlnet_data:
565
+ pth = True
566
+ key = pth_key
567
+ prefix = "control_model."
568
+ elif key in controlnet_data:
569
+ prefix = ""
570
+ else:
571
+ net = load_t2i_adapter(controlnet_data)
572
+ if net is None:
573
+ logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
574
+ return net
575
+
576
+ if controlnet_config is None:
577
+ model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
578
+ supported_inference_dtypes = model_config.supported_inference_dtypes
579
+ controlnet_config = model_config.unet_config
580
+
581
+ load_device = comfy.model_management.get_torch_device()
582
+ if supported_inference_dtypes is None:
583
+ unet_dtype = comfy.model_management.unet_dtype()
584
+ else:
585
+ unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
586
+
587
+ manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
588
+ if manual_cast_dtype is not None:
589
+ controlnet_config["operations"] = comfy.ops.manual_cast
590
+ controlnet_config["dtype"] = unet_dtype
591
+ controlnet_config["device"] = comfy.model_management.unet_offload_device()
592
+ controlnet_config.pop("out_channels")
593
+ controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
594
+ control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
595
+
596
+ if pth:
597
+ if 'difference' in controlnet_data:
598
+ if model is not None:
599
+ comfy.model_management.load_models_gpu([model])
600
+ model_sd = model.model_state_dict()
601
+ for x in controlnet_data:
602
+ c_m = "control_model."
603
+ if x.startswith(c_m):
604
+ sd_key = "diffusion_model.{}".format(x[len(c_m):])
605
+ if sd_key in model_sd:
606
+ cd = controlnet_data[x]
607
+ cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
608
+ else:
609
+ logging.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
610
+
611
+ class WeightsLoader(torch.nn.Module):
612
+ pass
613
+ w = WeightsLoader()
614
+ w.control_model = control_model
615
+ missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
616
+ else:
617
+ missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
618
+
619
+ if len(missing) > 0:
620
+ logging.warning("missing controlnet keys: {}".format(missing))
621
+
622
+ if len(unexpected) > 0:
623
+ logging.debug("unexpected controlnet keys: {}".format(unexpected))
624
+
625
+ global_average_pooling = False
626
+ filename = os.path.splitext(ckpt_path)[0]
627
+ if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
628
+ global_average_pooling = True
629
+
630
+ control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
631
+ return control
632
+
633
+ class T2IAdapter(ControlBase):
634
+ def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
635
+ super().__init__(device)
636
+ self.t2i_model = t2i_model
637
+ self.channels_in = channels_in
638
+ self.control_input = None
639
+ self.compression_ratio = compression_ratio
640
+ self.upscale_algorithm = upscale_algorithm
641
+
642
+ def scale_image_to(self, width, height):
643
+ unshuffle_amount = self.t2i_model.unshuffle_amount
644
+ width = math.ceil(width / unshuffle_amount) * unshuffle_amount
645
+ height = math.ceil(height / unshuffle_amount) * unshuffle_amount
646
+ return width, height
647
+
648
+ def get_control(self, x_noisy, t, cond, batched_number):
649
+ control_prev = None
650
+ if self.previous_controlnet is not None:
651
+ control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
652
+
653
+ if self.timestep_range is not None:
654
+ if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
655
+ if control_prev is not None:
656
+ return control_prev
657
+ else:
658
+ return None
659
+
660
+ if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
661
+ if self.cond_hint is not None:
662
+ del self.cond_hint
663
+ self.control_input = None
664
+ self.cond_hint = None
665
+ width, height = self.scale_image_to(x_noisy.shape[3] * self.compression_ratio, x_noisy.shape[2] * self.compression_ratio)
666
+ self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, width, height, self.upscale_algorithm, "center").float().to(self.device)
667
+ if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
668
+ self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
669
+ if x_noisy.shape[0] != self.cond_hint.shape[0]:
670
+ self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
671
+ if self.control_input is None:
672
+ self.t2i_model.to(x_noisy.dtype)
673
+ self.t2i_model.to(self.device)
674
+ self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
675
+ self.t2i_model.cpu()
676
+
677
+ control_input = {}
678
+ for k in self.control_input:
679
+ control_input[k] = list(map(lambda a: None if a is None else a.clone(), self.control_input[k]))
680
+
681
+ return self.control_merge(control_input, control_prev, x_noisy.dtype)
682
+
683
+ def copy(self):
684
+ c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm)
685
+ self.copy_to(c)
686
+ return c
687
+
688
+ def load_t2i_adapter(t2i_data):
689
+ compression_ratio = 8
690
+ upscale_algorithm = 'nearest-exact'
691
+
692
+ if 'adapter' in t2i_data:
693
+ t2i_data = t2i_data['adapter']
694
+ if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: #diffusers format
695
+ prefix_replace = {}
696
+ for i in range(4):
697
+ for j in range(2):
698
+ prefix_replace["adapter.body.{}.resnets.{}.".format(i, j)] = "body.{}.".format(i * 2 + j)
699
+ prefix_replace["adapter.body.{}.".format(i, j)] = "body.{}.".format(i * 2)
700
+ prefix_replace["adapter."] = ""
701
+ t2i_data = comfy.utils.state_dict_prefix_replace(t2i_data, prefix_replace)
702
+ keys = t2i_data.keys()
703
+
704
+ if "body.0.in_conv.weight" in keys:
705
+ cin = t2i_data['body.0.in_conv.weight'].shape[1]
706
+ model_ad = comfy.t2i_adapter.adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4)
707
+ elif 'conv_in.weight' in keys:
708
+ cin = t2i_data['conv_in.weight'].shape[1]
709
+ channel = t2i_data['conv_in.weight'].shape[0]
710
+ ksize = t2i_data['body.0.block2.weight'].shape[2]
711
+ use_conv = False
712
+ down_opts = list(filter(lambda a: a.endswith("down_opt.op.weight"), keys))
713
+ if len(down_opts) > 0:
714
+ use_conv = True
715
+ xl = False
716
+ if cin == 256 or cin == 768:
717
+ xl = True
718
+ model_ad = comfy.t2i_adapter.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
719
+ elif "backbone.0.0.weight" in keys:
720
+ model_ad = comfy.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.0.weight'].shape[1], proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
721
+ compression_ratio = 32
722
+ upscale_algorithm = 'bilinear'
723
+ elif "backbone.10.blocks.0.weight" in keys:
724
+ model_ad = comfy.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.weight'].shape[1], bottleneck_mode="large", proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
725
+ compression_ratio = 1
726
+ upscale_algorithm = 'nearest-exact'
727
+ else:
728
+ return None
729
+
730
+ missing, unexpected = model_ad.load_state_dict(t2i_data)
731
+ if len(missing) > 0:
732
+ logging.warning("t2i missing {}".format(missing))
733
+
734
+ if len(unexpected) > 0:
735
+ logging.debug("t2i unexpected {}".format(unexpected))
736
+
737
+ return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm)
comfy/diffusers_convert.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ import logging
4
+
5
+ # conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
6
+
7
+ # =================#
8
+ # UNet Conversion #
9
+ # =================#
10
+
11
+ unet_conversion_map = [
12
+ # (stable-diffusion, HF Diffusers)
13
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
14
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
15
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
16
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
17
+ ("input_blocks.0.0.weight", "conv_in.weight"),
18
+ ("input_blocks.0.0.bias", "conv_in.bias"),
19
+ ("out.0.weight", "conv_norm_out.weight"),
20
+ ("out.0.bias", "conv_norm_out.bias"),
21
+ ("out.2.weight", "conv_out.weight"),
22
+ ("out.2.bias", "conv_out.bias"),
23
+ ]
24
+
25
+ unet_conversion_map_resnet = [
26
+ # (stable-diffusion, HF Diffusers)
27
+ ("in_layers.0", "norm1"),
28
+ ("in_layers.2", "conv1"),
29
+ ("out_layers.0", "norm2"),
30
+ ("out_layers.3", "conv2"),
31
+ ("emb_layers.1", "time_emb_proj"),
32
+ ("skip_connection", "conv_shortcut"),
33
+ ]
34
+
35
+ unet_conversion_map_layer = []
36
+ # hardcoded number of downblocks and resnets/attentions...
37
+ # would need smarter logic for other networks.
38
+ for i in range(4):
39
+ # loop over downblocks/upblocks
40
+
41
+ for j in range(2):
42
+ # loop over resnets/attentions for downblocks
43
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
44
+ sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
45
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
46
+
47
+ if i < 3:
48
+ # no attention layers in down_blocks.3
49
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
50
+ sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
51
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
52
+
53
+ for j in range(3):
54
+ # loop over resnets/attentions for upblocks
55
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
56
+ sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
57
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
58
+
59
+ if i > 0:
60
+ # no attention layers in up_blocks.0
61
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
62
+ sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
63
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
64
+
65
+ if i < 3:
66
+ # no downsample in down_blocks.3
67
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
68
+ sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
69
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
70
+
71
+ # no upsample in up_blocks.3
72
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
73
+ sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
74
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
75
+
76
+ hf_mid_atn_prefix = "mid_block.attentions.0."
77
+ sd_mid_atn_prefix = "middle_block.1."
78
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
79
+
80
+ for j in range(2):
81
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
82
+ sd_mid_res_prefix = f"middle_block.{2 * j}."
83
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
84
+
85
+
86
+ def convert_unet_state_dict(unet_state_dict):
87
+ # buyer beware: this is a *brittle* function,
88
+ # and correct output requires that all of these pieces interact in
89
+ # the exact order in which I have arranged them.
90
+ mapping = {k: k for k in unet_state_dict.keys()}
91
+ for sd_name, hf_name in unet_conversion_map:
92
+ mapping[hf_name] = sd_name
93
+ for k, v in mapping.items():
94
+ if "resnets" in k:
95
+ for sd_part, hf_part in unet_conversion_map_resnet:
96
+ v = v.replace(hf_part, sd_part)
97
+ mapping[k] = v
98
+ for k, v in mapping.items():
99
+ for sd_part, hf_part in unet_conversion_map_layer:
100
+ v = v.replace(hf_part, sd_part)
101
+ mapping[k] = v
102
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
103
+ return new_state_dict
104
+
105
+
106
+ # ================#
107
+ # VAE Conversion #
108
+ # ================#
109
+
110
+ vae_conversion_map = [
111
+ # (stable-diffusion, HF Diffusers)
112
+ ("nin_shortcut", "conv_shortcut"),
113
+ ("norm_out", "conv_norm_out"),
114
+ ("mid.attn_1.", "mid_block.attentions.0."),
115
+ ]
116
+
117
+ for i in range(4):
118
+ # down_blocks have two resnets
119
+ for j in range(2):
120
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
121
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
122
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
123
+
124
+ if i < 3:
125
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
126
+ sd_downsample_prefix = f"down.{i}.downsample."
127
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
128
+
129
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
130
+ sd_upsample_prefix = f"up.{3 - i}.upsample."
131
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
132
+
133
+ # up_blocks have three resnets
134
+ # also, up blocks in hf are numbered in reverse from sd
135
+ for j in range(3):
136
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
137
+ sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
138
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
139
+
140
+ # this part accounts for mid blocks in both the encoder and the decoder
141
+ for i in range(2):
142
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
143
+ sd_mid_res_prefix = f"mid.block_{i + 1}."
144
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
145
+
146
+ vae_conversion_map_attn = [
147
+ # (stable-diffusion, HF Diffusers)
148
+ ("norm.", "group_norm."),
149
+ ("q.", "query."),
150
+ ("k.", "key."),
151
+ ("v.", "value."),
152
+ ("q.", "to_q."),
153
+ ("k.", "to_k."),
154
+ ("v.", "to_v."),
155
+ ("proj_out.", "to_out.0."),
156
+ ("proj_out.", "proj_attn."),
157
+ ]
158
+
159
+
160
+ def reshape_weight_for_sd(w):
161
+ # convert HF linear weights to SD conv2d weights
162
+ return w.reshape(*w.shape, 1, 1)
163
+
164
+
165
+ def convert_vae_state_dict(vae_state_dict):
166
+ mapping = {k: k for k in vae_state_dict.keys()}
167
+ for k, v in mapping.items():
168
+ for sd_part, hf_part in vae_conversion_map:
169
+ v = v.replace(hf_part, sd_part)
170
+ mapping[k] = v
171
+ for k, v in mapping.items():
172
+ if "attentions" in k:
173
+ for sd_part, hf_part in vae_conversion_map_attn:
174
+ v = v.replace(hf_part, sd_part)
175
+ mapping[k] = v
176
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
177
+ weights_to_convert = ["q", "k", "v", "proj_out"]
178
+ for k, v in new_state_dict.items():
179
+ for weight_name in weights_to_convert:
180
+ if f"mid.attn_1.{weight_name}.weight" in k:
181
+ logging.debug(f"Reshaping {k} for SD format")
182
+ new_state_dict[k] = reshape_weight_for_sd(v)
183
+ return new_state_dict
184
+
185
+
186
+ # =========================#
187
+ # Text Encoder Conversion #
188
+ # =========================#
189
+
190
+
191
+ textenc_conversion_lst = [
192
+ # (stable-diffusion, HF Diffusers)
193
+ ("resblocks.", "text_model.encoder.layers."),
194
+ ("ln_1", "layer_norm1"),
195
+ ("ln_2", "layer_norm2"),
196
+ (".c_fc.", ".fc1."),
197
+ (".c_proj.", ".fc2."),
198
+ (".attn", ".self_attn"),
199
+ ("ln_final.", "transformer.text_model.final_layer_norm."),
200
+ ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
201
+ ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
202
+ ]
203
+ protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
204
+ textenc_pattern = re.compile("|".join(protected.keys()))
205
+
206
+ # Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
207
+ code2idx = {"q": 0, "k": 1, "v": 2}
208
+
209
+ # This function exists because at the time of writing torch.cat can't do fp8 with cuda
210
+ def cat_tensors(tensors):
211
+ x = 0
212
+ for t in tensors:
213
+ x += t.shape[0]
214
+
215
+ shape = [x] + list(tensors[0].shape)[1:]
216
+ out = torch.empty(shape, device=tensors[0].device, dtype=tensors[0].dtype)
217
+
218
+ x = 0
219
+ for t in tensors:
220
+ out[x:x + t.shape[0]] = t
221
+ x += t.shape[0]
222
+
223
+ return out
224
+
225
+ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
226
+ new_state_dict = {}
227
+ capture_qkv_weight = {}
228
+ capture_qkv_bias = {}
229
+ for k, v in text_enc_dict.items():
230
+ if not k.startswith(prefix):
231
+ continue
232
+ if (
233
+ k.endswith(".self_attn.q_proj.weight")
234
+ or k.endswith(".self_attn.k_proj.weight")
235
+ or k.endswith(".self_attn.v_proj.weight")
236
+ ):
237
+ k_pre = k[: -len(".q_proj.weight")]
238
+ k_code = k[-len("q_proj.weight")]
239
+ if k_pre not in capture_qkv_weight:
240
+ capture_qkv_weight[k_pre] = [None, None, None]
241
+ capture_qkv_weight[k_pre][code2idx[k_code]] = v
242
+ continue
243
+
244
+ if (
245
+ k.endswith(".self_attn.q_proj.bias")
246
+ or k.endswith(".self_attn.k_proj.bias")
247
+ or k.endswith(".self_attn.v_proj.bias")
248
+ ):
249
+ k_pre = k[: -len(".q_proj.bias")]
250
+ k_code = k[-len("q_proj.bias")]
251
+ if k_pre not in capture_qkv_bias:
252
+ capture_qkv_bias[k_pre] = [None, None, None]
253
+ capture_qkv_bias[k_pre][code2idx[k_code]] = v
254
+ continue
255
+
256
+ text_proj = "transformer.text_projection.weight"
257
+ if k.endswith(text_proj):
258
+ new_state_dict[k.replace(text_proj, "text_projection")] = v.transpose(0, 1).contiguous()
259
+ else:
260
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
261
+ new_state_dict[relabelled_key] = v
262
+
263
+ for k_pre, tensors in capture_qkv_weight.items():
264
+ if None in tensors:
265
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
266
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
267
+ new_state_dict[relabelled_key + ".in_proj_weight"] = cat_tensors(tensors)
268
+
269
+ for k_pre, tensors in capture_qkv_bias.items():
270
+ if None in tensors:
271
+ raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
272
+ relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
273
+ new_state_dict[relabelled_key + ".in_proj_bias"] = cat_tensors(tensors)
274
+
275
+ return new_state_dict
276
+
277
+
278
+ def convert_text_enc_state_dict(text_enc_dict):
279
+ return text_enc_dict
280
+
281
+
comfy/diffusers_load.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import comfy.sd
4
+
5
+ def first_file(path, filenames):
6
+ for f in filenames:
7
+ p = os.path.join(path, f)
8
+ if os.path.exists(p):
9
+ return p
10
+ return None
11
+
12
+ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_directory=None):
13
+ diffusion_model_names = ["diffusion_pytorch_model.fp16.safetensors", "diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.fp16.bin", "diffusion_pytorch_model.bin"]
14
+ unet_path = first_file(os.path.join(model_path, "unet"), diffusion_model_names)
15
+ vae_path = first_file(os.path.join(model_path, "vae"), diffusion_model_names)
16
+
17
+ text_encoder_model_names = ["model.fp16.safetensors", "model.safetensors", "pytorch_model.fp16.bin", "pytorch_model.bin"]
18
+ text_encoder1_path = first_file(os.path.join(model_path, "text_encoder"), text_encoder_model_names)
19
+ text_encoder2_path = first_file(os.path.join(model_path, "text_encoder_2"), text_encoder_model_names)
20
+
21
+ text_encoder_paths = [text_encoder1_path]
22
+ if text_encoder2_path is not None:
23
+ text_encoder_paths.append(text_encoder2_path)
24
+
25
+ unet = comfy.sd.load_diffusion_model(unet_path)
26
+
27
+ clip = None
28
+ if output_clip:
29
+ clip = comfy.sd.load_clip(text_encoder_paths, embedding_directory=embedding_directory)
30
+
31
+ vae = None
32
+ if output_vae:
33
+ sd = comfy.utils.load_torch_file(vae_path)
34
+ vae = comfy.sd.VAE(sd=sd)
35
+
36
+ return (unet, clip, vae)
comfy/extra_samplers/uni_pc.py ADDED
@@ -0,0 +1,875 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #code taken from: https://github.com/wl-zhao/UniPC and modified
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import math
6
+
7
+ from tqdm.auto import trange, tqdm
8
+
9
+
10
+ class NoiseScheduleVP:
11
+ def __init__(
12
+ self,
13
+ schedule='discrete',
14
+ betas=None,
15
+ alphas_cumprod=None,
16
+ continuous_beta_0=0.1,
17
+ continuous_beta_1=20.,
18
+ ):
19
+ """Create a wrapper class for the forward SDE (VP type).
20
+
21
+ ***
22
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
23
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
24
+ ***
25
+
26
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
27
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
28
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
29
+
30
+ log_alpha_t = self.marginal_log_mean_coeff(t)
31
+ sigma_t = self.marginal_std(t)
32
+ lambda_t = self.marginal_lambda(t)
33
+
34
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
35
+
36
+ t = self.inverse_lambda(lambda_t)
37
+
38
+ ===============================================================
39
+
40
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
41
+
42
+ 1. For discrete-time DPMs:
43
+
44
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
45
+ t_i = (i + 1) / N
46
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
47
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
48
+
49
+ Args:
50
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
51
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
52
+
53
+ Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
54
+
55
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
56
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
57
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
58
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
59
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
60
+ and
61
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
62
+
63
+
64
+ 2. For continuous-time DPMs:
65
+
66
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
67
+ schedule are the default settings in DDPM and improved-DDPM:
68
+
69
+ Args:
70
+ beta_min: A `float` number. The smallest beta for the linear schedule.
71
+ beta_max: A `float` number. The largest beta for the linear schedule.
72
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
73
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
74
+ T: A `float` number. The ending time of the forward process.
75
+
76
+ ===============================================================
77
+
78
+ Args:
79
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
80
+ 'linear' or 'cosine' for continuous-time DPMs.
81
+ Returns:
82
+ A wrapper object of the forward SDE (VP type).
83
+
84
+ ===============================================================
85
+
86
+ Example:
87
+
88
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
89
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
90
+
91
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
92
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
93
+
94
+ # For continuous-time DPMs (VPSDE), linear schedule:
95
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
96
+
97
+ """
98
+
99
+ if schedule not in ['discrete', 'linear', 'cosine']:
100
+ raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule))
101
+
102
+ self.schedule = schedule
103
+ if schedule == 'discrete':
104
+ if betas is not None:
105
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
106
+ else:
107
+ assert alphas_cumprod is not None
108
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
109
+ self.total_N = len(log_alphas)
110
+ self.T = 1.
111
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
112
+ self.log_alpha_array = log_alphas.reshape((1, -1,))
113
+ else:
114
+ self.total_N = 1000
115
+ self.beta_0 = continuous_beta_0
116
+ self.beta_1 = continuous_beta_1
117
+ self.cosine_s = 0.008
118
+ self.cosine_beta_max = 999.
119
+ self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
120
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
121
+ self.schedule = schedule
122
+ if schedule == 'cosine':
123
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
124
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
125
+ self.T = 0.9946
126
+ else:
127
+ self.T = 1.
128
+
129
+ def marginal_log_mean_coeff(self, t):
130
+ """
131
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
132
+ """
133
+ if self.schedule == 'discrete':
134
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
135
+ elif self.schedule == 'linear':
136
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
137
+ elif self.schedule == 'cosine':
138
+ log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
139
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
140
+ return log_alpha_t
141
+
142
+ def marginal_alpha(self, t):
143
+ """
144
+ Compute alpha_t of a given continuous-time label t in [0, T].
145
+ """
146
+ return torch.exp(self.marginal_log_mean_coeff(t))
147
+
148
+ def marginal_std(self, t):
149
+ """
150
+ Compute sigma_t of a given continuous-time label t in [0, T].
151
+ """
152
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
153
+
154
+ def marginal_lambda(self, t):
155
+ """
156
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
157
+ """
158
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
159
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
160
+ return log_mean_coeff - log_std
161
+
162
+ def inverse_lambda(self, lamb):
163
+ """
164
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
165
+ """
166
+ if self.schedule == 'linear':
167
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
168
+ Delta = self.beta_0**2 + tmp
169
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
170
+ elif self.schedule == 'discrete':
171
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
172
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
173
+ return t.reshape((-1,))
174
+ else:
175
+ log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
176
+ t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
177
+ t = t_fn(log_alpha)
178
+ return t
179
+
180
+
181
+ def model_wrapper(
182
+ model,
183
+ noise_schedule,
184
+ model_type="noise",
185
+ model_kwargs={},
186
+ guidance_type="uncond",
187
+ condition=None,
188
+ unconditional_condition=None,
189
+ guidance_scale=1.,
190
+ classifier_fn=None,
191
+ classifier_kwargs={},
192
+ ):
193
+ """Create a wrapper function for the noise prediction model.
194
+
195
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
196
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
197
+
198
+ We support four types of the diffusion model by setting `model_type`:
199
+
200
+ 1. "noise": noise prediction model. (Trained by predicting noise).
201
+
202
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
203
+
204
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
205
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
206
+
207
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
208
+ arXiv preprint arXiv:2202.00512 (2022).
209
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
210
+ arXiv preprint arXiv:2210.02303 (2022).
211
+
212
+ 4. "score": marginal score function. (Trained by denoising score matching).
213
+ Note that the score function and the noise prediction model follows a simple relationship:
214
+ ```
215
+ noise(x_t, t) = -sigma_t * score(x_t, t)
216
+ ```
217
+
218
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
219
+ 1. "uncond": unconditional sampling by DPMs.
220
+ The input `model` has the following format:
221
+ ``
222
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
223
+ ``
224
+
225
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
226
+ The input `model` has the following format:
227
+ ``
228
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
229
+ ``
230
+
231
+ The input `classifier_fn` has the following format:
232
+ ``
233
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
234
+ ``
235
+
236
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
237
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
238
+
239
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
240
+ The input `model` has the following format:
241
+ ``
242
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
243
+ ``
244
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
245
+
246
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
247
+ arXiv preprint arXiv:2207.12598 (2022).
248
+
249
+
250
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
251
+ or continuous-time labels (i.e. epsilon to T).
252
+
253
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
254
+ ``
255
+ def model_fn(x, t_continuous) -> noise:
256
+ t_input = get_model_input_time(t_continuous)
257
+ return noise_pred(model, x, t_input, **model_kwargs)
258
+ ``
259
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
260
+
261
+ ===============================================================
262
+
263
+ Args:
264
+ model: A diffusion model with the corresponding format described above.
265
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
266
+ model_type: A `str`. The parameterization type of the diffusion model.
267
+ "noise" or "x_start" or "v" or "score".
268
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
269
+ guidance_type: A `str`. The type of the guidance for sampling.
270
+ "uncond" or "classifier" or "classifier-free".
271
+ condition: A pytorch tensor. The condition for the guided sampling.
272
+ Only used for "classifier" or "classifier-free" guidance type.
273
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
274
+ Only used for "classifier-free" guidance type.
275
+ guidance_scale: A `float`. The scale for the guided sampling.
276
+ classifier_fn: A classifier function. Only used for the classifier guidance.
277
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
278
+ Returns:
279
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
280
+ """
281
+
282
+ def get_model_input_time(t_continuous):
283
+ """
284
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
285
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
286
+ For continuous-time DPMs, we just use `t_continuous`.
287
+ """
288
+ if noise_schedule.schedule == 'discrete':
289
+ return (t_continuous - 1. / noise_schedule.total_N) * 1000.
290
+ else:
291
+ return t_continuous
292
+
293
+ def noise_pred_fn(x, t_continuous, cond=None):
294
+ if t_continuous.reshape((-1,)).shape[0] == 1:
295
+ t_continuous = t_continuous.expand((x.shape[0]))
296
+ t_input = get_model_input_time(t_continuous)
297
+ output = model(x, t_input, **model_kwargs)
298
+ if model_type == "noise":
299
+ return output
300
+ elif model_type == "x_start":
301
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
302
+ dims = x.dim()
303
+ return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
304
+ elif model_type == "v":
305
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
306
+ dims = x.dim()
307
+ return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
308
+ elif model_type == "score":
309
+ sigma_t = noise_schedule.marginal_std(t_continuous)
310
+ dims = x.dim()
311
+ return -expand_dims(sigma_t, dims) * output
312
+
313
+ def cond_grad_fn(x, t_input):
314
+ """
315
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
316
+ """
317
+ with torch.enable_grad():
318
+ x_in = x.detach().requires_grad_(True)
319
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
320
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
321
+
322
+ def model_fn(x, t_continuous):
323
+ """
324
+ The noise predicition model function that is used for DPM-Solver.
325
+ """
326
+ if t_continuous.reshape((-1,)).shape[0] == 1:
327
+ t_continuous = t_continuous.expand((x.shape[0]))
328
+ if guidance_type == "uncond":
329
+ return noise_pred_fn(x, t_continuous)
330
+ elif guidance_type == "classifier":
331
+ assert classifier_fn is not None
332
+ t_input = get_model_input_time(t_continuous)
333
+ cond_grad = cond_grad_fn(x, t_input)
334
+ sigma_t = noise_schedule.marginal_std(t_continuous)
335
+ noise = noise_pred_fn(x, t_continuous)
336
+ return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
337
+ elif guidance_type == "classifier-free":
338
+ if guidance_scale == 1. or unconditional_condition is None:
339
+ return noise_pred_fn(x, t_continuous, cond=condition)
340
+ else:
341
+ x_in = torch.cat([x] * 2)
342
+ t_in = torch.cat([t_continuous] * 2)
343
+ c_in = torch.cat([unconditional_condition, condition])
344
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
345
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
346
+
347
+ assert model_type in ["noise", "x_start", "v"]
348
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
349
+ return model_fn
350
+
351
+
352
+ class UniPC:
353
+ def __init__(
354
+ self,
355
+ model_fn,
356
+ noise_schedule,
357
+ predict_x0=True,
358
+ thresholding=False,
359
+ max_val=1.,
360
+ variant='bh1',
361
+ ):
362
+ """Construct a UniPC.
363
+
364
+ We support both data_prediction and noise_prediction.
365
+ """
366
+ self.model = model_fn
367
+ self.noise_schedule = noise_schedule
368
+ self.variant = variant
369
+ self.predict_x0 = predict_x0
370
+ self.thresholding = thresholding
371
+ self.max_val = max_val
372
+
373
+ def dynamic_thresholding_fn(self, x0, t=None):
374
+ """
375
+ The dynamic thresholding method.
376
+ """
377
+ dims = x0.dim()
378
+ p = self.dynamic_thresholding_ratio
379
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
380
+ s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
381
+ x0 = torch.clamp(x0, -s, s) / s
382
+ return x0
383
+
384
+ def noise_prediction_fn(self, x, t):
385
+ """
386
+ Return the noise prediction model.
387
+ """
388
+ return self.model(x, t)
389
+
390
+ def data_prediction_fn(self, x, t):
391
+ """
392
+ Return the data prediction model (with thresholding).
393
+ """
394
+ noise = self.noise_prediction_fn(x, t)
395
+ dims = x.dim()
396
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
397
+ x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
398
+ if self.thresholding:
399
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
400
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
401
+ s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
402
+ x0 = torch.clamp(x0, -s, s) / s
403
+ return x0
404
+
405
+ def model_fn(self, x, t):
406
+ """
407
+ Convert the model to the noise prediction model or the data prediction model.
408
+ """
409
+ if self.predict_x0:
410
+ return self.data_prediction_fn(x, t)
411
+ else:
412
+ return self.noise_prediction_fn(x, t)
413
+
414
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
415
+ """Compute the intermediate time steps for sampling.
416
+ """
417
+ if skip_type == 'logSNR':
418
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
419
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
420
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
421
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
422
+ elif skip_type == 'time_uniform':
423
+ return torch.linspace(t_T, t_0, N + 1).to(device)
424
+ elif skip_type == 'time_quadratic':
425
+ t_order = 2
426
+ t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
427
+ return t
428
+ else:
429
+ raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
430
+
431
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
432
+ """
433
+ Get the order of each step for sampling by the singlestep DPM-Solver.
434
+ """
435
+ if order == 3:
436
+ K = steps // 3 + 1
437
+ if steps % 3 == 0:
438
+ orders = [3,] * (K - 2) + [2, 1]
439
+ elif steps % 3 == 1:
440
+ orders = [3,] * (K - 1) + [1]
441
+ else:
442
+ orders = [3,] * (K - 1) + [2]
443
+ elif order == 2:
444
+ if steps % 2 == 0:
445
+ K = steps // 2
446
+ orders = [2,] * K
447
+ else:
448
+ K = steps // 2 + 1
449
+ orders = [2,] * (K - 1) + [1]
450
+ elif order == 1:
451
+ K = steps
452
+ orders = [1,] * steps
453
+ else:
454
+ raise ValueError("'order' must be '1' or '2' or '3'.")
455
+ if skip_type == 'logSNR':
456
+ # To reproduce the results in DPM-Solver paper
457
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
458
+ else:
459
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)]
460
+ return timesteps_outer, orders
461
+
462
+ def denoise_to_zero_fn(self, x, s):
463
+ """
464
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
465
+ """
466
+ return self.data_prediction_fn(x, s)
467
+
468
+ def multistep_uni_pc_update(self, x, model_prev_list, t_prev_list, t, order, **kwargs):
469
+ if len(t.shape) == 0:
470
+ t = t.view(-1)
471
+ if 'bh' in self.variant:
472
+ return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
473
+ else:
474
+ assert self.variant == 'vary_coeff'
475
+ return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
476
+
477
+ def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
478
+ print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
479
+ ns = self.noise_schedule
480
+ assert order <= len(model_prev_list)
481
+
482
+ # first compute rks
483
+ t_prev_0 = t_prev_list[-1]
484
+ lambda_prev_0 = ns.marginal_lambda(t_prev_0)
485
+ lambda_t = ns.marginal_lambda(t)
486
+ model_prev_0 = model_prev_list[-1]
487
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
488
+ log_alpha_t = ns.marginal_log_mean_coeff(t)
489
+ alpha_t = torch.exp(log_alpha_t)
490
+
491
+ h = lambda_t - lambda_prev_0
492
+
493
+ rks = []
494
+ D1s = []
495
+ for i in range(1, order):
496
+ t_prev_i = t_prev_list[-(i + 1)]
497
+ model_prev_i = model_prev_list[-(i + 1)]
498
+ lambda_prev_i = ns.marginal_lambda(t_prev_i)
499
+ rk = (lambda_prev_i - lambda_prev_0) / h
500
+ rks.append(rk)
501
+ D1s.append((model_prev_i - model_prev_0) / rk)
502
+
503
+ rks.append(1.)
504
+ rks = torch.tensor(rks, device=x.device)
505
+
506
+ K = len(rks)
507
+ # build C matrix
508
+ C = []
509
+
510
+ col = torch.ones_like(rks)
511
+ for k in range(1, K + 1):
512
+ C.append(col)
513
+ col = col * rks / (k + 1)
514
+ C = torch.stack(C, dim=1)
515
+
516
+ if len(D1s) > 0:
517
+ D1s = torch.stack(D1s, dim=1) # (B, K)
518
+ C_inv_p = torch.linalg.inv(C[:-1, :-1])
519
+ A_p = C_inv_p
520
+
521
+ if use_corrector:
522
+ print('using corrector')
523
+ C_inv = torch.linalg.inv(C)
524
+ A_c = C_inv
525
+
526
+ hh = -h if self.predict_x0 else h
527
+ h_phi_1 = torch.expm1(hh)
528
+ h_phi_ks = []
529
+ factorial_k = 1
530
+ h_phi_k = h_phi_1
531
+ for k in range(1, K + 2):
532
+ h_phi_ks.append(h_phi_k)
533
+ h_phi_k = h_phi_k / hh - 1 / factorial_k
534
+ factorial_k *= (k + 1)
535
+
536
+ model_t = None
537
+ if self.predict_x0:
538
+ x_t_ = (
539
+ sigma_t / sigma_prev_0 * x
540
+ - alpha_t * h_phi_1 * model_prev_0
541
+ )
542
+ # now predictor
543
+ x_t = x_t_
544
+ if len(D1s) > 0:
545
+ # compute the residuals for predictor
546
+ for k in range(K - 1):
547
+ x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
548
+ # now corrector
549
+ if use_corrector:
550
+ model_t = self.model_fn(x_t, t)
551
+ D1_t = (model_t - model_prev_0)
552
+ x_t = x_t_
553
+ k = 0
554
+ for k in range(K - 1):
555
+ x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
556
+ x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
557
+ else:
558
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
559
+ x_t_ = (
560
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
561
+ - (sigma_t * h_phi_1) * model_prev_0
562
+ )
563
+ # now predictor
564
+ x_t = x_t_
565
+ if len(D1s) > 0:
566
+ # compute the residuals for predictor
567
+ for k in range(K - 1):
568
+ x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
569
+ # now corrector
570
+ if use_corrector:
571
+ model_t = self.model_fn(x_t, t)
572
+ D1_t = (model_t - model_prev_0)
573
+ x_t = x_t_
574
+ k = 0
575
+ for k in range(K - 1):
576
+ x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
577
+ x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
578
+ return x_t, model_t
579
+
580
+ def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True):
581
+ # print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
582
+ ns = self.noise_schedule
583
+ assert order <= len(model_prev_list)
584
+ dims = x.dim()
585
+
586
+ # first compute rks
587
+ t_prev_0 = t_prev_list[-1]
588
+ lambda_prev_0 = ns.marginal_lambda(t_prev_0)
589
+ lambda_t = ns.marginal_lambda(t)
590
+ model_prev_0 = model_prev_list[-1]
591
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
592
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
593
+ alpha_t = torch.exp(log_alpha_t)
594
+
595
+ h = lambda_t - lambda_prev_0
596
+
597
+ rks = []
598
+ D1s = []
599
+ for i in range(1, order):
600
+ t_prev_i = t_prev_list[-(i + 1)]
601
+ model_prev_i = model_prev_list[-(i + 1)]
602
+ lambda_prev_i = ns.marginal_lambda(t_prev_i)
603
+ rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
604
+ rks.append(rk)
605
+ D1s.append((model_prev_i - model_prev_0) / rk)
606
+
607
+ rks.append(1.)
608
+ rks = torch.tensor(rks, device=x.device)
609
+
610
+ R = []
611
+ b = []
612
+
613
+ hh = -h[0] if self.predict_x0 else h[0]
614
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
615
+ h_phi_k = h_phi_1 / hh - 1
616
+
617
+ factorial_i = 1
618
+
619
+ if self.variant == 'bh1':
620
+ B_h = hh
621
+ elif self.variant == 'bh2':
622
+ B_h = torch.expm1(hh)
623
+ else:
624
+ raise NotImplementedError()
625
+
626
+ for i in range(1, order + 1):
627
+ R.append(torch.pow(rks, i - 1))
628
+ b.append(h_phi_k * factorial_i / B_h)
629
+ factorial_i *= (i + 1)
630
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
631
+
632
+ R = torch.stack(R)
633
+ b = torch.tensor(b, device=x.device)
634
+
635
+ # now predictor
636
+ use_predictor = len(D1s) > 0 and x_t is None
637
+ if len(D1s) > 0:
638
+ D1s = torch.stack(D1s, dim=1) # (B, K)
639
+ if x_t is None:
640
+ # for order 2, we use a simplified version
641
+ if order == 2:
642
+ rhos_p = torch.tensor([0.5], device=b.device)
643
+ else:
644
+ rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
645
+ else:
646
+ D1s = None
647
+
648
+ if use_corrector:
649
+ # print('using corrector')
650
+ # for order 1, we use a simplified version
651
+ if order == 1:
652
+ rhos_c = torch.tensor([0.5], device=b.device)
653
+ else:
654
+ rhos_c = torch.linalg.solve(R, b)
655
+
656
+ model_t = None
657
+ if self.predict_x0:
658
+ x_t_ = (
659
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
660
+ - expand_dims(alpha_t * h_phi_1, dims)* model_prev_0
661
+ )
662
+
663
+ if x_t is None:
664
+ if use_predictor:
665
+ pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
666
+ else:
667
+ pred_res = 0
668
+ x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
669
+
670
+ if use_corrector:
671
+ model_t = self.model_fn(x_t, t)
672
+ if D1s is not None:
673
+ corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
674
+ else:
675
+ corr_res = 0
676
+ D1_t = (model_t - model_prev_0)
677
+ x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
678
+ else:
679
+ x_t_ = (
680
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
681
+ - expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
682
+ )
683
+ if x_t is None:
684
+ if use_predictor:
685
+ pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
686
+ else:
687
+ pred_res = 0
688
+ x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * pred_res
689
+
690
+ if use_corrector:
691
+ model_t = self.model_fn(x_t, t)
692
+ if D1s is not None:
693
+ corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
694
+ else:
695
+ corr_res = 0
696
+ D1_t = (model_t - model_prev_0)
697
+ x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
698
+ return x_t, model_t
699
+
700
+
701
+ def sample(self, x, timesteps, t_start=None, t_end=None, order=3, skip_type='time_uniform',
702
+ method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
703
+ atol=0.0078, rtol=0.05, corrector=False, callback=None, disable_pbar=False
704
+ ):
705
+ # t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
706
+ # t_T = self.noise_schedule.T if t_start is None else t_start
707
+ device = x.device
708
+ steps = len(timesteps) - 1
709
+ if method == 'multistep':
710
+ assert steps >= order
711
+ # timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
712
+ assert timesteps.shape[0] - 1 == steps
713
+ # with torch.no_grad():
714
+ for step_index in trange(steps, disable=disable_pbar):
715
+ if step_index == 0:
716
+ vec_t = timesteps[0].expand((x.shape[0]))
717
+ model_prev_list = [self.model_fn(x, vec_t)]
718
+ t_prev_list = [vec_t]
719
+ elif step_index < order:
720
+ init_order = step_index
721
+ # Init the first `order` values by lower order multistep DPM-Solver.
722
+ # for init_order in range(1, order):
723
+ vec_t = timesteps[init_order].expand(x.shape[0])
724
+ x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
725
+ if model_x is None:
726
+ model_x = self.model_fn(x, vec_t)
727
+ model_prev_list.append(model_x)
728
+ t_prev_list.append(vec_t)
729
+ else:
730
+ extra_final_step = 0
731
+ if step_index == (steps - 1):
732
+ extra_final_step = 1
733
+ for step in range(step_index, step_index + 1 + extra_final_step):
734
+ vec_t = timesteps[step].expand(x.shape[0])
735
+ if lower_order_final:
736
+ step_order = min(order, steps + 1 - step)
737
+ else:
738
+ step_order = order
739
+ # print('this step order:', step_order)
740
+ if step == steps:
741
+ # print('do not run corrector at the last step')
742
+ use_corrector = False
743
+ else:
744
+ use_corrector = True
745
+ x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
746
+ for i in range(order - 1):
747
+ t_prev_list[i] = t_prev_list[i + 1]
748
+ model_prev_list[i] = model_prev_list[i + 1]
749
+ t_prev_list[-1] = vec_t
750
+ # We do not need to evaluate the final model value.
751
+ if step < steps:
752
+ if model_x is None:
753
+ model_x = self.model_fn(x, vec_t)
754
+ model_prev_list[-1] = model_x
755
+ if callback is not None:
756
+ callback({'x': x, 'i': step_index, 'denoised': model_prev_list[-1]})
757
+ else:
758
+ raise NotImplementedError()
759
+ # if denoise_to_zero:
760
+ # x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
761
+ return x
762
+
763
+
764
+ #############################################################
765
+ # other utility functions
766
+ #############################################################
767
+
768
+ def interpolate_fn(x, xp, yp):
769
+ """
770
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
771
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
772
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
773
+
774
+ Args:
775
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
776
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
777
+ yp: PyTorch tensor with shape [C, K].
778
+ Returns:
779
+ The function values f(x), with shape [N, C].
780
+ """
781
+ N, K = x.shape[0], xp.shape[1]
782
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
783
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
784
+ x_idx = torch.argmin(x_indices, dim=2)
785
+ cand_start_idx = x_idx - 1
786
+ start_idx = torch.where(
787
+ torch.eq(x_idx, 0),
788
+ torch.tensor(1, device=x.device),
789
+ torch.where(
790
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
791
+ ),
792
+ )
793
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
794
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
795
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
796
+ start_idx2 = torch.where(
797
+ torch.eq(x_idx, 0),
798
+ torch.tensor(0, device=x.device),
799
+ torch.where(
800
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
801
+ ),
802
+ )
803
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
804
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
805
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
806
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
807
+ return cand
808
+
809
+
810
+ def expand_dims(v, dims):
811
+ """
812
+ Expand the tensor `v` to the dim `dims`.
813
+
814
+ Args:
815
+ `v`: a PyTorch tensor with shape [N].
816
+ `dim`: a `int`.
817
+ Returns:
818
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
819
+ """
820
+ return v[(...,) + (None,)*(dims - 1)]
821
+
822
+
823
+ class SigmaConvert:
824
+ schedule = ""
825
+ def marginal_log_mean_coeff(self, sigma):
826
+ return 0.5 * torch.log(1 / ((sigma * sigma) + 1))
827
+
828
+ def marginal_alpha(self, t):
829
+ return torch.exp(self.marginal_log_mean_coeff(t))
830
+
831
+ def marginal_std(self, t):
832
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
833
+
834
+ def marginal_lambda(self, t):
835
+ """
836
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
837
+ """
838
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
839
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
840
+ return log_mean_coeff - log_std
841
+
842
+ def predict_eps_sigma(model, input, sigma_in, **kwargs):
843
+ sigma = sigma_in.view(sigma_in.shape[:1] + (1,) * (input.ndim - 1))
844
+ input = input * ((sigma ** 2 + 1.0) ** 0.5)
845
+ return (input - model(input, sigma_in, **kwargs)) / sigma
846
+
847
+
848
+ def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'):
849
+ timesteps = sigmas.clone()
850
+ if sigmas[-1] == 0:
851
+ timesteps = sigmas[:]
852
+ timesteps[-1] = 0.001
853
+ else:
854
+ timesteps = sigmas.clone()
855
+ ns = SigmaConvert()
856
+
857
+ noise = noise / torch.sqrt(1.0 + timesteps[0] ** 2.0)
858
+ model_type = "noise"
859
+
860
+ model_fn = model_wrapper(
861
+ lambda input, sigma, **kwargs: predict_eps_sigma(model, input, sigma, **kwargs),
862
+ ns,
863
+ model_type=model_type,
864
+ guidance_type="uncond",
865
+ model_kwargs=extra_args,
866
+ )
867
+
868
+ order = min(3, len(timesteps) - 2)
869
+ uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=variant)
870
+ x = uni_pc.sample(noise, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
871
+ x /= ns.marginal_alpha(timesteps[-1])
872
+ return x
873
+
874
+ def sample_unipc_bh2(model, noise, sigmas, extra_args=None, callback=None, disable=False):
875
+ return sample_unipc(model, noise, sigmas, extra_args, callback, disable, variant='bh2')
comfy/float.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+
4
+ def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
5
+ mantissa_scaled = torch.where(
6
+ normal_mask,
7
+ (abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
8
+ (abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
9
+ )
10
+
11
+ mantissa_scaled += torch.rand(mantissa_scaled.size(), dtype=mantissa_scaled.dtype, layout=mantissa_scaled.layout, device=mantissa_scaled.device, generator=generator)
12
+ return mantissa_scaled.floor() / (2**MANTISSA_BITS)
13
+
14
+ #Not 100% sure about this
15
+ def manual_stochastic_round_to_float8(x, dtype, generator=None):
16
+ if dtype == torch.float8_e4m3fn:
17
+ EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7
18
+ elif dtype == torch.float8_e5m2:
19
+ EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 5, 2, 15
20
+ else:
21
+ raise ValueError("Unsupported dtype")
22
+
23
+ x = x.half()
24
+ sign = torch.sign(x)
25
+ abs_x = x.abs()
26
+ sign = torch.where(abs_x == 0, 0, sign)
27
+
28
+ # Combine exponent calculation and clamping
29
+ exponent = torch.clamp(
30
+ torch.floor(torch.log2(abs_x)) + EXPONENT_BIAS,
31
+ 0, 2**EXPONENT_BITS - 1
32
+ )
33
+
34
+ # Combine mantissa calculation and rounding
35
+ normal_mask = ~(exponent == 0)
36
+
37
+ abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=generator)
38
+
39
+ sign *= torch.where(
40
+ normal_mask,
41
+ (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x),
42
+ (2.0 ** (-EXPONENT_BIAS + 1)) * abs_x
43
+ )
44
+
45
+ return sign
46
+
47
+
48
+
49
+ def stochastic_rounding(value, dtype, seed=0):
50
+ if dtype == torch.float32:
51
+ return value.to(dtype=torch.float32)
52
+ if dtype == torch.float16:
53
+ return value.to(dtype=torch.float16)
54
+ if dtype == torch.bfloat16:
55
+ return value.to(dtype=torch.bfloat16)
56
+ if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
57
+ generator = torch.Generator(device=value.device)
58
+ generator.manual_seed(seed)
59
+ output = torch.empty_like(value, dtype=dtype)
60
+ num_slices = max(1, (value.numel() / (4096 * 4096)))
61
+ slice_size = max(1, round(value.shape[0] / num_slices))
62
+ for i in range(0, value.shape[0], slice_size):
63
+ output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
64
+ return output
65
+
66
+ return value.to(dtype=dtype)
comfy/gligen.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from .ldm.modules.attention import CrossAttention
4
+ from inspect import isfunction
5
+ import comfy.ops
6
+ ops = comfy.ops.manual_cast
7
+
8
+ def exists(val):
9
+ return val is not None
10
+
11
+
12
+ def uniq(arr):
13
+ return{el: True for el in arr}.keys()
14
+
15
+
16
+ def default(val, d):
17
+ if exists(val):
18
+ return val
19
+ return d() if isfunction(d) else d
20
+
21
+
22
+ # feedforward
23
+ class GEGLU(nn.Module):
24
+ def __init__(self, dim_in, dim_out):
25
+ super().__init__()
26
+ self.proj = ops.Linear(dim_in, dim_out * 2)
27
+
28
+ def forward(self, x):
29
+ x, gate = self.proj(x).chunk(2, dim=-1)
30
+ return x * torch.nn.functional.gelu(gate)
31
+
32
+
33
+ class FeedForward(nn.Module):
34
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
35
+ super().__init__()
36
+ inner_dim = int(dim * mult)
37
+ dim_out = default(dim_out, dim)
38
+ project_in = nn.Sequential(
39
+ ops.Linear(dim, inner_dim),
40
+ nn.GELU()
41
+ ) if not glu else GEGLU(dim, inner_dim)
42
+
43
+ self.net = nn.Sequential(
44
+ project_in,
45
+ nn.Dropout(dropout),
46
+ ops.Linear(inner_dim, dim_out)
47
+ )
48
+
49
+ def forward(self, x):
50
+ return self.net(x)
51
+
52
+
53
+ class GatedCrossAttentionDense(nn.Module):
54
+ def __init__(self, query_dim, context_dim, n_heads, d_head):
55
+ super().__init__()
56
+
57
+ self.attn = CrossAttention(
58
+ query_dim=query_dim,
59
+ context_dim=context_dim,
60
+ heads=n_heads,
61
+ dim_head=d_head,
62
+ operations=ops)
63
+ self.ff = FeedForward(query_dim, glu=True)
64
+
65
+ self.norm1 = ops.LayerNorm(query_dim)
66
+ self.norm2 = ops.LayerNorm(query_dim)
67
+
68
+ self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
69
+ self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
70
+
71
+ # this can be useful: we can externally change magnitude of tanh(alpha)
72
+ # for example, when it is set to 0, then the entire model is same as
73
+ # original one
74
+ self.scale = 1
75
+
76
+ def forward(self, x, objs):
77
+
78
+ x = x + self.scale * \
79
+ torch.tanh(self.alpha_attn) * self.attn(self.norm1(x), objs, objs)
80
+ x = x + self.scale * \
81
+ torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
82
+
83
+ return x
84
+
85
+
86
+ class GatedSelfAttentionDense(nn.Module):
87
+ def __init__(self, query_dim, context_dim, n_heads, d_head):
88
+ super().__init__()
89
+
90
+ # we need a linear projection since we need cat visual feature and obj
91
+ # feature
92
+ self.linear = ops.Linear(context_dim, query_dim)
93
+
94
+ self.attn = CrossAttention(
95
+ query_dim=query_dim,
96
+ context_dim=query_dim,
97
+ heads=n_heads,
98
+ dim_head=d_head,
99
+ operations=ops)
100
+ self.ff = FeedForward(query_dim, glu=True)
101
+
102
+ self.norm1 = ops.LayerNorm(query_dim)
103
+ self.norm2 = ops.LayerNorm(query_dim)
104
+
105
+ self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
106
+ self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
107
+
108
+ # this can be useful: we can externally change magnitude of tanh(alpha)
109
+ # for example, when it is set to 0, then the entire model is same as
110
+ # original one
111
+ self.scale = 1
112
+
113
+ def forward(self, x, objs):
114
+
115
+ N_visual = x.shape[1]
116
+ objs = self.linear(objs)
117
+
118
+ x = x + self.scale * torch.tanh(self.alpha_attn) * self.attn(
119
+ self.norm1(torch.cat([x, objs], dim=1)))[:, 0:N_visual, :]
120
+ x = x + self.scale * \
121
+ torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
122
+
123
+ return x
124
+
125
+
126
+ class GatedSelfAttentionDense2(nn.Module):
127
+ def __init__(self, query_dim, context_dim, n_heads, d_head):
128
+ super().__init__()
129
+
130
+ # we need a linear projection since we need cat visual feature and obj
131
+ # feature
132
+ self.linear = ops.Linear(context_dim, query_dim)
133
+
134
+ self.attn = CrossAttention(
135
+ query_dim=query_dim, context_dim=query_dim, dim_head=d_head, operations=ops)
136
+ self.ff = FeedForward(query_dim, glu=True)
137
+
138
+ self.norm1 = ops.LayerNorm(query_dim)
139
+ self.norm2 = ops.LayerNorm(query_dim)
140
+
141
+ self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
142
+ self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
143
+
144
+ # this can be useful: we can externally change magnitude of tanh(alpha)
145
+ # for example, when it is set to 0, then the entire model is same as
146
+ # original one
147
+ self.scale = 1
148
+
149
+ def forward(self, x, objs):
150
+
151
+ B, N_visual, _ = x.shape
152
+ B, N_ground, _ = objs.shape
153
+
154
+ objs = self.linear(objs)
155
+
156
+ # sanity check
157
+ size_v = math.sqrt(N_visual)
158
+ size_g = math.sqrt(N_ground)
159
+ assert int(size_v) == size_v, "Visual tokens must be square rootable"
160
+ assert int(size_g) == size_g, "Grounding tokens must be square rootable"
161
+ size_v = int(size_v)
162
+ size_g = int(size_g)
163
+
164
+ # select grounding token and resize it to visual token size as residual
165
+ out = self.attn(self.norm1(torch.cat([x, objs], dim=1)))[
166
+ :, N_visual:, :]
167
+ out = out.permute(0, 2, 1).reshape(B, -1, size_g, size_g)
168
+ out = torch.nn.functional.interpolate(
169
+ out, (size_v, size_v), mode='bicubic')
170
+ residual = out.reshape(B, -1, N_visual).permute(0, 2, 1)
171
+
172
+ # add residual to visual feature
173
+ x = x + self.scale * torch.tanh(self.alpha_attn) * residual
174
+ x = x + self.scale * \
175
+ torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
176
+
177
+ return x
178
+
179
+
180
+ class FourierEmbedder():
181
+ def __init__(self, num_freqs=64, temperature=100):
182
+
183
+ self.num_freqs = num_freqs
184
+ self.temperature = temperature
185
+ self.freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
186
+
187
+ @torch.no_grad()
188
+ def __call__(self, x, cat_dim=-1):
189
+ "x: arbitrary shape of tensor. dim: cat dim"
190
+ out = []
191
+ for freq in self.freq_bands:
192
+ out.append(torch.sin(freq * x))
193
+ out.append(torch.cos(freq * x))
194
+ return torch.cat(out, cat_dim)
195
+
196
+
197
+ class PositionNet(nn.Module):
198
+ def __init__(self, in_dim, out_dim, fourier_freqs=8):
199
+ super().__init__()
200
+ self.in_dim = in_dim
201
+ self.out_dim = out_dim
202
+
203
+ self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
204
+ self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy
205
+
206
+ self.linears = nn.Sequential(
207
+ ops.Linear(self.in_dim + self.position_dim, 512),
208
+ nn.SiLU(),
209
+ ops.Linear(512, 512),
210
+ nn.SiLU(),
211
+ ops.Linear(512, out_dim),
212
+ )
213
+
214
+ self.null_positive_feature = torch.nn.Parameter(
215
+ torch.zeros([self.in_dim]))
216
+ self.null_position_feature = torch.nn.Parameter(
217
+ torch.zeros([self.position_dim]))
218
+
219
+ def forward(self, boxes, masks, positive_embeddings):
220
+ B, N, _ = boxes.shape
221
+ masks = masks.unsqueeze(-1)
222
+ positive_embeddings = positive_embeddings
223
+
224
+ # embedding position (it may includes padding as placeholder)
225
+ xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C
226
+
227
+ # learnable null embedding
228
+ positive_null = self.null_positive_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1)
229
+ xyxy_null = self.null_position_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1)
230
+
231
+ # replace padding with learnable null embedding
232
+ positive_embeddings = positive_embeddings * \
233
+ masks + (1 - masks) * positive_null
234
+ xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
235
+
236
+ objs = self.linears(
237
+ torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
238
+ assert objs.shape == torch.Size([B, N, self.out_dim])
239
+ return objs
240
+
241
+
242
+ class Gligen(nn.Module):
243
+ def __init__(self, modules, position_net, key_dim):
244
+ super().__init__()
245
+ self.module_list = nn.ModuleList(modules)
246
+ self.position_net = position_net
247
+ self.key_dim = key_dim
248
+ self.max_objs = 30
249
+ self.current_device = torch.device("cpu")
250
+
251
+ def _set_position(self, boxes, masks, positive_embeddings):
252
+ objs = self.position_net(boxes, masks, positive_embeddings)
253
+ def func(x, extra_options):
254
+ key = extra_options["transformer_index"]
255
+ module = self.module_list[key]
256
+ return module(x, objs.to(device=x.device, dtype=x.dtype))
257
+ return func
258
+
259
+ def set_position(self, latent_image_shape, position_params, device):
260
+ batch, c, h, w = latent_image_shape
261
+ masks = torch.zeros([self.max_objs], device="cpu")
262
+ boxes = []
263
+ positive_embeddings = []
264
+ for p in position_params:
265
+ x1 = (p[4]) / w
266
+ y1 = (p[3]) / h
267
+ x2 = (p[4] + p[2]) / w
268
+ y2 = (p[3] + p[1]) / h
269
+ masks[len(boxes)] = 1.0
270
+ boxes += [torch.tensor((x1, y1, x2, y2)).unsqueeze(0)]
271
+ positive_embeddings += [p[0]]
272
+ append_boxes = []
273
+ append_conds = []
274
+ if len(boxes) < self.max_objs:
275
+ append_boxes = [torch.zeros(
276
+ [self.max_objs - len(boxes), 4], device="cpu")]
277
+ append_conds = [torch.zeros(
278
+ [self.max_objs - len(boxes), self.key_dim], device="cpu")]
279
+
280
+ box_out = torch.cat(
281
+ boxes + append_boxes).unsqueeze(0).repeat(batch, 1, 1)
282
+ masks = masks.unsqueeze(0).repeat(batch, 1)
283
+ conds = torch.cat(positive_embeddings +
284
+ append_conds).unsqueeze(0).repeat(batch, 1, 1)
285
+ return self._set_position(
286
+ box_out.to(device),
287
+ masks.to(device),
288
+ conds.to(device))
289
+
290
+ def set_empty(self, latent_image_shape, device):
291
+ batch, c, h, w = latent_image_shape
292
+ masks = torch.zeros([self.max_objs], device="cpu").repeat(batch, 1)
293
+ box_out = torch.zeros([self.max_objs, 4],
294
+ device="cpu").repeat(batch, 1, 1)
295
+ conds = torch.zeros([self.max_objs, self.key_dim],
296
+ device="cpu").repeat(batch, 1, 1)
297
+ return self._set_position(
298
+ box_out.to(device),
299
+ masks.to(device),
300
+ conds.to(device))
301
+
302
+
303
+ def load_gligen(sd):
304
+ sd_k = sd.keys()
305
+ output_list = []
306
+ key_dim = 768
307
+ for a in ["input_blocks", "middle_block", "output_blocks"]:
308
+ for b in range(20):
309
+ k_temp = filter(lambda k: "{}.{}.".format(a, b)
310
+ in k and ".fuser." in k, sd_k)
311
+ k_temp = map(lambda k: (k, k.split(".fuser.")[-1]), k_temp)
312
+
313
+ n_sd = {}
314
+ for k in k_temp:
315
+ n_sd[k[1]] = sd[k[0]]
316
+ if len(n_sd) > 0:
317
+ query_dim = n_sd["linear.weight"].shape[0]
318
+ key_dim = n_sd["linear.weight"].shape[1]
319
+
320
+ if key_dim == 768: # SD1.x
321
+ n_heads = 8
322
+ d_head = query_dim // n_heads
323
+ else:
324
+ d_head = 64
325
+ n_heads = query_dim // d_head
326
+
327
+ gated = GatedSelfAttentionDense(
328
+ query_dim, key_dim, n_heads, d_head)
329
+ gated.load_state_dict(n_sd, strict=False)
330
+ output_list.append(gated)
331
+
332
+ if "position_net.null_positive_feature" in sd_k:
333
+ in_dim = sd["position_net.null_positive_feature"].shape[0]
334
+ out_dim = sd["position_net.linears.4.weight"].shape[0]
335
+
336
+ class WeightsLoader(torch.nn.Module):
337
+ pass
338
+ w = WeightsLoader()
339
+ w.position_net = PositionNet(in_dim, out_dim)
340
+ w.load_state_dict(sd, strict=False)
341
+
342
+ gligen = Gligen(output_list, w.position_net, key_dim)
343
+ return gligen
comfy/k_diffusion/deis.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Taken from: https://github.com/zju-pi/diff-sampler/blob/main/gits-main/solver_utils.py
2
+ #under Apache 2 license
3
+ import torch
4
+ import numpy as np
5
+
6
+ # A pytorch reimplementation of DEIS (https://github.com/qsh-zh/deis).
7
+ #############################
8
+ ### Utils for DEIS solver ###
9
+ #############################
10
+ #----------------------------------------------------------------------------
11
+ # Transfer from the input time (sigma) used in EDM to that (t) used in DEIS.
12
+
13
+ def edm2t(edm_steps, epsilon_s=1e-3, sigma_min=0.002, sigma_max=80):
14
+ vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
15
+ vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
16
+ vp_beta_d = 2 * (np.log(torch.tensor(sigma_min).cpu() ** 2 + 1) / epsilon_s - np.log(torch.tensor(sigma_max).cpu() ** 2 + 1)) / (epsilon_s - 1)
17
+ vp_beta_min = np.log(torch.tensor(sigma_max).cpu() ** 2 + 1) - 0.5 * vp_beta_d
18
+ t_steps = vp_sigma_inv(vp_beta_d.clone().detach().cpu(), vp_beta_min.clone().detach().cpu())(edm_steps.clone().detach().cpu())
19
+ return t_steps, vp_beta_min, vp_beta_d + vp_beta_min
20
+
21
+ #----------------------------------------------------------------------------
22
+
23
+ def cal_poly(prev_t, j, taus):
24
+ poly = 1
25
+ for k in range(prev_t.shape[0]):
26
+ if k == j:
27
+ continue
28
+ poly *= (taus - prev_t[k]) / (prev_t[j] - prev_t[k])
29
+ return poly
30
+
31
+ #----------------------------------------------------------------------------
32
+ # Transfer from t to alpha_t.
33
+
34
+ def t2alpha_fn(beta_0, beta_1, t):
35
+ return torch.exp(-0.5 * t ** 2 * (beta_1 - beta_0) - t * beta_0)
36
+
37
+ #----------------------------------------------------------------------------
38
+
39
+ def cal_intergrand(beta_0, beta_1, taus):
40
+ with torch.inference_mode(mode=False):
41
+ taus = taus.clone()
42
+ beta_0 = beta_0.clone()
43
+ beta_1 = beta_1.clone()
44
+ with torch.enable_grad():
45
+ taus.requires_grad_(True)
46
+ alpha = t2alpha_fn(beta_0, beta_1, taus)
47
+ log_alpha = alpha.log()
48
+ log_alpha.sum().backward()
49
+ d_log_alpha_dtau = taus.grad
50
+ integrand = -0.5 * d_log_alpha_dtau / torch.sqrt(alpha * (1 - alpha))
51
+ return integrand
52
+
53
+ #----------------------------------------------------------------------------
54
+
55
+ def get_deis_coeff_list(t_steps, max_order, N=10000, deis_mode='tab'):
56
+ """
57
+ Get the coefficient list for DEIS sampling.
58
+
59
+ Args:
60
+ t_steps: A pytorch tensor. The time steps for sampling.
61
+ max_order: A `int`. Maximum order of the solver. 1 <= max_order <= 4
62
+ N: A `int`. Use how many points to perform the numerical integration when deis_mode=='tab'.
63
+ deis_mode: A `str`. Select between 'tab' and 'rhoab'. Type of DEIS.
64
+ Returns:
65
+ A pytorch tensor. A batch of generated samples or sampling trajectories if return_inters=True.
66
+ """
67
+ if deis_mode == 'tab':
68
+ t_steps, beta_0, beta_1 = edm2t(t_steps)
69
+ C = []
70
+ for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
71
+ order = min(i+1, max_order)
72
+ if order == 1:
73
+ C.append([])
74
+ else:
75
+ taus = torch.linspace(t_cur, t_next, N) # split the interval for integral appximation
76
+ dtau = (t_next - t_cur) / N
77
+ prev_t = t_steps[[i - k for k in range(order)]]
78
+ coeff_temp = []
79
+ integrand = cal_intergrand(beta_0, beta_1, taus)
80
+ for j in range(order):
81
+ poly = cal_poly(prev_t, j, taus)
82
+ coeff_temp.append(torch.sum(integrand * poly) * dtau)
83
+ C.append(coeff_temp)
84
+
85
+ elif deis_mode == 'rhoab':
86
+ # Analytical solution, second order
87
+ def get_def_intergral_2(a, b, start, end, c):
88
+ coeff = (end**3 - start**3) / 3 - (end**2 - start**2) * (a + b) / 2 + (end - start) * a * b
89
+ return coeff / ((c - a) * (c - b))
90
+
91
+ # Analytical solution, third order
92
+ def get_def_intergral_3(a, b, c, start, end, d):
93
+ coeff = (end**4 - start**4) / 4 - (end**3 - start**3) * (a + b + c) / 3 \
94
+ + (end**2 - start**2) * (a*b + a*c + b*c) / 2 - (end - start) * a * b * c
95
+ return coeff / ((d - a) * (d - b) * (d - c))
96
+
97
+ C = []
98
+ for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
99
+ order = min(i, max_order)
100
+ if order == 0:
101
+ C.append([])
102
+ else:
103
+ prev_t = t_steps[[i - k for k in range(order+1)]]
104
+ if order == 1:
105
+ coeff_cur = ((t_next - prev_t[1])**2 - (t_cur - prev_t[1])**2) / (2 * (t_cur - prev_t[1]))
106
+ coeff_prev1 = (t_next - t_cur)**2 / (2 * (prev_t[1] - t_cur))
107
+ coeff_temp = [coeff_cur, coeff_prev1]
108
+ elif order == 2:
109
+ coeff_cur = get_def_intergral_2(prev_t[1], prev_t[2], t_cur, t_next, t_cur)
110
+ coeff_prev1 = get_def_intergral_2(t_cur, prev_t[2], t_cur, t_next, prev_t[1])
111
+ coeff_prev2 = get_def_intergral_2(t_cur, prev_t[1], t_cur, t_next, prev_t[2])
112
+ coeff_temp = [coeff_cur, coeff_prev1, coeff_prev2]
113
+ elif order == 3:
114
+ coeff_cur = get_def_intergral_3(prev_t[1], prev_t[2], prev_t[3], t_cur, t_next, t_cur)
115
+ coeff_prev1 = get_def_intergral_3(t_cur, prev_t[2], prev_t[3], t_cur, t_next, prev_t[1])
116
+ coeff_prev2 = get_def_intergral_3(t_cur, prev_t[1], prev_t[3], t_cur, t_next, prev_t[2])
117
+ coeff_prev3 = get_def_intergral_3(t_cur, prev_t[1], prev_t[2], t_cur, t_next, prev_t[3])
118
+ coeff_temp = [coeff_cur, coeff_prev1, coeff_prev2, coeff_prev3]
119
+ C.append(coeff_temp)
120
+ return C
121
+
comfy/k_diffusion/sampling.py ADDED
@@ -0,0 +1,1145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ from scipy import integrate
4
+ import torch
5
+ from torch import nn
6
+ import torchsde
7
+ from tqdm.auto import trange, tqdm
8
+
9
+ from . import utils
10
+ from . import deis
11
+ import comfy.model_patcher
12
+ import comfy.model_sampling
13
+
14
+ def append_zero(x):
15
+ return torch.cat([x, x.new_zeros([1])])
16
+
17
+
18
+ def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
19
+ """Constructs the noise schedule of Karras et al. (2022)."""
20
+ ramp = torch.linspace(0, 1, n, device=device)
21
+ min_inv_rho = sigma_min ** (1 / rho)
22
+ max_inv_rho = sigma_max ** (1 / rho)
23
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
24
+ return append_zero(sigmas).to(device)
25
+
26
+
27
+ def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
28
+ """Constructs an exponential noise schedule."""
29
+ sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
30
+ return append_zero(sigmas)
31
+
32
+
33
+ def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
34
+ """Constructs an polynomial in log sigma noise schedule."""
35
+ ramp = torch.linspace(1, 0, n, device=device) ** rho
36
+ sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min))
37
+ return append_zero(sigmas)
38
+
39
+
40
+ def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
41
+ """Constructs a continuous VP noise schedule."""
42
+ t = torch.linspace(1, eps_s, n, device=device)
43
+ sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
44
+ return append_zero(sigmas)
45
+
46
+
47
+ def to_d(x, sigma, denoised):
48
+ """Converts a denoiser output to a Karras ODE derivative."""
49
+ return (x - denoised) / utils.append_dims(sigma, x.ndim)
50
+
51
+
52
+ def get_ancestral_step(sigma_from, sigma_to, eta=1.):
53
+ """Calculates the noise level (sigma_down) to step down to and the amount
54
+ of noise to add (sigma_up) when doing an ancestral sampling step."""
55
+ if not eta:
56
+ return sigma_to, 0.
57
+ sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5)
58
+ sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
59
+ return sigma_down, sigma_up
60
+
61
+
62
+ def default_noise_sampler(x):
63
+ return lambda sigma, sigma_next: torch.randn_like(x)
64
+
65
+
66
+ class BatchedBrownianTree:
67
+ """A wrapper around torchsde.BrownianTree that enables batches of entropy."""
68
+
69
+ def __init__(self, x, t0, t1, seed=None, **kwargs):
70
+ self.cpu_tree = True
71
+ if "cpu" in kwargs:
72
+ self.cpu_tree = kwargs.pop("cpu")
73
+ t0, t1, self.sign = self.sort(t0, t1)
74
+ w0 = kwargs.get('w0', torch.zeros_like(x))
75
+ if seed is None:
76
+ seed = torch.randint(0, 2 ** 63 - 1, []).item()
77
+ self.batched = True
78
+ try:
79
+ assert len(seed) == x.shape[0]
80
+ w0 = w0[0]
81
+ except TypeError:
82
+ seed = [seed]
83
+ self.batched = False
84
+ if self.cpu_tree:
85
+ self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed]
86
+ else:
87
+ self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
88
+
89
+ @staticmethod
90
+ def sort(a, b):
91
+ return (a, b, 1) if a < b else (b, a, -1)
92
+
93
+ def __call__(self, t0, t1):
94
+ t0, t1, sign = self.sort(t0, t1)
95
+ if self.cpu_tree:
96
+ w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign)
97
+ else:
98
+ w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
99
+
100
+ return w if self.batched else w[0]
101
+
102
+
103
+ class BrownianTreeNoiseSampler:
104
+ """A noise sampler backed by a torchsde.BrownianTree.
105
+
106
+ Args:
107
+ x (Tensor): The tensor whose shape, device and dtype to use to generate
108
+ random samples.
109
+ sigma_min (float): The low end of the valid interval.
110
+ sigma_max (float): The high end of the valid interval.
111
+ seed (int or List[int]): The random seed. If a list of seeds is
112
+ supplied instead of a single integer, then the noise sampler will
113
+ use one BrownianTree per batch item, each with its own seed.
114
+ transform (callable): A function that maps sigma to the sampler's
115
+ internal timestep.
116
+ """
117
+
118
+ def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False):
119
+ self.transform = transform
120
+ t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
121
+ self.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu)
122
+
123
+ def __call__(self, sigma, sigma_next):
124
+ t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
125
+ return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
126
+
127
+
128
+ @torch.no_grad()
129
+ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
130
+ """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
131
+ extra_args = {} if extra_args is None else extra_args
132
+ s_in = x.new_ones([x.shape[0]])
133
+ for i in trange(len(sigmas) - 1, disable=disable):
134
+ if s_churn > 0:
135
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
136
+ sigma_hat = sigmas[i] * (gamma + 1)
137
+ else:
138
+ gamma = 0
139
+ sigma_hat = sigmas[i]
140
+
141
+ if gamma > 0:
142
+ eps = torch.randn_like(x) * s_noise
143
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
144
+ denoised = model(x, sigma_hat * s_in, **extra_args)
145
+ d = to_d(x, sigma_hat, denoised)
146
+ if callback is not None:
147
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
148
+ dt = sigmas[i + 1] - sigma_hat
149
+ # Euler method
150
+ x = x + d * dt
151
+ return x
152
+
153
+
154
+ @torch.no_grad()
155
+ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
156
+ """Ancestral sampling with Euler method steps."""
157
+ extra_args = {} if extra_args is None else extra_args
158
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
159
+ s_in = x.new_ones([x.shape[0]])
160
+ for i in trange(len(sigmas) - 1, disable=disable):
161
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
162
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
163
+ if callback is not None:
164
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
165
+ d = to_d(x, sigmas[i], denoised)
166
+ # Euler method
167
+ dt = sigma_down - sigmas[i]
168
+ x = x + d * dt
169
+ if sigmas[i + 1] > 0:
170
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
171
+ return x
172
+
173
+
174
+ @torch.no_grad()
175
+ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
176
+ """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
177
+ extra_args = {} if extra_args is None else extra_args
178
+ s_in = x.new_ones([x.shape[0]])
179
+ for i in trange(len(sigmas) - 1, disable=disable):
180
+ if s_churn > 0:
181
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
182
+ sigma_hat = sigmas[i] * (gamma + 1)
183
+ else:
184
+ gamma = 0
185
+ sigma_hat = sigmas[i]
186
+
187
+ sigma_hat = sigmas[i] * (gamma + 1)
188
+ if gamma > 0:
189
+ eps = torch.randn_like(x) * s_noise
190
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
191
+ denoised = model(x, sigma_hat * s_in, **extra_args)
192
+ d = to_d(x, sigma_hat, denoised)
193
+ if callback is not None:
194
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
195
+ dt = sigmas[i + 1] - sigma_hat
196
+ if sigmas[i + 1] == 0:
197
+ # Euler method
198
+ x = x + d * dt
199
+ else:
200
+ # Heun's method
201
+ x_2 = x + d * dt
202
+ denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
203
+ d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
204
+ d_prime = (d + d_2) / 2
205
+ x = x + d_prime * dt
206
+ return x
207
+
208
+
209
+ @torch.no_grad()
210
+ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
211
+ """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
212
+ extra_args = {} if extra_args is None else extra_args
213
+ s_in = x.new_ones([x.shape[0]])
214
+ for i in trange(len(sigmas) - 1, disable=disable):
215
+ if s_churn > 0:
216
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
217
+ sigma_hat = sigmas[i] * (gamma + 1)
218
+ else:
219
+ gamma = 0
220
+ sigma_hat = sigmas[i]
221
+
222
+ if gamma > 0:
223
+ eps = torch.randn_like(x) * s_noise
224
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
225
+ denoised = model(x, sigma_hat * s_in, **extra_args)
226
+ d = to_d(x, sigma_hat, denoised)
227
+ if callback is not None:
228
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
229
+ if sigmas[i + 1] == 0:
230
+ # Euler method
231
+ dt = sigmas[i + 1] - sigma_hat
232
+ x = x + d * dt
233
+ else:
234
+ # DPM-Solver-2
235
+ sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
236
+ dt_1 = sigma_mid - sigma_hat
237
+ dt_2 = sigmas[i + 1] - sigma_hat
238
+ x_2 = x + d * dt_1
239
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
240
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
241
+ x = x + d_2 * dt_2
242
+ return x
243
+
244
+
245
+ @torch.no_grad()
246
+ def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
247
+ """Ancestral sampling with DPM-Solver second-order steps."""
248
+ extra_args = {} if extra_args is None else extra_args
249
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
250
+ s_in = x.new_ones([x.shape[0]])
251
+ for i in trange(len(sigmas) - 1, disable=disable):
252
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
253
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
254
+ if callback is not None:
255
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
256
+ d = to_d(x, sigmas[i], denoised)
257
+ if sigma_down == 0:
258
+ # Euler method
259
+ dt = sigma_down - sigmas[i]
260
+ x = x + d * dt
261
+ else:
262
+ # DPM-Solver-2
263
+ sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
264
+ dt_1 = sigma_mid - sigmas[i]
265
+ dt_2 = sigma_down - sigmas[i]
266
+ x_2 = x + d * dt_1
267
+ denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
268
+ d_2 = to_d(x_2, sigma_mid, denoised_2)
269
+ x = x + d_2 * dt_2
270
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
271
+ return x
272
+
273
+
274
+ def linear_multistep_coeff(order, t, i, j):
275
+ if order - 1 > i:
276
+ raise ValueError(f'Order {order} too high for step {i}')
277
+ def fn(tau):
278
+ prod = 1.
279
+ for k in range(order):
280
+ if j == k:
281
+ continue
282
+ prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
283
+ return prod
284
+ return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
285
+
286
+
287
+ @torch.no_grad()
288
+ def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
289
+ extra_args = {} if extra_args is None else extra_args
290
+ s_in = x.new_ones([x.shape[0]])
291
+ sigmas_cpu = sigmas.detach().cpu().numpy()
292
+ ds = []
293
+ for i in trange(len(sigmas) - 1, disable=disable):
294
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
295
+ d = to_d(x, sigmas[i], denoised)
296
+ ds.append(d)
297
+ if len(ds) > order:
298
+ ds.pop(0)
299
+ if callback is not None:
300
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
301
+ cur_order = min(i + 1, order)
302
+ coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
303
+ x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
304
+ return x
305
+
306
+
307
+ class PIDStepSizeController:
308
+ """A PID controller for ODE adaptive step size control."""
309
+ def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
310
+ self.h = h
311
+ self.b1 = (pcoeff + icoeff + dcoeff) / order
312
+ self.b2 = -(pcoeff + 2 * dcoeff) / order
313
+ self.b3 = dcoeff / order
314
+ self.accept_safety = accept_safety
315
+ self.eps = eps
316
+ self.errs = []
317
+
318
+ def limiter(self, x):
319
+ return 1 + math.atan(x - 1)
320
+
321
+ def propose_step(self, error):
322
+ inv_error = 1 / (float(error) + self.eps)
323
+ if not self.errs:
324
+ self.errs = [inv_error, inv_error, inv_error]
325
+ self.errs[0] = inv_error
326
+ factor = self.errs[0] ** self.b1 * self.errs[1] ** self.b2 * self.errs[2] ** self.b3
327
+ factor = self.limiter(factor)
328
+ accept = factor >= self.accept_safety
329
+ if accept:
330
+ self.errs[2] = self.errs[1]
331
+ self.errs[1] = self.errs[0]
332
+ self.h *= factor
333
+ return accept
334
+
335
+
336
+ class DPMSolver(nn.Module):
337
+ """DPM-Solver. See https://arxiv.org/abs/2206.00927."""
338
+
339
+ def __init__(self, model, extra_args=None, eps_callback=None, info_callback=None):
340
+ super().__init__()
341
+ self.model = model
342
+ self.extra_args = {} if extra_args is None else extra_args
343
+ self.eps_callback = eps_callback
344
+ self.info_callback = info_callback
345
+
346
+ def t(self, sigma):
347
+ return -sigma.log()
348
+
349
+ def sigma(self, t):
350
+ return t.neg().exp()
351
+
352
+ def eps(self, eps_cache, key, x, t, *args, **kwargs):
353
+ if key in eps_cache:
354
+ return eps_cache[key], eps_cache
355
+ sigma = self.sigma(t) * x.new_ones([x.shape[0]])
356
+ eps = (x - self.model(x, sigma, *args, **self.extra_args, **kwargs)) / self.sigma(t)
357
+ if self.eps_callback is not None:
358
+ self.eps_callback()
359
+ return eps, {key: eps, **eps_cache}
360
+
361
+ def dpm_solver_1_step(self, x, t, t_next, eps_cache=None):
362
+ eps_cache = {} if eps_cache is None else eps_cache
363
+ h = t_next - t
364
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
365
+ x_1 = x - self.sigma(t_next) * h.expm1() * eps
366
+ return x_1, eps_cache
367
+
368
+ def dpm_solver_2_step(self, x, t, t_next, r1=1 / 2, eps_cache=None):
369
+ eps_cache = {} if eps_cache is None else eps_cache
370
+ h = t_next - t
371
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
372
+ s1 = t + r1 * h
373
+ u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
374
+ eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
375
+ x_2 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps)
376
+ return x_2, eps_cache
377
+
378
+ def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None):
379
+ eps_cache = {} if eps_cache is None else eps_cache
380
+ h = t_next - t
381
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
382
+ s1 = t + r1 * h
383
+ s2 = t + r2 * h
384
+ u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
385
+ eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
386
+ u2 = x - self.sigma(s2) * (r2 * h).expm1() * eps - self.sigma(s2) * (r2 / r1) * ((r2 * h).expm1() / (r2 * h) - 1) * (eps_r1 - eps)
387
+ eps_r2, eps_cache = self.eps(eps_cache, 'eps_r2', u2, s2)
388
+ x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps)
389
+ return x_3, eps_cache
390
+
391
+ def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None):
392
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
393
+ if not t_end > t_start and eta:
394
+ raise ValueError('eta must be 0 for reverse sampling')
395
+
396
+ m = math.floor(nfe / 3) + 1
397
+ ts = torch.linspace(t_start, t_end, m + 1, device=x.device)
398
+
399
+ if nfe % 3 == 0:
400
+ orders = [3] * (m - 2) + [2, 1]
401
+ else:
402
+ orders = [3] * (m - 1) + [nfe % 3]
403
+
404
+ for i in range(len(orders)):
405
+ eps_cache = {}
406
+ t, t_next = ts[i], ts[i + 1]
407
+ if eta:
408
+ sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta)
409
+ t_next_ = torch.minimum(t_end, self.t(sd))
410
+ su = (self.sigma(t_next) ** 2 - self.sigma(t_next_) ** 2) ** 0.5
411
+ else:
412
+ t_next_, su = t_next, 0.
413
+
414
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
415
+ denoised = x - self.sigma(t) * eps
416
+ if self.info_callback is not None:
417
+ self.info_callback({'x': x, 'i': i, 't': ts[i], 't_up': t, 'denoised': denoised})
418
+
419
+ if orders[i] == 1:
420
+ x, eps_cache = self.dpm_solver_1_step(x, t, t_next_, eps_cache=eps_cache)
421
+ elif orders[i] == 2:
422
+ x, eps_cache = self.dpm_solver_2_step(x, t, t_next_, eps_cache=eps_cache)
423
+ else:
424
+ x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache)
425
+
426
+ x = x + su * s_noise * noise_sampler(self.sigma(t), self.sigma(t_next))
427
+
428
+ return x
429
+
430
+ def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None):
431
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
432
+ if order not in {2, 3}:
433
+ raise ValueError('order should be 2 or 3')
434
+ forward = t_end > t_start
435
+ if not forward and eta:
436
+ raise ValueError('eta must be 0 for reverse sampling')
437
+ h_init = abs(h_init) * (1 if forward else -1)
438
+ atol = torch.tensor(atol)
439
+ rtol = torch.tensor(rtol)
440
+ s = t_start
441
+ x_prev = x
442
+ accept = True
443
+ pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety)
444
+ info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0}
445
+
446
+ while s < t_end - 1e-5 if forward else s > t_end + 1e-5:
447
+ eps_cache = {}
448
+ t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h)
449
+ if eta:
450
+ sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta)
451
+ t_ = torch.minimum(t_end, self.t(sd))
452
+ su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5
453
+ else:
454
+ t_, su = t, 0.
455
+
456
+ eps, eps_cache = self.eps(eps_cache, 'eps', x, s)
457
+ denoised = x - self.sigma(s) * eps
458
+
459
+ if order == 2:
460
+ x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache)
461
+ x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache)
462
+ else:
463
+ x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache)
464
+ x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache)
465
+ delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs()))
466
+ error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5
467
+ accept = pid.propose_step(error)
468
+ if accept:
469
+ x_prev = x_low
470
+ x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t))
471
+ s = t
472
+ info['n_accept'] += 1
473
+ else:
474
+ info['n_reject'] += 1
475
+ info['nfe'] += order
476
+ info['steps'] += 1
477
+
478
+ if self.info_callback is not None:
479
+ self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info})
480
+
481
+ return x, info
482
+
483
+
484
+ @torch.no_grad()
485
+ def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None):
486
+ """DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927."""
487
+ if sigma_min <= 0 or sigma_max <= 0:
488
+ raise ValueError('sigma_min and sigma_max must not be 0')
489
+ with tqdm(total=n, disable=disable) as pbar:
490
+ dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
491
+ if callback is not None:
492
+ dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
493
+ return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), n, eta, s_noise, noise_sampler)
494
+
495
+
496
+ @torch.no_grad()
497
+ def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False):
498
+ """DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927."""
499
+ if sigma_min <= 0 or sigma_max <= 0:
500
+ raise ValueError('sigma_min and sigma_max must not be 0')
501
+ with tqdm(disable=disable) as pbar:
502
+ dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
503
+ if callback is not None:
504
+ dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
505
+ x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise, noise_sampler)
506
+ if return_info:
507
+ return x, info
508
+ return x
509
+
510
+
511
+ @torch.no_grad()
512
+ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
513
+ if isinstance(model.inner_model.inner_model.model_sampling, comfy.model_sampling.CONST):
514
+ return sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
515
+
516
+ """Ancestral sampling with DPM-Solver++(2S) second-order steps."""
517
+ extra_args = {} if extra_args is None else extra_args
518
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
519
+ s_in = x.new_ones([x.shape[0]])
520
+ sigma_fn = lambda t: t.neg().exp()
521
+ t_fn = lambda sigma: sigma.log().neg()
522
+
523
+ for i in trange(len(sigmas) - 1, disable=disable):
524
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
525
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
526
+ if callback is not None:
527
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
528
+ if sigma_down == 0:
529
+ # Euler method
530
+ d = to_d(x, sigmas[i], denoised)
531
+ dt = sigma_down - sigmas[i]
532
+ x = x + d * dt
533
+ else:
534
+ # DPM-Solver++(2S)
535
+ t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
536
+ r = 1 / 2
537
+ h = t_next - t
538
+ s = t + r * h
539
+ x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * denoised
540
+ denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
541
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2
542
+ # Noise addition
543
+ if sigmas[i + 1] > 0:
544
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
545
+ return x
546
+
547
+
548
+ @torch.no_grad()
549
+ def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
550
+ """Ancestral sampling with DPM-Solver++(2S) second-order steps."""
551
+ extra_args = {} if extra_args is None else extra_args
552
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
553
+ s_in = x.new_ones([x.shape[0]])
554
+ sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1
555
+ lambda_fn = lambda sigma: ((1-sigma)/sigma).log()
556
+
557
+ # logged_x = x.unsqueeze(0)
558
+
559
+ for i in trange(len(sigmas) - 1, disable=disable):
560
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
561
+ downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
562
+ sigma_down = sigmas[i+1] * downstep_ratio
563
+ alpha_ip1 = 1 - sigmas[i+1]
564
+ alpha_down = 1 - sigma_down
565
+ renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
566
+ # sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
567
+ if callback is not None:
568
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
569
+ if sigmas[i + 1] == 0:
570
+ # Euler method
571
+ d = to_d(x, sigmas[i], denoised)
572
+ dt = sigma_down - sigmas[i]
573
+ x = x + d * dt
574
+ else:
575
+ # DPM-Solver++(2S)
576
+ if sigmas[i] == 1.0:
577
+ sigma_s = 0.9999
578
+ else:
579
+ t_i, t_down = lambda_fn(sigmas[i]), lambda_fn(sigma_down)
580
+ r = 1 / 2
581
+ h = t_down - t_i
582
+ s = t_i + r * h
583
+ sigma_s = sigma_fn(s)
584
+ # sigma_s = sigmas[i+1]
585
+ sigma_s_i_ratio = sigma_s / sigmas[i]
586
+ u = sigma_s_i_ratio * x + (1 - sigma_s_i_ratio) * denoised
587
+ D_i = model(u, sigma_s * s_in, **extra_args)
588
+ sigma_down_i_ratio = sigma_down / sigmas[i]
589
+ x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * D_i
590
+ # print("sigma_i", sigmas[i], "sigma_ip1", sigmas[i+1],"sigma_down", sigma_down, "sigma_down_i_ratio", sigma_down_i_ratio, "sigma_s_i_ratio", sigma_s_i_ratio, "renoise_coeff", renoise_coeff)
591
+ # Noise addition
592
+ if sigmas[i + 1] > 0 and eta > 0:
593
+ x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
594
+ # logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
595
+ return x
596
+
597
+ @torch.no_grad()
598
+ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
599
+ """DPM-Solver++ (stochastic)."""
600
+ if len(sigmas) <= 1:
601
+ return x
602
+
603
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
604
+ seed = extra_args.get("seed", None)
605
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
606
+ extra_args = {} if extra_args is None else extra_args
607
+ s_in = x.new_ones([x.shape[0]])
608
+ sigma_fn = lambda t: t.neg().exp()
609
+ t_fn = lambda sigma: sigma.log().neg()
610
+
611
+ for i in trange(len(sigmas) - 1, disable=disable):
612
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
613
+ if callback is not None:
614
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
615
+ if sigmas[i + 1] == 0:
616
+ # Euler method
617
+ d = to_d(x, sigmas[i], denoised)
618
+ dt = sigmas[i + 1] - sigmas[i]
619
+ x = x + d * dt
620
+ else:
621
+ # DPM-Solver++
622
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
623
+ h = t_next - t
624
+ s = t + h * r
625
+ fac = 1 / (2 * r)
626
+
627
+ # Step 1
628
+ sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
629
+ s_ = t_fn(sd)
630
+ x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
631
+ x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
632
+ denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
633
+
634
+ # Step 2
635
+ sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
636
+ t_next_ = t_fn(sd)
637
+ denoised_d = (1 - fac) * denoised + fac * denoised_2
638
+ x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
639
+ x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
640
+ return x
641
+
642
+
643
+ @torch.no_grad()
644
+ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
645
+ """DPM-Solver++(2M)."""
646
+ extra_args = {} if extra_args is None else extra_args
647
+ s_in = x.new_ones([x.shape[0]])
648
+ sigma_fn = lambda t: t.neg().exp()
649
+ t_fn = lambda sigma: sigma.log().neg()
650
+ old_denoised = None
651
+
652
+ for i in trange(len(sigmas) - 1, disable=disable):
653
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
654
+ if callback is not None:
655
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
656
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
657
+ h = t_next - t
658
+ if old_denoised is None or sigmas[i + 1] == 0:
659
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
660
+ else:
661
+ h_last = t - t_fn(sigmas[i - 1])
662
+ r = h_last / h
663
+ denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
664
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
665
+ old_denoised = denoised
666
+ return x
667
+
668
+ @torch.no_grad()
669
+ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
670
+ """DPM-Solver++(2M) SDE."""
671
+ if len(sigmas) <= 1:
672
+ return x
673
+
674
+ if solver_type not in {'heun', 'midpoint'}:
675
+ raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
676
+
677
+ seed = extra_args.get("seed", None)
678
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
679
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
680
+ extra_args = {} if extra_args is None else extra_args
681
+ s_in = x.new_ones([x.shape[0]])
682
+
683
+ old_denoised = None
684
+ h_last = None
685
+ h = None
686
+
687
+ for i in trange(len(sigmas) - 1, disable=disable):
688
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
689
+ if callback is not None:
690
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
691
+ if sigmas[i + 1] == 0:
692
+ # Denoising step
693
+ x = denoised
694
+ else:
695
+ # DPM-Solver++(2M) SDE
696
+ t, s = -sigmas[i].log(), -sigmas[i + 1].log()
697
+ h = s - t
698
+ eta_h = eta * h
699
+
700
+ x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
701
+
702
+ if old_denoised is not None:
703
+ r = h_last / h
704
+ if solver_type == 'heun':
705
+ x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
706
+ elif solver_type == 'midpoint':
707
+ x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
708
+
709
+ if eta:
710
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
711
+
712
+ old_denoised = denoised
713
+ h_last = h
714
+ return x
715
+
716
+ @torch.no_grad()
717
+ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
718
+ """DPM-Solver++(3M) SDE."""
719
+
720
+ if len(sigmas) <= 1:
721
+ return x
722
+
723
+ seed = extra_args.get("seed", None)
724
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
725
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
726
+ extra_args = {} if extra_args is None else extra_args
727
+ s_in = x.new_ones([x.shape[0]])
728
+
729
+ denoised_1, denoised_2 = None, None
730
+ h, h_1, h_2 = None, None, None
731
+
732
+ for i in trange(len(sigmas) - 1, disable=disable):
733
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
734
+ if callback is not None:
735
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
736
+ if sigmas[i + 1] == 0:
737
+ # Denoising step
738
+ x = denoised
739
+ else:
740
+ t, s = -sigmas[i].log(), -sigmas[i + 1].log()
741
+ h = s - t
742
+ h_eta = h * (eta + 1)
743
+
744
+ x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised
745
+
746
+ if h_2 is not None:
747
+ r0 = h_1 / h
748
+ r1 = h_2 / h
749
+ d1_0 = (denoised - denoised_1) / r0
750
+ d1_1 = (denoised_1 - denoised_2) / r1
751
+ d1 = d1_0 + (d1_0 - d1_1) * r0 / (r0 + r1)
752
+ d2 = (d1_0 - d1_1) / (r0 + r1)
753
+ phi_2 = h_eta.neg().expm1() / h_eta + 1
754
+ phi_3 = phi_2 / h_eta - 0.5
755
+ x = x + phi_2 * d1 - phi_3 * d2
756
+ elif h_1 is not None:
757
+ r = h_1 / h
758
+ d = (denoised - denoised_1) / r
759
+ phi_2 = h_eta.neg().expm1() / h_eta + 1
760
+ x = x + phi_2 * d
761
+
762
+ if eta:
763
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
764
+
765
+ denoised_1, denoised_2 = denoised, denoised_1
766
+ h_1, h_2 = h, h_1
767
+ return x
768
+
769
+ @torch.no_grad()
770
+ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
771
+ if len(sigmas) <= 1:
772
+ return x
773
+
774
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
775
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
776
+ return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
777
+
778
+ @torch.no_grad()
779
+ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
780
+ if len(sigmas) <= 1:
781
+ return x
782
+
783
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
784
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
785
+ return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
786
+
787
+ @torch.no_grad()
788
+ def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
789
+ if len(sigmas) <= 1:
790
+ return x
791
+
792
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
793
+ noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
794
+ return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
795
+
796
+
797
+ def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler):
798
+ alpha_cumprod = 1 / ((sigma * sigma) + 1)
799
+ alpha_cumprod_prev = 1 / ((sigma_prev * sigma_prev) + 1)
800
+ alpha = (alpha_cumprod / alpha_cumprod_prev)
801
+
802
+ mu = (1.0 / alpha).sqrt() * (x - (1 - alpha) * noise / (1 - alpha_cumprod).sqrt())
803
+ if sigma_prev > 0:
804
+ mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev)
805
+ return mu
806
+
807
+ def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
808
+ extra_args = {} if extra_args is None else extra_args
809
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
810
+ s_in = x.new_ones([x.shape[0]])
811
+
812
+ for i in trange(len(sigmas) - 1, disable=disable):
813
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
814
+ if callback is not None:
815
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
816
+ x = step_function(x / torch.sqrt(1.0 + sigmas[i] ** 2.0), sigmas[i], sigmas[i + 1], (x - denoised) / sigmas[i], noise_sampler)
817
+ if sigmas[i + 1] != 0:
818
+ x *= torch.sqrt(1.0 + sigmas[i + 1] ** 2.0)
819
+ return x
820
+
821
+
822
+ @torch.no_grad()
823
+ def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
824
+ return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step)
825
+
826
+ @torch.no_grad()
827
+ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
828
+ extra_args = {} if extra_args is None else extra_args
829
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
830
+ s_in = x.new_ones([x.shape[0]])
831
+ for i in trange(len(sigmas) - 1, disable=disable):
832
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
833
+ if callback is not None:
834
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
835
+
836
+ x = denoised
837
+ if sigmas[i + 1] > 0:
838
+ x = model.inner_model.inner_model.model_sampling.noise_scaling(sigmas[i + 1], noise_sampler(sigmas[i], sigmas[i + 1]), x)
839
+ return x
840
+
841
+
842
+
843
+ @torch.no_grad()
844
+ def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
845
+ # From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/
846
+ extra_args = {} if extra_args is None else extra_args
847
+ s_in = x.new_ones([x.shape[0]])
848
+ s_end = sigmas[-1]
849
+ for i in trange(len(sigmas) - 1, disable=disable):
850
+ gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
851
+ eps = torch.randn_like(x) * s_noise
852
+ sigma_hat = sigmas[i] * (gamma + 1)
853
+ if gamma > 0:
854
+ x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
855
+ denoised = model(x, sigma_hat * s_in, **extra_args)
856
+ d = to_d(x, sigma_hat, denoised)
857
+ if callback is not None:
858
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
859
+ dt = sigmas[i + 1] - sigma_hat
860
+ if sigmas[i + 1] == s_end:
861
+ # Euler method
862
+ x = x + d * dt
863
+ elif sigmas[i + 2] == s_end:
864
+
865
+ # Heun's method
866
+ x_2 = x + d * dt
867
+ denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
868
+ d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
869
+
870
+ w = 2 * sigmas[0]
871
+ w2 = sigmas[i+1]/w
872
+ w1 = 1 - w2
873
+
874
+ d_prime = d * w1 + d_2 * w2
875
+
876
+
877
+ x = x + d_prime * dt
878
+
879
+ else:
880
+ # Heun++
881
+ x_2 = x + d * dt
882
+ denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
883
+ d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
884
+ dt_2 = sigmas[i + 2] - sigmas[i + 1]
885
+
886
+ x_3 = x_2 + d_2 * dt_2
887
+ denoised_3 = model(x_3, sigmas[i + 2] * s_in, **extra_args)
888
+ d_3 = to_d(x_3, sigmas[i + 2], denoised_3)
889
+
890
+ w = 3 * sigmas[0]
891
+ w2 = sigmas[i + 1] / w
892
+ w3 = sigmas[i + 2] / w
893
+ w1 = 1 - w2 - w3
894
+
895
+ d_prime = w1 * d + w2 * d_2 + w3 * d_3
896
+ x = x + d_prime * dt
897
+ return x
898
+
899
+
900
+ #From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
901
+ #under Apache 2 license
902
+ def sample_ipndm(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
903
+ extra_args = {} if extra_args is None else extra_args
904
+ s_in = x.new_ones([x.shape[0]])
905
+
906
+ x_next = x
907
+
908
+ buffer_model = []
909
+ for i in trange(len(sigmas) - 1, disable=disable):
910
+ t_cur = sigmas[i]
911
+ t_next = sigmas[i + 1]
912
+
913
+ x_cur = x_next
914
+
915
+ denoised = model(x_cur, t_cur * s_in, **extra_args)
916
+ if callback is not None:
917
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
918
+
919
+ d_cur = (x_cur - denoised) / t_cur
920
+
921
+ order = min(max_order, i+1)
922
+ if order == 1: # First Euler step.
923
+ x_next = x_cur + (t_next - t_cur) * d_cur
924
+ elif order == 2: # Use one history point.
925
+ x_next = x_cur + (t_next - t_cur) * (3 * d_cur - buffer_model[-1]) / 2
926
+ elif order == 3: # Use two history points.
927
+ x_next = x_cur + (t_next - t_cur) * (23 * d_cur - 16 * buffer_model[-1] + 5 * buffer_model[-2]) / 12
928
+ elif order == 4: # Use three history points.
929
+ x_next = x_cur + (t_next - t_cur) * (55 * d_cur - 59 * buffer_model[-1] + 37 * buffer_model[-2] - 9 * buffer_model[-3]) / 24
930
+
931
+ if len(buffer_model) == max_order - 1:
932
+ for k in range(max_order - 2):
933
+ buffer_model[k] = buffer_model[k+1]
934
+ buffer_model[-1] = d_cur
935
+ else:
936
+ buffer_model.append(d_cur)
937
+
938
+ return x_next
939
+
940
+ #From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
941
+ #under Apache 2 license
942
+ def sample_ipndm_v(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=4):
943
+ extra_args = {} if extra_args is None else extra_args
944
+ s_in = x.new_ones([x.shape[0]])
945
+
946
+ x_next = x
947
+ t_steps = sigmas
948
+
949
+ buffer_model = []
950
+ for i in trange(len(sigmas) - 1, disable=disable):
951
+ t_cur = sigmas[i]
952
+ t_next = sigmas[i + 1]
953
+
954
+ x_cur = x_next
955
+
956
+ denoised = model(x_cur, t_cur * s_in, **extra_args)
957
+ if callback is not None:
958
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
959
+
960
+ d_cur = (x_cur - denoised) / t_cur
961
+
962
+ order = min(max_order, i+1)
963
+ if order == 1: # First Euler step.
964
+ x_next = x_cur + (t_next - t_cur) * d_cur
965
+ elif order == 2: # Use one history point.
966
+ h_n = (t_next - t_cur)
967
+ h_n_1 = (t_cur - t_steps[i-1])
968
+ coeff1 = (2 + (h_n / h_n_1)) / 2
969
+ coeff2 = -(h_n / h_n_1) / 2
970
+ x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1])
971
+ elif order == 3: # Use two history points.
972
+ h_n = (t_next - t_cur)
973
+ h_n_1 = (t_cur - t_steps[i-1])
974
+ h_n_2 = (t_steps[i-1] - t_steps[i-2])
975
+ temp = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2
976
+ coeff1 = (2 + (h_n / h_n_1)) / 2 + temp
977
+ coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp
978
+ coeff3 = temp * h_n_1 / h_n_2
979
+ x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1] + coeff3 * buffer_model[-2])
980
+ elif order == 4: # Use three history points.
981
+ h_n = (t_next - t_cur)
982
+ h_n_1 = (t_cur - t_steps[i-1])
983
+ h_n_2 = (t_steps[i-1] - t_steps[i-2])
984
+ h_n_3 = (t_steps[i-2] - t_steps[i-3])
985
+ temp1 = (1 - h_n / (3 * (h_n + h_n_1)) * (h_n * (h_n + h_n_1)) / (h_n_1 * (h_n_1 + h_n_2))) / 2
986
+ temp2 = ((1 - h_n / (3 * (h_n + h_n_1))) / 2 + (1 - h_n / (2 * (h_n + h_n_1))) * h_n / (6 * (h_n + h_n_1 + h_n_2))) \
987
+ * (h_n * (h_n + h_n_1) * (h_n + h_n_1 + h_n_2)) / (h_n_1 * (h_n_1 + h_n_2) * (h_n_1 + h_n_2 + h_n_3))
988
+ coeff1 = (2 + (h_n / h_n_1)) / 2 + temp1 + temp2
989
+ coeff2 = -(h_n / h_n_1) / 2 - (1 + h_n_1 / h_n_2) * temp1 - (1 + (h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3)))) * temp2
990
+ coeff3 = temp1 * h_n_1 / h_n_2 + ((h_n_1 / h_n_2) + (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) * (1 + h_n_2 / h_n_3)) * temp2
991
+ coeff4 = -temp2 * (h_n_1 * (h_n_1 + h_n_2) / (h_n_2 * (h_n_2 + h_n_3))) * h_n_1 / h_n_2
992
+ x_next = x_cur + (t_next - t_cur) * (coeff1 * d_cur + coeff2 * buffer_model[-1] + coeff3 * buffer_model[-2] + coeff4 * buffer_model[-3])
993
+
994
+ if len(buffer_model) == max_order - 1:
995
+ for k in range(max_order - 2):
996
+ buffer_model[k] = buffer_model[k+1]
997
+ buffer_model[-1] = d_cur.detach()
998
+ else:
999
+ buffer_model.append(d_cur.detach())
1000
+
1001
+ return x_next
1002
+
1003
+ #From https://github.com/zju-pi/diff-sampler/blob/main/diff-solvers-main/solvers.py
1004
+ #under Apache 2 license
1005
+ @torch.no_grad()
1006
+ def sample_deis(model, x, sigmas, extra_args=None, callback=None, disable=None, max_order=3, deis_mode='tab'):
1007
+ extra_args = {} if extra_args is None else extra_args
1008
+ s_in = x.new_ones([x.shape[0]])
1009
+
1010
+ x_next = x
1011
+ t_steps = sigmas
1012
+
1013
+ coeff_list = deis.get_deis_coeff_list(t_steps, max_order, deis_mode=deis_mode)
1014
+
1015
+ buffer_model = []
1016
+ for i in trange(len(sigmas) - 1, disable=disable):
1017
+ t_cur = sigmas[i]
1018
+ t_next = sigmas[i + 1]
1019
+
1020
+ x_cur = x_next
1021
+
1022
+ denoised = model(x_cur, t_cur * s_in, **extra_args)
1023
+ if callback is not None:
1024
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
1025
+
1026
+ d_cur = (x_cur - denoised) / t_cur
1027
+
1028
+ order = min(max_order, i+1)
1029
+ if t_next <= 0:
1030
+ order = 1
1031
+
1032
+ if order == 1: # First Euler step.
1033
+ x_next = x_cur + (t_next - t_cur) * d_cur
1034
+ elif order == 2: # Use one history point.
1035
+ coeff_cur, coeff_prev1 = coeff_list[i]
1036
+ x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1]
1037
+ elif order == 3: # Use two history points.
1038
+ coeff_cur, coeff_prev1, coeff_prev2 = coeff_list[i]
1039
+ x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2]
1040
+ elif order == 4: # Use three history points.
1041
+ coeff_cur, coeff_prev1, coeff_prev2, coeff_prev3 = coeff_list[i]
1042
+ x_next = x_cur + coeff_cur * d_cur + coeff_prev1 * buffer_model[-1] + coeff_prev2 * buffer_model[-2] + coeff_prev3 * buffer_model[-3]
1043
+
1044
+ if len(buffer_model) == max_order - 1:
1045
+ for k in range(max_order - 2):
1046
+ buffer_model[k] = buffer_model[k+1]
1047
+ buffer_model[-1] = d_cur.detach()
1048
+ else:
1049
+ buffer_model.append(d_cur.detach())
1050
+
1051
+ return x_next
1052
+
1053
+ @torch.no_grad()
1054
+ def sample_euler_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None):
1055
+ extra_args = {} if extra_args is None else extra_args
1056
+
1057
+ temp = [0]
1058
+ def post_cfg_function(args):
1059
+ temp[0] = args["uncond_denoised"]
1060
+ return args["denoised"]
1061
+
1062
+ model_options = extra_args.get("model_options", {}).copy()
1063
+ extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
1064
+
1065
+ s_in = x.new_ones([x.shape[0]])
1066
+ for i in trange(len(sigmas) - 1, disable=disable):
1067
+ sigma_hat = sigmas[i]
1068
+ denoised = model(x, sigma_hat * s_in, **extra_args)
1069
+ d = to_d(x, sigma_hat, temp[0])
1070
+ if callback is not None:
1071
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
1072
+ dt = sigmas[i + 1] - sigma_hat
1073
+ # Euler method
1074
+ x = denoised + d * sigmas[i + 1]
1075
+ return x
1076
+
1077
+ @torch.no_grad()
1078
+ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
1079
+ """Ancestral sampling with Euler method steps."""
1080
+ extra_args = {} if extra_args is None else extra_args
1081
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
1082
+
1083
+ temp = [0]
1084
+ def post_cfg_function(args):
1085
+ temp[0] = args["uncond_denoised"]
1086
+ return args["denoised"]
1087
+
1088
+ model_options = extra_args.get("model_options", {}).copy()
1089
+ extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
1090
+
1091
+ s_in = x.new_ones([x.shape[0]])
1092
+ for i in trange(len(sigmas) - 1, disable=disable):
1093
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
1094
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
1095
+ if callback is not None:
1096
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
1097
+ d = to_d(x, sigmas[i], temp[0])
1098
+ # Euler method
1099
+ dt = sigma_down - sigmas[i]
1100
+ x = denoised + d * sigma_down
1101
+ if sigmas[i + 1] > 0:
1102
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
1103
+ return x
1104
+ @torch.no_grad()
1105
+ def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
1106
+ """Ancestral sampling with DPM-Solver++(2S) second-order steps."""
1107
+ extra_args = {} if extra_args is None else extra_args
1108
+ noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
1109
+
1110
+ temp = [0]
1111
+ def post_cfg_function(args):
1112
+ temp[0] = args["uncond_denoised"]
1113
+ return args["denoised"]
1114
+
1115
+ model_options = extra_args.get("model_options", {}).copy()
1116
+ extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
1117
+
1118
+ s_in = x.new_ones([x.shape[0]])
1119
+ sigma_fn = lambda t: t.neg().exp()
1120
+ t_fn = lambda sigma: sigma.log().neg()
1121
+
1122
+ for i in trange(len(sigmas) - 1, disable=disable):
1123
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
1124
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
1125
+ if callback is not None:
1126
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
1127
+ if sigma_down == 0:
1128
+ # Euler method
1129
+ d = to_d(x, sigmas[i], temp[0])
1130
+ dt = sigma_down - sigmas[i]
1131
+ x = denoised + d * sigma_down
1132
+ else:
1133
+ # DPM-Solver++(2S)
1134
+ t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
1135
+ # r = torch.sinh(1 + (2 - eta) * (t_next - t) / (t - t_fn(sigma_up))) works only on non-cfgpp, weird
1136
+ r = 1 / 2
1137
+ h = t_next - t
1138
+ s = t + r * h
1139
+ x_2 = (sigma_fn(s) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h * r).expm1() * denoised
1140
+ denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
1141
+ x = (sigma_fn(t_next) / sigma_fn(t)) * (x + (denoised - temp[0])) - (-h).expm1() * denoised_2
1142
+ # Noise addition
1143
+ if sigmas[i + 1] > 0:
1144
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
1145
+ return x
comfy/k_diffusion/utils.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+ import hashlib
3
+ import math
4
+ from pathlib import Path
5
+ import shutil
6
+ import urllib
7
+ import warnings
8
+
9
+ from PIL import Image
10
+ import torch
11
+ from torch import nn, optim
12
+ from torch.utils import data
13
+
14
+
15
+ def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'):
16
+ """Apply passed in transforms for HuggingFace Datasets."""
17
+ images = [transform(image.convert(mode)) for image in examples[image_key]]
18
+ return {image_key: images}
19
+
20
+
21
+ def append_dims(x, target_dims):
22
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
23
+ dims_to_append = target_dims - x.ndim
24
+ if dims_to_append < 0:
25
+ raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
26
+ expanded = x[(...,) + (None,) * dims_to_append]
27
+ # MPS will get inf values if it tries to index into the new axes, but detaching fixes this.
28
+ # https://github.com/pytorch/pytorch/issues/84364
29
+ return expanded.detach().clone() if expanded.device.type == 'mps' else expanded
30
+
31
+
32
+ def n_params(module):
33
+ """Returns the number of trainable parameters in a module."""
34
+ return sum(p.numel() for p in module.parameters())
35
+
36
+
37
+ def download_file(path, url, digest=None):
38
+ """Downloads a file if it does not exist, optionally checking its SHA-256 hash."""
39
+ path = Path(path)
40
+ path.parent.mkdir(parents=True, exist_ok=True)
41
+ if not path.exists():
42
+ with urllib.request.urlopen(url) as response, open(path, 'wb') as f:
43
+ shutil.copyfileobj(response, f)
44
+ if digest is not None:
45
+ file_digest = hashlib.sha256(open(path, 'rb').read()).hexdigest()
46
+ if digest != file_digest:
47
+ raise OSError(f'hash of {path} (url: {url}) failed to validate')
48
+ return path
49
+
50
+
51
+ @contextmanager
52
+ def train_mode(model, mode=True):
53
+ """A context manager that places a model into training mode and restores
54
+ the previous mode on exit."""
55
+ modes = [module.training for module in model.modules()]
56
+ try:
57
+ yield model.train(mode)
58
+ finally:
59
+ for i, module in enumerate(model.modules()):
60
+ module.training = modes[i]
61
+
62
+
63
+ def eval_mode(model):
64
+ """A context manager that places a model into evaluation mode and restores
65
+ the previous mode on exit."""
66
+ return train_mode(model, False)
67
+
68
+
69
+ @torch.no_grad()
70
+ def ema_update(model, averaged_model, decay):
71
+ """Incorporates updated model parameters into an exponential moving averaged
72
+ version of a model. It should be called after each optimizer step."""
73
+ model_params = dict(model.named_parameters())
74
+ averaged_params = dict(averaged_model.named_parameters())
75
+ assert model_params.keys() == averaged_params.keys()
76
+
77
+ for name, param in model_params.items():
78
+ averaged_params[name].mul_(decay).add_(param, alpha=1 - decay)
79
+
80
+ model_buffers = dict(model.named_buffers())
81
+ averaged_buffers = dict(averaged_model.named_buffers())
82
+ assert model_buffers.keys() == averaged_buffers.keys()
83
+
84
+ for name, buf in model_buffers.items():
85
+ averaged_buffers[name].copy_(buf)
86
+
87
+
88
+ class EMAWarmup:
89
+ """Implements an EMA warmup using an inverse decay schedule.
90
+ If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are
91
+ good values for models you plan to train for a million or more steps (reaches decay
92
+ factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models
93
+ you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
94
+ 215.4k steps).
95
+ Args:
96
+ inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
97
+ power (float): Exponential factor of EMA warmup. Default: 1.
98
+ min_value (float): The minimum EMA decay rate. Default: 0.
99
+ max_value (float): The maximum EMA decay rate. Default: 1.
100
+ start_at (int): The epoch to start averaging at. Default: 0.
101
+ last_epoch (int): The index of last epoch. Default: 0.
102
+ """
103
+
104
+ def __init__(self, inv_gamma=1., power=1., min_value=0., max_value=1., start_at=0,
105
+ last_epoch=0):
106
+ self.inv_gamma = inv_gamma
107
+ self.power = power
108
+ self.min_value = min_value
109
+ self.max_value = max_value
110
+ self.start_at = start_at
111
+ self.last_epoch = last_epoch
112
+
113
+ def state_dict(self):
114
+ """Returns the state of the class as a :class:`dict`."""
115
+ return dict(self.__dict__.items())
116
+
117
+ def load_state_dict(self, state_dict):
118
+ """Loads the class's state.
119
+ Args:
120
+ state_dict (dict): scaler state. Should be an object returned
121
+ from a call to :meth:`state_dict`.
122
+ """
123
+ self.__dict__.update(state_dict)
124
+
125
+ def get_value(self):
126
+ """Gets the current EMA decay rate."""
127
+ epoch = max(0, self.last_epoch - self.start_at)
128
+ value = 1 - (1 + epoch / self.inv_gamma) ** -self.power
129
+ return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value))
130
+
131
+ def step(self):
132
+ """Updates the step count."""
133
+ self.last_epoch += 1
134
+
135
+
136
+ class InverseLR(optim.lr_scheduler._LRScheduler):
137
+ """Implements an inverse decay learning rate schedule with an optional exponential
138
+ warmup. When last_epoch=-1, sets initial lr as lr.
139
+ inv_gamma is the number of steps/epochs required for the learning rate to decay to
140
+ (1 / 2)**power of its original value.
141
+ Args:
142
+ optimizer (Optimizer): Wrapped optimizer.
143
+ inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1.
144
+ power (float): Exponential factor of learning rate decay. Default: 1.
145
+ warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
146
+ Default: 0.
147
+ min_lr (float): The minimum learning rate. Default: 0.
148
+ last_epoch (int): The index of last epoch. Default: -1.
149
+ verbose (bool): If ``True``, prints a message to stdout for
150
+ each update. Default: ``False``.
151
+ """
152
+
153
+ def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., min_lr=0.,
154
+ last_epoch=-1, verbose=False):
155
+ self.inv_gamma = inv_gamma
156
+ self.power = power
157
+ if not 0. <= warmup < 1:
158
+ raise ValueError('Invalid value for warmup')
159
+ self.warmup = warmup
160
+ self.min_lr = min_lr
161
+ super().__init__(optimizer, last_epoch, verbose)
162
+
163
+ def get_lr(self):
164
+ if not self._get_lr_called_within_step:
165
+ warnings.warn("To get the last learning rate computed by the scheduler, "
166
+ "please use `get_last_lr()`.")
167
+
168
+ return self._get_closed_form_lr()
169
+
170
+ def _get_closed_form_lr(self):
171
+ warmup = 1 - self.warmup ** (self.last_epoch + 1)
172
+ lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power
173
+ return [warmup * max(self.min_lr, base_lr * lr_mult)
174
+ for base_lr in self.base_lrs]
175
+
176
+
177
+ class ExponentialLR(optim.lr_scheduler._LRScheduler):
178
+ """Implements an exponential learning rate schedule with an optional exponential
179
+ warmup. When last_epoch=-1, sets initial lr as lr. Decays the learning rate
180
+ continuously by decay (default 0.5) every num_steps steps.
181
+ Args:
182
+ optimizer (Optimizer): Wrapped optimizer.
183
+ num_steps (float): The number of steps to decay the learning rate by decay in.
184
+ decay (float): The factor by which to decay the learning rate every num_steps
185
+ steps. Default: 0.5.
186
+ warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
187
+ Default: 0.
188
+ min_lr (float): The minimum learning rate. Default: 0.
189
+ last_epoch (int): The index of last epoch. Default: -1.
190
+ verbose (bool): If ``True``, prints a message to stdout for
191
+ each update. Default: ``False``.
192
+ """
193
+
194
+ def __init__(self, optimizer, num_steps, decay=0.5, warmup=0., min_lr=0.,
195
+ last_epoch=-1, verbose=False):
196
+ self.num_steps = num_steps
197
+ self.decay = decay
198
+ if not 0. <= warmup < 1:
199
+ raise ValueError('Invalid value for warmup')
200
+ self.warmup = warmup
201
+ self.min_lr = min_lr
202
+ super().__init__(optimizer, last_epoch, verbose)
203
+
204
+ def get_lr(self):
205
+ if not self._get_lr_called_within_step:
206
+ warnings.warn("To get the last learning rate computed by the scheduler, "
207
+ "please use `get_last_lr()`.")
208
+
209
+ return self._get_closed_form_lr()
210
+
211
+ def _get_closed_form_lr(self):
212
+ warmup = 1 - self.warmup ** (self.last_epoch + 1)
213
+ lr_mult = (self.decay ** (1 / self.num_steps)) ** self.last_epoch
214
+ return [warmup * max(self.min_lr, base_lr * lr_mult)
215
+ for base_lr in self.base_lrs]
216
+
217
+
218
+ def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32):
219
+ """Draws samples from an lognormal distribution."""
220
+ return (torch.randn(shape, device=device, dtype=dtype) * scale + loc).exp()
221
+
222
+
223
+ def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
224
+ """Draws samples from an optionally truncated log-logistic distribution."""
225
+ min_value = torch.as_tensor(min_value, device=device, dtype=torch.float64)
226
+ max_value = torch.as_tensor(max_value, device=device, dtype=torch.float64)
227
+ min_cdf = min_value.log().sub(loc).div(scale).sigmoid()
228
+ max_cdf = max_value.log().sub(loc).div(scale).sigmoid()
229
+ u = torch.rand(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf) + min_cdf
230
+ return u.logit().mul(scale).add(loc).exp().to(dtype)
231
+
232
+
233
+ def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32):
234
+ """Draws samples from an log-uniform distribution."""
235
+ min_value = math.log(min_value)
236
+ max_value = math.log(max_value)
237
+ return (torch.rand(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp()
238
+
239
+
240
+ def rand_v_diffusion(shape, sigma_data=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
241
+ """Draws samples from a truncated v-diffusion training timestep distribution."""
242
+ min_cdf = math.atan(min_value / sigma_data) * 2 / math.pi
243
+ max_cdf = math.atan(max_value / sigma_data) * 2 / math.pi
244
+ u = torch.rand(shape, device=device, dtype=dtype) * (max_cdf - min_cdf) + min_cdf
245
+ return torch.tan(u * math.pi / 2) * sigma_data
246
+
247
+
248
+ def rand_split_log_normal(shape, loc, scale_1, scale_2, device='cpu', dtype=torch.float32):
249
+ """Draws samples from a split lognormal distribution."""
250
+ n = torch.randn(shape, device=device, dtype=dtype).abs()
251
+ u = torch.rand(shape, device=device, dtype=dtype)
252
+ n_left = n * -scale_1 + loc
253
+ n_right = n * scale_2 + loc
254
+ ratio = scale_1 / (scale_1 + scale_2)
255
+ return torch.where(u < ratio, n_left, n_right).exp()
256
+
257
+
258
+ class FolderOfImages(data.Dataset):
259
+ """Recursively finds all images in a directory. It does not support
260
+ classes/targets."""
261
+
262
+ IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'}
263
+
264
+ def __init__(self, root, transform=None):
265
+ super().__init__()
266
+ self.root = Path(root)
267
+ self.transform = nn.Identity() if transform is None else transform
268
+ self.paths = sorted(path for path in self.root.rglob('*') if path.suffix.lower() in self.IMG_EXTENSIONS)
269
+
270
+ def __repr__(self):
271
+ return f'FolderOfImages(root="{self.root}", len: {len(self)})'
272
+
273
+ def __len__(self):
274
+ return len(self.paths)
275
+
276
+ def __getitem__(self, key):
277
+ path = self.paths[key]
278
+ with open(path, 'rb') as f:
279
+ image = Image.open(f).convert('RGB')
280
+ image = self.transform(image)
281
+ return image,
282
+
283
+
284
+ class CSVLogger:
285
+ def __init__(self, filename, columns):
286
+ self.filename = Path(filename)
287
+ self.columns = columns
288
+ if self.filename.exists():
289
+ self.file = open(self.filename, 'a')
290
+ else:
291
+ self.file = open(self.filename, 'w')
292
+ self.write(*self.columns)
293
+
294
+ def write(self, *args):
295
+ print(*args, sep=',', file=self.file, flush=True)
296
+
297
+
298
+ @contextmanager
299
+ def tf32_mode(cudnn=None, matmul=None):
300
+ """A context manager that sets whether TF32 is allowed on cuDNN or matmul."""
301
+ cudnn_old = torch.backends.cudnn.allow_tf32
302
+ matmul_old = torch.backends.cuda.matmul.allow_tf32
303
+ try:
304
+ if cudnn is not None:
305
+ torch.backends.cudnn.allow_tf32 = cudnn
306
+ if matmul is not None:
307
+ torch.backends.cuda.matmul.allow_tf32 = matmul
308
+ yield
309
+ finally:
310
+ if cudnn is not None:
311
+ torch.backends.cudnn.allow_tf32 = cudnn_old
312
+ if matmul is not None:
313
+ torch.backends.cuda.matmul.allow_tf32 = matmul_old
comfy/latent_formats.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class LatentFormat:
4
+ scale_factor = 1.0
5
+ latent_channels = 4
6
+ latent_rgb_factors = None
7
+ taesd_decoder_name = None
8
+
9
+ def process_in(self, latent):
10
+ return latent * self.scale_factor
11
+
12
+ def process_out(self, latent):
13
+ return latent / self.scale_factor
14
+
15
+ class SD15(LatentFormat):
16
+ def __init__(self, scale_factor=0.18215):
17
+ self.scale_factor = scale_factor
18
+ self.latent_rgb_factors = [
19
+ # R G B
20
+ [ 0.3512, 0.2297, 0.3227],
21
+ [ 0.3250, 0.4974, 0.2350],
22
+ [-0.2829, 0.1762, 0.2721],
23
+ [-0.2120, -0.2616, -0.7177]
24
+ ]
25
+ self.taesd_decoder_name = "taesd_decoder"
26
+
27
+ class SDXL(LatentFormat):
28
+ scale_factor = 0.13025
29
+
30
+ def __init__(self):
31
+ self.latent_rgb_factors = [
32
+ # R G B
33
+ [ 0.3920, 0.4054, 0.4549],
34
+ [-0.2634, -0.0196, 0.0653],
35
+ [ 0.0568, 0.1687, -0.0755],
36
+ [-0.3112, -0.2359, -0.2076]
37
+ ]
38
+ self.taesd_decoder_name = "taesdxl_decoder"
39
+
40
+ class SDXL_Playground_2_5(LatentFormat):
41
+ def __init__(self):
42
+ self.scale_factor = 0.5
43
+ self.latents_mean = torch.tensor([-1.6574, 1.886, -1.383, 2.5155]).view(1, 4, 1, 1)
44
+ self.latents_std = torch.tensor([8.4927, 5.9022, 6.5498, 5.2299]).view(1, 4, 1, 1)
45
+
46
+ self.latent_rgb_factors = [
47
+ # R G B
48
+ [ 0.3920, 0.4054, 0.4549],
49
+ [-0.2634, -0.0196, 0.0653],
50
+ [ 0.0568, 0.1687, -0.0755],
51
+ [-0.3112, -0.2359, -0.2076]
52
+ ]
53
+ self.taesd_decoder_name = "taesdxl_decoder"
54
+
55
+ def process_in(self, latent):
56
+ latents_mean = self.latents_mean.to(latent.device, latent.dtype)
57
+ latents_std = self.latents_std.to(latent.device, latent.dtype)
58
+ return (latent - latents_mean) * self.scale_factor / latents_std
59
+
60
+ def process_out(self, latent):
61
+ latents_mean = self.latents_mean.to(latent.device, latent.dtype)
62
+ latents_std = self.latents_std.to(latent.device, latent.dtype)
63
+ return latent * latents_std / self.scale_factor + latents_mean
64
+
65
+
66
+ class SD_X4(LatentFormat):
67
+ def __init__(self):
68
+ self.scale_factor = 0.08333
69
+ self.latent_rgb_factors = [
70
+ [-0.2340, -0.3863, -0.3257],
71
+ [ 0.0994, 0.0885, -0.0908],
72
+ [-0.2833, -0.2349, -0.3741],
73
+ [ 0.2523, -0.0055, -0.1651]
74
+ ]
75
+
76
+ class SC_Prior(LatentFormat):
77
+ latent_channels = 16
78
+ def __init__(self):
79
+ self.scale_factor = 1.0
80
+ self.latent_rgb_factors = [
81
+ [-0.0326, -0.0204, -0.0127],
82
+ [-0.1592, -0.0427, 0.0216],
83
+ [ 0.0873, 0.0638, -0.0020],
84
+ [-0.0602, 0.0442, 0.1304],
85
+ [ 0.0800, -0.0313, -0.1796],
86
+ [-0.0810, -0.0638, -0.1581],
87
+ [ 0.1791, 0.1180, 0.0967],
88
+ [ 0.0740, 0.1416, 0.0432],
89
+ [-0.1745, -0.1888, -0.1373],
90
+ [ 0.2412, 0.1577, 0.0928],
91
+ [ 0.1908, 0.0998, 0.0682],
92
+ [ 0.0209, 0.0365, -0.0092],
93
+ [ 0.0448, -0.0650, -0.1728],
94
+ [-0.1658, -0.1045, -0.1308],
95
+ [ 0.0542, 0.1545, 0.1325],
96
+ [-0.0352, -0.1672, -0.2541]
97
+ ]
98
+
99
+ class SC_B(LatentFormat):
100
+ def __init__(self):
101
+ self.scale_factor = 1.0 / 0.43
102
+ self.latent_rgb_factors = [
103
+ [ 0.1121, 0.2006, 0.1023],
104
+ [-0.2093, -0.0222, -0.0195],
105
+ [-0.3087, -0.1535, 0.0366],
106
+ [ 0.0290, -0.1574, -0.4078]
107
+ ]
108
+
109
+ class SD3(LatentFormat):
110
+ latent_channels = 16
111
+ def __init__(self):
112
+ self.scale_factor = 1.5305
113
+ self.shift_factor = 0.0609
114
+ self.latent_rgb_factors = [
115
+ [-0.0645, 0.0177, 0.1052],
116
+ [ 0.0028, 0.0312, 0.0650],
117
+ [ 0.1848, 0.0762, 0.0360],
118
+ [ 0.0944, 0.0360, 0.0889],
119
+ [ 0.0897, 0.0506, -0.0364],
120
+ [-0.0020, 0.1203, 0.0284],
121
+ [ 0.0855, 0.0118, 0.0283],
122
+ [-0.0539, 0.0658, 0.1047],
123
+ [-0.0057, 0.0116, 0.0700],
124
+ [-0.0412, 0.0281, -0.0039],
125
+ [ 0.1106, 0.1171, 0.1220],
126
+ [-0.0248, 0.0682, -0.0481],
127
+ [ 0.0815, 0.0846, 0.1207],
128
+ [-0.0120, -0.0055, -0.0867],
129
+ [-0.0749, -0.0634, -0.0456],
130
+ [-0.1418, -0.1457, -0.1259]
131
+ ]
132
+ self.taesd_decoder_name = "taesd3_decoder"
133
+
134
+ def process_in(self, latent):
135
+ return (latent - self.shift_factor) * self.scale_factor
136
+
137
+ def process_out(self, latent):
138
+ return (latent / self.scale_factor) + self.shift_factor
139
+
140
+ class StableAudio1(LatentFormat):
141
+ latent_channels = 64
142
+
143
+ class Flux(SD3):
144
+ latent_channels = 16
145
+ def __init__(self):
146
+ self.scale_factor = 0.3611
147
+ self.shift_factor = 0.1159
148
+ self.latent_rgb_factors =[
149
+ [-0.0404, 0.0159, 0.0609],
150
+ [ 0.0043, 0.0298, 0.0850],
151
+ [ 0.0328, -0.0749, -0.0503],
152
+ [-0.0245, 0.0085, 0.0549],
153
+ [ 0.0966, 0.0894, 0.0530],
154
+ [ 0.0035, 0.0399, 0.0123],
155
+ [ 0.0583, 0.1184, 0.1262],
156
+ [-0.0191, -0.0206, -0.0306],
157
+ [-0.0324, 0.0055, 0.1001],
158
+ [ 0.0955, 0.0659, -0.0545],
159
+ [-0.0504, 0.0231, -0.0013],
160
+ [ 0.0500, -0.0008, -0.0088],
161
+ [ 0.0982, 0.0941, 0.0976],
162
+ [-0.1233, -0.0280, -0.0897],
163
+ [-0.0005, -0.0530, -0.0020],
164
+ [-0.1273, -0.0932, -0.0680]
165
+ ]
166
+ self.taesd_decoder_name = "taef1_decoder"
167
+
168
+ def process_in(self, latent):
169
+ return (latent - self.shift_factor) * self.scale_factor
170
+
171
+ def process_out(self, latent):
172
+ return (latent / self.scale_factor) + self.shift_factor
comfy/ldm/audio/autoencoder.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code adapted from: https://github.com/Stability-AI/stable-audio-tools
2
+
3
+ import torch
4
+ from torch import nn
5
+ from typing import Literal, Dict, Any
6
+ import math
7
+ import comfy.ops
8
+ ops = comfy.ops.disable_weight_init
9
+
10
+ def vae_sample(mean, scale):
11
+ stdev = nn.functional.softplus(scale) + 1e-4
12
+ var = stdev * stdev
13
+ logvar = torch.log(var)
14
+ latents = torch.randn_like(mean) * stdev + mean
15
+
16
+ kl = (mean * mean + var - logvar - 1).sum(1).mean()
17
+
18
+ return latents, kl
19
+
20
+ class VAEBottleneck(nn.Module):
21
+ def __init__(self):
22
+ super().__init__()
23
+ self.is_discrete = False
24
+
25
+ def encode(self, x, return_info=False, **kwargs):
26
+ info = {}
27
+
28
+ mean, scale = x.chunk(2, dim=1)
29
+
30
+ x, kl = vae_sample(mean, scale)
31
+
32
+ info["kl"] = kl
33
+
34
+ if return_info:
35
+ return x, info
36
+ else:
37
+ return x
38
+
39
+ def decode(self, x):
40
+ return x
41
+
42
+
43
+ def snake_beta(x, alpha, beta):
44
+ return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
45
+
46
+ # Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
47
+ class SnakeBeta(nn.Module):
48
+
49
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
50
+ super(SnakeBeta, self).__init__()
51
+ self.in_features = in_features
52
+
53
+ # initialize alpha
54
+ self.alpha_logscale = alpha_logscale
55
+ if self.alpha_logscale: # log scale alphas initialized to zeros
56
+ self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
57
+ self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
58
+ else: # linear scale alphas initialized to ones
59
+ self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
60
+ self.beta = nn.Parameter(torch.ones(in_features) * alpha)
61
+
62
+ # self.alpha.requires_grad = alpha_trainable
63
+ # self.beta.requires_grad = alpha_trainable
64
+
65
+ self.no_div_by_zero = 0.000000001
66
+
67
+ def forward(self, x):
68
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1).to(x.device) # line up with x to [B, C, T]
69
+ beta = self.beta.unsqueeze(0).unsqueeze(-1).to(x.device)
70
+ if self.alpha_logscale:
71
+ alpha = torch.exp(alpha)
72
+ beta = torch.exp(beta)
73
+ x = snake_beta(x, alpha, beta)
74
+
75
+ return x
76
+
77
+ def WNConv1d(*args, **kwargs):
78
+ try:
79
+ return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
80
+ except:
81
+ return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs)) #support pytorch 2.1 and older
82
+
83
+ def WNConvTranspose1d(*args, **kwargs):
84
+ try:
85
+ return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
86
+ except:
87
+ return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) #support pytorch 2.1 and older
88
+
89
+ def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
90
+ if activation == "elu":
91
+ act = torch.nn.ELU()
92
+ elif activation == "snake":
93
+ act = SnakeBeta(channels)
94
+ elif activation == "none":
95
+ act = torch.nn.Identity()
96
+ else:
97
+ raise ValueError(f"Unknown activation {activation}")
98
+
99
+ if antialias:
100
+ act = Activation1d(act)
101
+
102
+ return act
103
+
104
+
105
+ class ResidualUnit(nn.Module):
106
+ def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
107
+ super().__init__()
108
+
109
+ self.dilation = dilation
110
+
111
+ padding = (dilation * (7-1)) // 2
112
+
113
+ self.layers = nn.Sequential(
114
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
115
+ WNConv1d(in_channels=in_channels, out_channels=out_channels,
116
+ kernel_size=7, dilation=dilation, padding=padding),
117
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
118
+ WNConv1d(in_channels=out_channels, out_channels=out_channels,
119
+ kernel_size=1)
120
+ )
121
+
122
+ def forward(self, x):
123
+ res = x
124
+
125
+ #x = checkpoint(self.layers, x)
126
+ x = self.layers(x)
127
+
128
+ return x + res
129
+
130
+ class EncoderBlock(nn.Module):
131
+ def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
132
+ super().__init__()
133
+
134
+ self.layers = nn.Sequential(
135
+ ResidualUnit(in_channels=in_channels,
136
+ out_channels=in_channels, dilation=1, use_snake=use_snake),
137
+ ResidualUnit(in_channels=in_channels,
138
+ out_channels=in_channels, dilation=3, use_snake=use_snake),
139
+ ResidualUnit(in_channels=in_channels,
140
+ out_channels=in_channels, dilation=9, use_snake=use_snake),
141
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
142
+ WNConv1d(in_channels=in_channels, out_channels=out_channels,
143
+ kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
144
+ )
145
+
146
+ def forward(self, x):
147
+ return self.layers(x)
148
+
149
+ class DecoderBlock(nn.Module):
150
+ def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
151
+ super().__init__()
152
+
153
+ if use_nearest_upsample:
154
+ upsample_layer = nn.Sequential(
155
+ nn.Upsample(scale_factor=stride, mode="nearest"),
156
+ WNConv1d(in_channels=in_channels,
157
+ out_channels=out_channels,
158
+ kernel_size=2*stride,
159
+ stride=1,
160
+ bias=False,
161
+ padding='same')
162
+ )
163
+ else:
164
+ upsample_layer = WNConvTranspose1d(in_channels=in_channels,
165
+ out_channels=out_channels,
166
+ kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
167
+
168
+ self.layers = nn.Sequential(
169
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
170
+ upsample_layer,
171
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
172
+ dilation=1, use_snake=use_snake),
173
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
174
+ dilation=3, use_snake=use_snake),
175
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
176
+ dilation=9, use_snake=use_snake),
177
+ )
178
+
179
+ def forward(self, x):
180
+ return self.layers(x)
181
+
182
+ class OobleckEncoder(nn.Module):
183
+ def __init__(self,
184
+ in_channels=2,
185
+ channels=128,
186
+ latent_dim=32,
187
+ c_mults = [1, 2, 4, 8],
188
+ strides = [2, 4, 8, 8],
189
+ use_snake=False,
190
+ antialias_activation=False
191
+ ):
192
+ super().__init__()
193
+
194
+ c_mults = [1] + c_mults
195
+
196
+ self.depth = len(c_mults)
197
+
198
+ layers = [
199
+ WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
200
+ ]
201
+
202
+ for i in range(self.depth-1):
203
+ layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
204
+
205
+ layers += [
206
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
207
+ WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
208
+ ]
209
+
210
+ self.layers = nn.Sequential(*layers)
211
+
212
+ def forward(self, x):
213
+ return self.layers(x)
214
+
215
+
216
+ class OobleckDecoder(nn.Module):
217
+ def __init__(self,
218
+ out_channels=2,
219
+ channels=128,
220
+ latent_dim=32,
221
+ c_mults = [1, 2, 4, 8],
222
+ strides = [2, 4, 8, 8],
223
+ use_snake=False,
224
+ antialias_activation=False,
225
+ use_nearest_upsample=False,
226
+ final_tanh=True):
227
+ super().__init__()
228
+
229
+ c_mults = [1] + c_mults
230
+
231
+ self.depth = len(c_mults)
232
+
233
+ layers = [
234
+ WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
235
+ ]
236
+
237
+ for i in range(self.depth-1, 0, -1):
238
+ layers += [DecoderBlock(
239
+ in_channels=c_mults[i]*channels,
240
+ out_channels=c_mults[i-1]*channels,
241
+ stride=strides[i-1],
242
+ use_snake=use_snake,
243
+ antialias_activation=antialias_activation,
244
+ use_nearest_upsample=use_nearest_upsample
245
+ )
246
+ ]
247
+
248
+ layers += [
249
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
250
+ WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
251
+ nn.Tanh() if final_tanh else nn.Identity()
252
+ ]
253
+
254
+ self.layers = nn.Sequential(*layers)
255
+
256
+ def forward(self, x):
257
+ return self.layers(x)
258
+
259
+
260
+ class AudioOobleckVAE(nn.Module):
261
+ def __init__(self,
262
+ in_channels=2,
263
+ channels=128,
264
+ latent_dim=64,
265
+ c_mults = [1, 2, 4, 8, 16],
266
+ strides = [2, 4, 4, 8, 8],
267
+ use_snake=True,
268
+ antialias_activation=False,
269
+ use_nearest_upsample=False,
270
+ final_tanh=False):
271
+ super().__init__()
272
+ self.encoder = OobleckEncoder(in_channels, channels, latent_dim * 2, c_mults, strides, use_snake, antialias_activation)
273
+ self.decoder = OobleckDecoder(in_channels, channels, latent_dim, c_mults, strides, use_snake, antialias_activation,
274
+ use_nearest_upsample=use_nearest_upsample, final_tanh=final_tanh)
275
+ self.bottleneck = VAEBottleneck()
276
+
277
+ def encode(self, x):
278
+ return self.bottleneck.encode(self.encoder(x))
279
+
280
+ def decode(self, x):
281
+ return self.decoder(self.bottleneck.decode(x))
282
+
comfy/ldm/audio/dit.py ADDED
@@ -0,0 +1,891 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code adapted from: https://github.com/Stability-AI/stable-audio-tools
2
+
3
+ from comfy.ldm.modules.attention import optimized_attention
4
+ import typing as tp
5
+
6
+ import torch
7
+
8
+ from einops import rearrange
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+ import math
12
+ import comfy.ops
13
+
14
+ class FourierFeatures(nn.Module):
15
+ def __init__(self, in_features, out_features, std=1., dtype=None, device=None):
16
+ super().__init__()
17
+ assert out_features % 2 == 0
18
+ self.weight = nn.Parameter(torch.empty(
19
+ [out_features // 2, in_features], dtype=dtype, device=device))
20
+
21
+ def forward(self, input):
22
+ f = 2 * math.pi * input @ comfy.ops.cast_to_input(self.weight.T, input)
23
+ return torch.cat([f.cos(), f.sin()], dim=-1)
24
+
25
+ # norms
26
+ class LayerNorm(nn.Module):
27
+ def __init__(self, dim, bias=False, fix_scale=False, dtype=None, device=None):
28
+ """
29
+ bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
30
+ """
31
+ super().__init__()
32
+
33
+ self.gamma = nn.Parameter(torch.empty(dim, dtype=dtype, device=device))
34
+
35
+ if bias:
36
+ self.beta = nn.Parameter(torch.empty(dim, dtype=dtype, device=device))
37
+ else:
38
+ self.beta = None
39
+
40
+ def forward(self, x):
41
+ beta = self.beta
42
+ if beta is not None:
43
+ beta = comfy.ops.cast_to_input(beta, x)
44
+ return F.layer_norm(x, x.shape[-1:], weight=comfy.ops.cast_to_input(self.gamma, x), bias=beta)
45
+
46
+ class GLU(nn.Module):
47
+ def __init__(
48
+ self,
49
+ dim_in,
50
+ dim_out,
51
+ activation,
52
+ use_conv = False,
53
+ conv_kernel_size = 3,
54
+ dtype=None,
55
+ device=None,
56
+ operations=None,
57
+ ):
58
+ super().__init__()
59
+ self.act = activation
60
+ self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2), dtype=dtype, device=device)
61
+ self.use_conv = use_conv
62
+
63
+ def forward(self, x):
64
+ if self.use_conv:
65
+ x = rearrange(x, 'b n d -> b d n')
66
+ x = self.proj(x)
67
+ x = rearrange(x, 'b d n -> b n d')
68
+ else:
69
+ x = self.proj(x)
70
+
71
+ x, gate = x.chunk(2, dim = -1)
72
+ return x * self.act(gate)
73
+
74
+ class AbsolutePositionalEmbedding(nn.Module):
75
+ def __init__(self, dim, max_seq_len):
76
+ super().__init__()
77
+ self.scale = dim ** -0.5
78
+ self.max_seq_len = max_seq_len
79
+ self.emb = nn.Embedding(max_seq_len, dim)
80
+
81
+ def forward(self, x, pos = None, seq_start_pos = None):
82
+ seq_len, device = x.shape[1], x.device
83
+ assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
84
+
85
+ if pos is None:
86
+ pos = torch.arange(seq_len, device = device)
87
+
88
+ if seq_start_pos is not None:
89
+ pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
90
+
91
+ pos_emb = self.emb(pos)
92
+ pos_emb = pos_emb * self.scale
93
+ return pos_emb
94
+
95
+ class ScaledSinusoidalEmbedding(nn.Module):
96
+ def __init__(self, dim, theta = 10000):
97
+ super().__init__()
98
+ assert (dim % 2) == 0, 'dimension must be divisible by 2'
99
+ self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
100
+
101
+ half_dim = dim // 2
102
+ freq_seq = torch.arange(half_dim).float() / half_dim
103
+ inv_freq = theta ** -freq_seq
104
+ self.register_buffer('inv_freq', inv_freq, persistent = False)
105
+
106
+ def forward(self, x, pos = None, seq_start_pos = None):
107
+ seq_len, device = x.shape[1], x.device
108
+
109
+ if pos is None:
110
+ pos = torch.arange(seq_len, device = device)
111
+
112
+ if seq_start_pos is not None:
113
+ pos = pos - seq_start_pos[..., None]
114
+
115
+ emb = torch.einsum('i, j -> i j', pos, self.inv_freq)
116
+ emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
117
+ return emb * self.scale
118
+
119
+ class RotaryEmbedding(nn.Module):
120
+ def __init__(
121
+ self,
122
+ dim,
123
+ use_xpos = False,
124
+ scale_base = 512,
125
+ interpolation_factor = 1.,
126
+ base = 10000,
127
+ base_rescale_factor = 1.,
128
+ dtype=None,
129
+ device=None,
130
+ ):
131
+ super().__init__()
132
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
133
+ # has some connection to NTK literature
134
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
135
+ base *= base_rescale_factor ** (dim / (dim - 2))
136
+
137
+ # inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
138
+ self.register_buffer('inv_freq', torch.empty((dim // 2,), device=device, dtype=dtype))
139
+
140
+ assert interpolation_factor >= 1.
141
+ self.interpolation_factor = interpolation_factor
142
+
143
+ if not use_xpos:
144
+ self.register_buffer('scale', None)
145
+ return
146
+
147
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
148
+
149
+ self.scale_base = scale_base
150
+ self.register_buffer('scale', scale)
151
+
152
+ def forward_from_seq_len(self, seq_len, device, dtype):
153
+ # device = self.inv_freq.device
154
+
155
+ t = torch.arange(seq_len, device=device, dtype=dtype)
156
+ return self.forward(t)
157
+
158
+ def forward(self, t):
159
+ # device = self.inv_freq.device
160
+ device = t.device
161
+ dtype = t.dtype
162
+
163
+ # t = t.to(torch.float32)
164
+
165
+ t = t / self.interpolation_factor
166
+
167
+ freqs = torch.einsum('i , j -> i j', t, comfy.ops.cast_to_input(self.inv_freq, t))
168
+ freqs = torch.cat((freqs, freqs), dim = -1)
169
+
170
+ if self.scale is None:
171
+ return freqs, 1.
172
+
173
+ power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
174
+ scale = comfy.ops.cast_to_input(self.scale, t) ** rearrange(power, 'n -> n 1')
175
+ scale = torch.cat((scale, scale), dim = -1)
176
+
177
+ return freqs, scale
178
+
179
+ def rotate_half(x):
180
+ x = rearrange(x, '... (j d) -> ... j d', j = 2)
181
+ x1, x2 = x.unbind(dim = -2)
182
+ return torch.cat((-x2, x1), dim = -1)
183
+
184
+ def apply_rotary_pos_emb(t, freqs, scale = 1):
185
+ out_dtype = t.dtype
186
+
187
+ # cast to float32 if necessary for numerical stability
188
+ dtype = t.dtype #reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
189
+ rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
190
+ freqs, t = freqs.to(dtype), t.to(dtype)
191
+ freqs = freqs[-seq_len:, :]
192
+
193
+ if t.ndim == 4 and freqs.ndim == 3:
194
+ freqs = rearrange(freqs, 'b n d -> b 1 n d')
195
+
196
+ # partial rotary embeddings, Wang et al. GPT-J
197
+ t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
198
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
199
+
200
+ t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)
201
+
202
+ return torch.cat((t, t_unrotated), dim = -1)
203
+
204
+ class FeedForward(nn.Module):
205
+ def __init__(
206
+ self,
207
+ dim,
208
+ dim_out = None,
209
+ mult = 4,
210
+ no_bias = False,
211
+ glu = True,
212
+ use_conv = False,
213
+ conv_kernel_size = 3,
214
+ zero_init_output = True,
215
+ dtype=None,
216
+ device=None,
217
+ operations=None,
218
+ ):
219
+ super().__init__()
220
+ inner_dim = int(dim * mult)
221
+
222
+ # Default to SwiGLU
223
+
224
+ activation = nn.SiLU()
225
+
226
+ dim_out = dim if dim_out is None else dim_out
227
+
228
+ if glu:
229
+ linear_in = GLU(dim, inner_dim, activation, dtype=dtype, device=device, operations=operations)
230
+ else:
231
+ linear_in = nn.Sequential(
232
+ Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
233
+ operations.Linear(dim, inner_dim, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device),
234
+ Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
235
+ activation
236
+ )
237
+
238
+ linear_out = operations.Linear(inner_dim, dim_out, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device)
239
+
240
+ # # init last linear layer to 0
241
+ # if zero_init_output:
242
+ # nn.init.zeros_(linear_out.weight)
243
+ # if not no_bias:
244
+ # nn.init.zeros_(linear_out.bias)
245
+
246
+
247
+ self.ff = nn.Sequential(
248
+ linear_in,
249
+ Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
250
+ linear_out,
251
+ Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
252
+ )
253
+
254
+ def forward(self, x):
255
+ return self.ff(x)
256
+
257
+ class Attention(nn.Module):
258
+ def __init__(
259
+ self,
260
+ dim,
261
+ dim_heads = 64,
262
+ dim_context = None,
263
+ causal = False,
264
+ zero_init_output=True,
265
+ qk_norm = False,
266
+ natten_kernel_size = None,
267
+ dtype=None,
268
+ device=None,
269
+ operations=None,
270
+ ):
271
+ super().__init__()
272
+ self.dim = dim
273
+ self.dim_heads = dim_heads
274
+ self.causal = causal
275
+
276
+ dim_kv = dim_context if dim_context is not None else dim
277
+
278
+ self.num_heads = dim // dim_heads
279
+ self.kv_heads = dim_kv // dim_heads
280
+
281
+ if dim_context is not None:
282
+ self.to_q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
283
+ self.to_kv = operations.Linear(dim_kv, dim_kv * 2, bias=False, dtype=dtype, device=device)
284
+ else:
285
+ self.to_qkv = operations.Linear(dim, dim * 3, bias=False, dtype=dtype, device=device)
286
+
287
+ self.to_out = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
288
+
289
+ # if zero_init_output:
290
+ # nn.init.zeros_(self.to_out.weight)
291
+
292
+ self.qk_norm = qk_norm
293
+
294
+
295
+ def forward(
296
+ self,
297
+ x,
298
+ context = None,
299
+ mask = None,
300
+ context_mask = None,
301
+ rotary_pos_emb = None,
302
+ causal = None
303
+ ):
304
+ h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
305
+
306
+ kv_input = context if has_context else x
307
+
308
+ if hasattr(self, 'to_q'):
309
+ # Use separate linear projections for q and k/v
310
+ q = self.to_q(x)
311
+ q = rearrange(q, 'b n (h d) -> b h n d', h = h)
312
+
313
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
314
+
315
+ k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v))
316
+ else:
317
+ # Use fused linear projection
318
+ q, k, v = self.to_qkv(x).chunk(3, dim=-1)
319
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
320
+
321
+ # Normalize q and k for cosine sim attention
322
+ if self.qk_norm:
323
+ q = F.normalize(q, dim=-1)
324
+ k = F.normalize(k, dim=-1)
325
+
326
+ if rotary_pos_emb is not None and not has_context:
327
+ freqs, _ = rotary_pos_emb
328
+
329
+ q_dtype = q.dtype
330
+ k_dtype = k.dtype
331
+
332
+ q = q.to(torch.float32)
333
+ k = k.to(torch.float32)
334
+ freqs = freqs.to(torch.float32)
335
+
336
+ q = apply_rotary_pos_emb(q, freqs)
337
+ k = apply_rotary_pos_emb(k, freqs)
338
+
339
+ q = q.to(q_dtype)
340
+ k = k.to(k_dtype)
341
+
342
+ input_mask = context_mask
343
+
344
+ if input_mask is None and not has_context:
345
+ input_mask = mask
346
+
347
+ # determine masking
348
+ masks = []
349
+ final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account
350
+
351
+ if input_mask is not None:
352
+ input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
353
+ masks.append(~input_mask)
354
+
355
+ # Other masks will be added here later
356
+
357
+ if len(masks) > 0:
358
+ final_attn_mask = ~or_reduce(masks)
359
+
360
+ n, device = q.shape[-2], q.device
361
+
362
+ causal = self.causal if causal is None else causal
363
+
364
+ if n == 1 and causal:
365
+ causal = False
366
+
367
+ if h != kv_h:
368
+ # Repeat interleave kv_heads to match q_heads
369
+ heads_per_kv_head = h // kv_h
370
+ k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
371
+
372
+ out = optimized_attention(q, k, v, h, skip_reshape=True)
373
+ out = self.to_out(out)
374
+
375
+ if mask is not None:
376
+ mask = rearrange(mask, 'b n -> b n 1')
377
+ out = out.masked_fill(~mask, 0.)
378
+
379
+ return out
380
+
381
+ class ConformerModule(nn.Module):
382
+ def __init__(
383
+ self,
384
+ dim,
385
+ norm_kwargs = {},
386
+ ):
387
+
388
+ super().__init__()
389
+
390
+ self.dim = dim
391
+
392
+ self.in_norm = LayerNorm(dim, **norm_kwargs)
393
+ self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
394
+ self.glu = GLU(dim, dim, nn.SiLU())
395
+ self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False)
396
+ self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm
397
+ self.swish = nn.SiLU()
398
+ self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
399
+
400
+ def forward(self, x):
401
+ x = self.in_norm(x)
402
+ x = rearrange(x, 'b n d -> b d n')
403
+ x = self.pointwise_conv(x)
404
+ x = rearrange(x, 'b d n -> b n d')
405
+ x = self.glu(x)
406
+ x = rearrange(x, 'b n d -> b d n')
407
+ x = self.depthwise_conv(x)
408
+ x = rearrange(x, 'b d n -> b n d')
409
+ x = self.mid_norm(x)
410
+ x = self.swish(x)
411
+ x = rearrange(x, 'b n d -> b d n')
412
+ x = self.pointwise_conv_2(x)
413
+ x = rearrange(x, 'b d n -> b n d')
414
+
415
+ return x
416
+
417
+ class TransformerBlock(nn.Module):
418
+ def __init__(
419
+ self,
420
+ dim,
421
+ dim_heads = 64,
422
+ cross_attend = False,
423
+ dim_context = None,
424
+ global_cond_dim = None,
425
+ causal = False,
426
+ zero_init_branch_outputs = True,
427
+ conformer = False,
428
+ layer_ix = -1,
429
+ remove_norms = False,
430
+ attn_kwargs = {},
431
+ ff_kwargs = {},
432
+ norm_kwargs = {},
433
+ dtype=None,
434
+ device=None,
435
+ operations=None,
436
+ ):
437
+
438
+ super().__init__()
439
+ self.dim = dim
440
+ self.dim_heads = dim_heads
441
+ self.cross_attend = cross_attend
442
+ self.dim_context = dim_context
443
+ self.causal = causal
444
+
445
+ self.pre_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
446
+
447
+ self.self_attn = Attention(
448
+ dim,
449
+ dim_heads = dim_heads,
450
+ causal = causal,
451
+ zero_init_output=zero_init_branch_outputs,
452
+ dtype=dtype,
453
+ device=device,
454
+ operations=operations,
455
+ **attn_kwargs
456
+ )
457
+
458
+ if cross_attend:
459
+ self.cross_attend_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
460
+ self.cross_attn = Attention(
461
+ dim,
462
+ dim_heads = dim_heads,
463
+ dim_context=dim_context,
464
+ causal = causal,
465
+ zero_init_output=zero_init_branch_outputs,
466
+ dtype=dtype,
467
+ device=device,
468
+ operations=operations,
469
+ **attn_kwargs
470
+ )
471
+
472
+ self.ff_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
473
+ self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, dtype=dtype, device=device, operations=operations,**ff_kwargs)
474
+
475
+ self.layer_ix = layer_ix
476
+
477
+ self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None
478
+
479
+ self.global_cond_dim = global_cond_dim
480
+
481
+ if global_cond_dim is not None:
482
+ self.to_scale_shift_gate = nn.Sequential(
483
+ nn.SiLU(),
484
+ nn.Linear(global_cond_dim, dim * 6, bias=False)
485
+ )
486
+
487
+ nn.init.zeros_(self.to_scale_shift_gate[1].weight)
488
+ #nn.init.zeros_(self.to_scale_shift_gate_self[1].bias)
489
+
490
+ def forward(
491
+ self,
492
+ x,
493
+ context = None,
494
+ global_cond=None,
495
+ mask = None,
496
+ context_mask = None,
497
+ rotary_pos_emb = None
498
+ ):
499
+ if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
500
+
501
+ scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(global_cond).unsqueeze(1).chunk(6, dim = -1)
502
+
503
+ # self-attention with adaLN
504
+ residual = x
505
+ x = self.pre_norm(x)
506
+ x = x * (1 + scale_self) + shift_self
507
+ x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
508
+ x = x * torch.sigmoid(1 - gate_self)
509
+ x = x + residual
510
+
511
+ if context is not None:
512
+ x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
513
+
514
+ if self.conformer is not None:
515
+ x = x + self.conformer(x)
516
+
517
+ # feedforward with adaLN
518
+ residual = x
519
+ x = self.ff_norm(x)
520
+ x = x * (1 + scale_ff) + shift_ff
521
+ x = self.ff(x)
522
+ x = x * torch.sigmoid(1 - gate_ff)
523
+ x = x + residual
524
+
525
+ else:
526
+ x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
527
+
528
+ if context is not None:
529
+ x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
530
+
531
+ if self.conformer is not None:
532
+ x = x + self.conformer(x)
533
+
534
+ x = x + self.ff(self.ff_norm(x))
535
+
536
+ return x
537
+
538
+ class ContinuousTransformer(nn.Module):
539
+ def __init__(
540
+ self,
541
+ dim,
542
+ depth,
543
+ *,
544
+ dim_in = None,
545
+ dim_out = None,
546
+ dim_heads = 64,
547
+ cross_attend=False,
548
+ cond_token_dim=None,
549
+ global_cond_dim=None,
550
+ causal=False,
551
+ rotary_pos_emb=True,
552
+ zero_init_branch_outputs=True,
553
+ conformer=False,
554
+ use_sinusoidal_emb=False,
555
+ use_abs_pos_emb=False,
556
+ abs_pos_emb_max_length=10000,
557
+ dtype=None,
558
+ device=None,
559
+ operations=None,
560
+ **kwargs
561
+ ):
562
+
563
+ super().__init__()
564
+
565
+ self.dim = dim
566
+ self.depth = depth
567
+ self.causal = causal
568
+ self.layers = nn.ModuleList([])
569
+
570
+ self.project_in = operations.Linear(dim_in, dim, bias=False, dtype=dtype, device=device) if dim_in is not None else nn.Identity()
571
+ self.project_out = operations.Linear(dim, dim_out, bias=False, dtype=dtype, device=device) if dim_out is not None else nn.Identity()
572
+
573
+ if rotary_pos_emb:
574
+ self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32), device=device, dtype=dtype)
575
+ else:
576
+ self.rotary_pos_emb = None
577
+
578
+ self.use_sinusoidal_emb = use_sinusoidal_emb
579
+ if use_sinusoidal_emb:
580
+ self.pos_emb = ScaledSinusoidalEmbedding(dim)
581
+
582
+ self.use_abs_pos_emb = use_abs_pos_emb
583
+ if use_abs_pos_emb:
584
+ self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length)
585
+
586
+ for i in range(depth):
587
+ self.layers.append(
588
+ TransformerBlock(
589
+ dim,
590
+ dim_heads = dim_heads,
591
+ cross_attend = cross_attend,
592
+ dim_context = cond_token_dim,
593
+ global_cond_dim = global_cond_dim,
594
+ causal = causal,
595
+ zero_init_branch_outputs = zero_init_branch_outputs,
596
+ conformer=conformer,
597
+ layer_ix=i,
598
+ dtype=dtype,
599
+ device=device,
600
+ operations=operations,
601
+ **kwargs
602
+ )
603
+ )
604
+
605
+ def forward(
606
+ self,
607
+ x,
608
+ mask = None,
609
+ prepend_embeds = None,
610
+ prepend_mask = None,
611
+ global_cond = None,
612
+ return_info = False,
613
+ **kwargs
614
+ ):
615
+ batch, seq, device = *x.shape[:2], x.device
616
+
617
+ info = {
618
+ "hidden_states": [],
619
+ }
620
+
621
+ x = self.project_in(x)
622
+
623
+ if prepend_embeds is not None:
624
+ prepend_length, prepend_dim = prepend_embeds.shape[1:]
625
+
626
+ assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension'
627
+
628
+ x = torch.cat((prepend_embeds, x), dim = -2)
629
+
630
+ if prepend_mask is not None or mask is not None:
631
+ mask = mask if mask is not None else torch.ones((batch, seq), device = device, dtype = torch.bool)
632
+ prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length), device = device, dtype = torch.bool)
633
+
634
+ mask = torch.cat((prepend_mask, mask), dim = -1)
635
+
636
+ # Attention layers
637
+
638
+ if self.rotary_pos_emb is not None:
639
+ rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=x.dtype, device=x.device)
640
+ else:
641
+ rotary_pos_emb = None
642
+
643
+ if self.use_sinusoidal_emb or self.use_abs_pos_emb:
644
+ x = x + self.pos_emb(x)
645
+
646
+ # Iterate over the transformer layers
647
+ for layer in self.layers:
648
+ x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
649
+ # x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
650
+
651
+ if return_info:
652
+ info["hidden_states"].append(x)
653
+
654
+ x = self.project_out(x)
655
+
656
+ if return_info:
657
+ return x, info
658
+
659
+ return x
660
+
661
+ class AudioDiffusionTransformer(nn.Module):
662
+ def __init__(self,
663
+ io_channels=64,
664
+ patch_size=1,
665
+ embed_dim=1536,
666
+ cond_token_dim=768,
667
+ project_cond_tokens=False,
668
+ global_cond_dim=1536,
669
+ project_global_cond=True,
670
+ input_concat_dim=0,
671
+ prepend_cond_dim=0,
672
+ depth=24,
673
+ num_heads=24,
674
+ transformer_type: tp.Literal["continuous_transformer"] = "continuous_transformer",
675
+ global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
676
+ audio_model="",
677
+ dtype=None,
678
+ device=None,
679
+ operations=None,
680
+ **kwargs):
681
+
682
+ super().__init__()
683
+
684
+ self.dtype = dtype
685
+ self.cond_token_dim = cond_token_dim
686
+
687
+ # Timestep embeddings
688
+ timestep_features_dim = 256
689
+
690
+ self.timestep_features = FourierFeatures(1, timestep_features_dim, dtype=dtype, device=device)
691
+
692
+ self.to_timestep_embed = nn.Sequential(
693
+ operations.Linear(timestep_features_dim, embed_dim, bias=True, dtype=dtype, device=device),
694
+ nn.SiLU(),
695
+ operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device),
696
+ )
697
+
698
+ if cond_token_dim > 0:
699
+ # Conditioning tokens
700
+
701
+ cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim
702
+ self.to_cond_embed = nn.Sequential(
703
+ operations.Linear(cond_token_dim, cond_embed_dim, bias=False, dtype=dtype, device=device),
704
+ nn.SiLU(),
705
+ operations.Linear(cond_embed_dim, cond_embed_dim, bias=False, dtype=dtype, device=device)
706
+ )
707
+ else:
708
+ cond_embed_dim = 0
709
+
710
+ if global_cond_dim > 0:
711
+ # Global conditioning
712
+ global_embed_dim = global_cond_dim if not project_global_cond else embed_dim
713
+ self.to_global_embed = nn.Sequential(
714
+ operations.Linear(global_cond_dim, global_embed_dim, bias=False, dtype=dtype, device=device),
715
+ nn.SiLU(),
716
+ operations.Linear(global_embed_dim, global_embed_dim, bias=False, dtype=dtype, device=device)
717
+ )
718
+
719
+ if prepend_cond_dim > 0:
720
+ # Prepend conditioning
721
+ self.to_prepend_embed = nn.Sequential(
722
+ operations.Linear(prepend_cond_dim, embed_dim, bias=False, dtype=dtype, device=device),
723
+ nn.SiLU(),
724
+ operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
725
+ )
726
+
727
+ self.input_concat_dim = input_concat_dim
728
+
729
+ dim_in = io_channels + self.input_concat_dim
730
+
731
+ self.patch_size = patch_size
732
+
733
+ # Transformer
734
+
735
+ self.transformer_type = transformer_type
736
+
737
+ self.global_cond_type = global_cond_type
738
+
739
+ if self.transformer_type == "continuous_transformer":
740
+
741
+ global_dim = None
742
+
743
+ if self.global_cond_type == "adaLN":
744
+ # The global conditioning is projected to the embed_dim already at this point
745
+ global_dim = embed_dim
746
+
747
+ self.transformer = ContinuousTransformer(
748
+ dim=embed_dim,
749
+ depth=depth,
750
+ dim_heads=embed_dim // num_heads,
751
+ dim_in=dim_in * patch_size,
752
+ dim_out=io_channels * patch_size,
753
+ cross_attend = cond_token_dim > 0,
754
+ cond_token_dim = cond_embed_dim,
755
+ global_cond_dim=global_dim,
756
+ dtype=dtype,
757
+ device=device,
758
+ operations=operations,
759
+ **kwargs
760
+ )
761
+ else:
762
+ raise ValueError(f"Unknown transformer type: {self.transformer_type}")
763
+
764
+ self.preprocess_conv = operations.Conv1d(dim_in, dim_in, 1, bias=False, dtype=dtype, device=device)
765
+ self.postprocess_conv = operations.Conv1d(io_channels, io_channels, 1, bias=False, dtype=dtype, device=device)
766
+
767
+ def _forward(
768
+ self,
769
+ x,
770
+ t,
771
+ mask=None,
772
+ cross_attn_cond=None,
773
+ cross_attn_cond_mask=None,
774
+ input_concat_cond=None,
775
+ global_embed=None,
776
+ prepend_cond=None,
777
+ prepend_cond_mask=None,
778
+ return_info=False,
779
+ **kwargs):
780
+
781
+ if cross_attn_cond is not None:
782
+ cross_attn_cond = self.to_cond_embed(cross_attn_cond)
783
+
784
+ if global_embed is not None:
785
+ # Project the global conditioning to the embedding dimension
786
+ global_embed = self.to_global_embed(global_embed)
787
+
788
+ prepend_inputs = None
789
+ prepend_mask = None
790
+ prepend_length = 0
791
+ if prepend_cond is not None:
792
+ # Project the prepend conditioning to the embedding dimension
793
+ prepend_cond = self.to_prepend_embed(prepend_cond)
794
+
795
+ prepend_inputs = prepend_cond
796
+ if prepend_cond_mask is not None:
797
+ prepend_mask = prepend_cond_mask
798
+
799
+ if input_concat_cond is not None:
800
+
801
+ # Interpolate input_concat_cond to the same length as x
802
+ if input_concat_cond.shape[2] != x.shape[2]:
803
+ input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
804
+
805
+ x = torch.cat([x, input_concat_cond], dim=1)
806
+
807
+ # Get the batch of timestep embeddings
808
+ timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None]).to(x.dtype)) # (b, embed_dim)
809
+
810
+ # Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
811
+ if global_embed is not None:
812
+ global_embed = global_embed + timestep_embed
813
+ else:
814
+ global_embed = timestep_embed
815
+
816
+ # Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
817
+ if self.global_cond_type == "prepend":
818
+ if prepend_inputs is None:
819
+ # Prepend inputs are just the global embed, and the mask is all ones
820
+ prepend_inputs = global_embed.unsqueeze(1)
821
+ prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)
822
+ else:
823
+ # Prepend inputs are the prepend conditioning + the global embed
824
+ prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1)
825
+ prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1)
826
+
827
+ prepend_length = prepend_inputs.shape[1]
828
+
829
+ x = self.preprocess_conv(x) + x
830
+
831
+ x = rearrange(x, "b c t -> b t c")
832
+
833
+ extra_args = {}
834
+
835
+ if self.global_cond_type == "adaLN":
836
+ extra_args["global_cond"] = global_embed
837
+
838
+ if self.patch_size > 1:
839
+ x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
840
+
841
+ if self.transformer_type == "x-transformers":
842
+ output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, **extra_args, **kwargs)
843
+ elif self.transformer_type == "continuous_transformer":
844
+ output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs)
845
+
846
+ if return_info:
847
+ output, info = output
848
+ elif self.transformer_type == "mm_transformer":
849
+ output = self.transformer(x, context=cross_attn_cond, mask=mask, context_mask=cross_attn_cond_mask, **extra_args, **kwargs)
850
+
851
+ output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:]
852
+
853
+ if self.patch_size > 1:
854
+ output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
855
+
856
+ output = self.postprocess_conv(output) + output
857
+
858
+ if return_info:
859
+ return output, info
860
+
861
+ return output
862
+
863
+ def forward(
864
+ self,
865
+ x,
866
+ timestep,
867
+ context=None,
868
+ context_mask=None,
869
+ input_concat_cond=None,
870
+ global_embed=None,
871
+ negative_global_embed=None,
872
+ prepend_cond=None,
873
+ prepend_cond_mask=None,
874
+ mask=None,
875
+ return_info=False,
876
+ control=None,
877
+ transformer_options={},
878
+ **kwargs):
879
+ return self._forward(
880
+ x,
881
+ timestep,
882
+ cross_attn_cond=context,
883
+ cross_attn_cond_mask=context_mask,
884
+ input_concat_cond=input_concat_cond,
885
+ global_embed=global_embed,
886
+ prepend_cond=prepend_cond,
887
+ prepend_cond_mask=prepend_cond_mask,
888
+ mask=mask,
889
+ return_info=return_info,
890
+ **kwargs
891
+ )
comfy/ldm/audio/embedders.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code adapted from: https://github.com/Stability-AI/stable-audio-tools
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch import Tensor, einsum
6
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
7
+ from einops import rearrange
8
+ import math
9
+ import comfy.ops
10
+
11
+ class LearnedPositionalEmbedding(nn.Module):
12
+ """Used for continuous time"""
13
+
14
+ def __init__(self, dim: int):
15
+ super().__init__()
16
+ assert (dim % 2) == 0
17
+ half_dim = dim // 2
18
+ self.weights = nn.Parameter(torch.empty(half_dim))
19
+
20
+ def forward(self, x: Tensor) -> Tensor:
21
+ x = rearrange(x, "b -> b 1")
22
+ freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi
23
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
24
+ fouriered = torch.cat((x, fouriered), dim=-1)
25
+ return fouriered
26
+
27
+ def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
28
+ return nn.Sequential(
29
+ LearnedPositionalEmbedding(dim),
30
+ comfy.ops.manual_cast.Linear(in_features=dim + 1, out_features=out_features),
31
+ )
32
+
33
+
34
+ class NumberEmbedder(nn.Module):
35
+ def __init__(
36
+ self,
37
+ features: int,
38
+ dim: int = 256,
39
+ ):
40
+ super().__init__()
41
+ self.features = features
42
+ self.embedding = TimePositionalEmbedding(dim=dim, out_features=features)
43
+
44
+ def forward(self, x: Union[List[float], Tensor]) -> Tensor:
45
+ if not torch.is_tensor(x):
46
+ device = next(self.embedding.parameters()).device
47
+ x = torch.tensor(x, device=device)
48
+ assert isinstance(x, Tensor)
49
+ shape = x.shape
50
+ x = rearrange(x, "... -> (...)")
51
+ embedding = self.embedding(x)
52
+ x = embedding.view(*shape, self.features)
53
+ return x # type: ignore
54
+
55
+
56
+ class Conditioner(nn.Module):
57
+ def __init__(
58
+ self,
59
+ dim: int,
60
+ output_dim: int,
61
+ project_out: bool = False
62
+ ):
63
+
64
+ super().__init__()
65
+
66
+ self.dim = dim
67
+ self.output_dim = output_dim
68
+ self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity()
69
+
70
+ def forward(self, x):
71
+ raise NotImplementedError()
72
+
73
+ class NumberConditioner(Conditioner):
74
+ '''
75
+ Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings
76
+ '''
77
+ def __init__(self,
78
+ output_dim: int,
79
+ min_val: float=0,
80
+ max_val: float=1
81
+ ):
82
+ super().__init__(output_dim, output_dim)
83
+
84
+ self.min_val = min_val
85
+ self.max_val = max_val
86
+
87
+ self.embedder = NumberEmbedder(features=output_dim)
88
+
89
+ def forward(self, floats, device=None):
90
+ # Cast the inputs to floats
91
+ floats = [float(x) for x in floats]
92
+
93
+ if device is None:
94
+ device = next(self.embedder.parameters()).device
95
+
96
+ floats = torch.tensor(floats).to(device)
97
+
98
+ floats = floats.clamp(self.min_val, self.max_val)
99
+
100
+ normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val)
101
+
102
+ # Cast floats to same type as embedder
103
+ embedder_dtype = next(self.embedder.parameters()).dtype
104
+ normalized_floats = normalized_floats.to(embedder_dtype)
105
+
106
+ float_embeds = self.embedder(normalized_floats).unsqueeze(1)
107
+
108
+ return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)]
comfy/ldm/aura/mmdit.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #AuraFlow MMDiT
2
+ #Originally written by the AuraFlow Authors
3
+
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from comfy.ldm.modules.attention import optimized_attention
11
+ import comfy.ops
12
+ import comfy.ldm.common_dit
13
+
14
+ def modulate(x, shift, scale):
15
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
16
+
17
+
18
+ def find_multiple(n: int, k: int) -> int:
19
+ if n % k == 0:
20
+ return n
21
+ return n + k - (n % k)
22
+
23
+
24
+ class MLP(nn.Module):
25
+ def __init__(self, dim, hidden_dim=None, dtype=None, device=None, operations=None) -> None:
26
+ super().__init__()
27
+ if hidden_dim is None:
28
+ hidden_dim = 4 * dim
29
+
30
+ n_hidden = int(2 * hidden_dim / 3)
31
+ n_hidden = find_multiple(n_hidden, 256)
32
+
33
+ self.c_fc1 = operations.Linear(dim, n_hidden, bias=False, dtype=dtype, device=device)
34
+ self.c_fc2 = operations.Linear(dim, n_hidden, bias=False, dtype=dtype, device=device)
35
+ self.c_proj = operations.Linear(n_hidden, dim, bias=False, dtype=dtype, device=device)
36
+
37
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
38
+ x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
39
+ x = self.c_proj(x)
40
+ return x
41
+
42
+
43
+ class MultiHeadLayerNorm(nn.Module):
44
+ def __init__(self, hidden_size=None, eps=1e-5, dtype=None, device=None):
45
+ # Copy pasta from https://github.com/huggingface/transformers/blob/e5f71ecaae50ea476d1e12351003790273c4b2ed/src/transformers/models/cohere/modeling_cohere.py#L78
46
+
47
+ super().__init__()
48
+ self.weight = nn.Parameter(torch.empty(hidden_size, dtype=dtype, device=device))
49
+ self.variance_epsilon = eps
50
+
51
+ def forward(self, hidden_states):
52
+ input_dtype = hidden_states.dtype
53
+ hidden_states = hidden_states.to(torch.float32)
54
+ mean = hidden_states.mean(-1, keepdim=True)
55
+ variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
56
+ hidden_states = (hidden_states - mean) * torch.rsqrt(
57
+ variance + self.variance_epsilon
58
+ )
59
+ hidden_states = self.weight.to(torch.float32) * hidden_states
60
+ return hidden_states.to(input_dtype)
61
+
62
+ class SingleAttention(nn.Module):
63
+ def __init__(self, dim, n_heads, mh_qknorm=False, dtype=None, device=None, operations=None):
64
+ super().__init__()
65
+
66
+ self.n_heads = n_heads
67
+ self.head_dim = dim // n_heads
68
+
69
+ # this is for cond
70
+ self.w1q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
71
+ self.w1k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
72
+ self.w1v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
73
+ self.w1o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
74
+
75
+ self.q_norm1 = (
76
+ MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
77
+ if mh_qknorm
78
+ else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
79
+ )
80
+ self.k_norm1 = (
81
+ MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
82
+ if mh_qknorm
83
+ else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
84
+ )
85
+
86
+ #@torch.compile()
87
+ def forward(self, c):
88
+
89
+ bsz, seqlen1, _ = c.shape
90
+
91
+ q, k, v = self.w1q(c), self.w1k(c), self.w1v(c)
92
+ q = q.view(bsz, seqlen1, self.n_heads, self.head_dim)
93
+ k = k.view(bsz, seqlen1, self.n_heads, self.head_dim)
94
+ v = v.view(bsz, seqlen1, self.n_heads, self.head_dim)
95
+ q, k = self.q_norm1(q), self.k_norm1(k)
96
+
97
+ output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
98
+ c = self.w1o(output)
99
+ return c
100
+
101
+
102
+
103
+ class DoubleAttention(nn.Module):
104
+ def __init__(self, dim, n_heads, mh_qknorm=False, dtype=None, device=None, operations=None):
105
+ super().__init__()
106
+
107
+ self.n_heads = n_heads
108
+ self.head_dim = dim // n_heads
109
+
110
+ # this is for cond
111
+ self.w1q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
112
+ self.w1k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
113
+ self.w1v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
114
+ self.w1o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
115
+
116
+ # this is for x
117
+ self.w2q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
118
+ self.w2k = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
119
+ self.w2v = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
120
+ self.w2o = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
121
+
122
+ self.q_norm1 = (
123
+ MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
124
+ if mh_qknorm
125
+ else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
126
+ )
127
+ self.k_norm1 = (
128
+ MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
129
+ if mh_qknorm
130
+ else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
131
+ )
132
+
133
+ self.q_norm2 = (
134
+ MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
135
+ if mh_qknorm
136
+ else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
137
+ )
138
+ self.k_norm2 = (
139
+ MultiHeadLayerNorm((self.n_heads, self.head_dim), dtype=dtype, device=device)
140
+ if mh_qknorm
141
+ else operations.LayerNorm(self.head_dim, elementwise_affine=False, dtype=dtype, device=device)
142
+ )
143
+
144
+
145
+ #@torch.compile()
146
+ def forward(self, c, x):
147
+
148
+ bsz, seqlen1, _ = c.shape
149
+ bsz, seqlen2, _ = x.shape
150
+ seqlen = seqlen1 + seqlen2
151
+
152
+ cq, ck, cv = self.w1q(c), self.w1k(c), self.w1v(c)
153
+ cq = cq.view(bsz, seqlen1, self.n_heads, self.head_dim)
154
+ ck = ck.view(bsz, seqlen1, self.n_heads, self.head_dim)
155
+ cv = cv.view(bsz, seqlen1, self.n_heads, self.head_dim)
156
+ cq, ck = self.q_norm1(cq), self.k_norm1(ck)
157
+
158
+ xq, xk, xv = self.w2q(x), self.w2k(x), self.w2v(x)
159
+ xq = xq.view(bsz, seqlen2, self.n_heads, self.head_dim)
160
+ xk = xk.view(bsz, seqlen2, self.n_heads, self.head_dim)
161
+ xv = xv.view(bsz, seqlen2, self.n_heads, self.head_dim)
162
+ xq, xk = self.q_norm2(xq), self.k_norm2(xk)
163
+
164
+ # concat all
165
+ q, k, v = (
166
+ torch.cat([cq, xq], dim=1),
167
+ torch.cat([ck, xk], dim=1),
168
+ torch.cat([cv, xv], dim=1),
169
+ )
170
+
171
+ output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
172
+
173
+ c, x = output.split([seqlen1, seqlen2], dim=1)
174
+ c = self.w1o(c)
175
+ x = self.w2o(x)
176
+
177
+ return c, x
178
+
179
+
180
+ class MMDiTBlock(nn.Module):
181
+ def __init__(self, dim, heads=8, global_conddim=1024, is_last=False, dtype=None, device=None, operations=None):
182
+ super().__init__()
183
+
184
+ self.normC1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
185
+ self.normC2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
186
+ if not is_last:
187
+ self.mlpC = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
188
+ self.modC = nn.Sequential(
189
+ nn.SiLU(),
190
+ operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device),
191
+ )
192
+ else:
193
+ self.modC = nn.Sequential(
194
+ nn.SiLU(),
195
+ operations.Linear(global_conddim, 2 * dim, bias=False, dtype=dtype, device=device),
196
+ )
197
+
198
+ self.normX1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
199
+ self.normX2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
200
+ self.mlpX = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
201
+ self.modX = nn.Sequential(
202
+ nn.SiLU(),
203
+ operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device),
204
+ )
205
+
206
+ self.attn = DoubleAttention(dim, heads, dtype=dtype, device=device, operations=operations)
207
+ self.is_last = is_last
208
+
209
+ #@torch.compile()
210
+ def forward(self, c, x, global_cond, **kwargs):
211
+
212
+ cres, xres = c, x
213
+
214
+ cshift_msa, cscale_msa, cgate_msa, cshift_mlp, cscale_mlp, cgate_mlp = (
215
+ self.modC(global_cond).chunk(6, dim=1)
216
+ )
217
+
218
+ c = modulate(self.normC1(c), cshift_msa, cscale_msa)
219
+
220
+ # xpath
221
+ xshift_msa, xscale_msa, xgate_msa, xshift_mlp, xscale_mlp, xgate_mlp = (
222
+ self.modX(global_cond).chunk(6, dim=1)
223
+ )
224
+
225
+ x = modulate(self.normX1(x), xshift_msa, xscale_msa)
226
+
227
+ # attention
228
+ c, x = self.attn(c, x)
229
+
230
+
231
+ c = self.normC2(cres + cgate_msa.unsqueeze(1) * c)
232
+ c = cgate_mlp.unsqueeze(1) * self.mlpC(modulate(c, cshift_mlp, cscale_mlp))
233
+ c = cres + c
234
+
235
+ x = self.normX2(xres + xgate_msa.unsqueeze(1) * x)
236
+ x = xgate_mlp.unsqueeze(1) * self.mlpX(modulate(x, xshift_mlp, xscale_mlp))
237
+ x = xres + x
238
+
239
+ return c, x
240
+
241
+ class DiTBlock(nn.Module):
242
+ # like MMDiTBlock, but it only has X
243
+ def __init__(self, dim, heads=8, global_conddim=1024, dtype=None, device=None, operations=None):
244
+ super().__init__()
245
+
246
+ self.norm1 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
247
+ self.norm2 = operations.LayerNorm(dim, elementwise_affine=False, dtype=dtype, device=device)
248
+
249
+ self.modCX = nn.Sequential(
250
+ nn.SiLU(),
251
+ operations.Linear(global_conddim, 6 * dim, bias=False, dtype=dtype, device=device),
252
+ )
253
+
254
+ self.attn = SingleAttention(dim, heads, dtype=dtype, device=device, operations=operations)
255
+ self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
256
+
257
+ #@torch.compile()
258
+ def forward(self, cx, global_cond, **kwargs):
259
+ cxres = cx
260
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX(
261
+ global_cond
262
+ ).chunk(6, dim=1)
263
+ cx = modulate(self.norm1(cx), shift_msa, scale_msa)
264
+ cx = self.attn(cx)
265
+ cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx)
266
+ mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
267
+ cx = gate_mlp.unsqueeze(1) * mlpout
268
+
269
+ cx = cxres + cx
270
+
271
+ return cx
272
+
273
+
274
+
275
+ class TimestepEmbedder(nn.Module):
276
+ def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
277
+ super().__init__()
278
+ self.mlp = nn.Sequential(
279
+ operations.Linear(frequency_embedding_size, hidden_size, dtype=dtype, device=device),
280
+ nn.SiLU(),
281
+ operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device),
282
+ )
283
+ self.frequency_embedding_size = frequency_embedding_size
284
+
285
+ @staticmethod
286
+ def timestep_embedding(t, dim, max_period=10000):
287
+ half = dim // 2
288
+ freqs = 1000 * torch.exp(
289
+ -math.log(max_period) * torch.arange(start=0, end=half) / half
290
+ ).to(t.device)
291
+ args = t[:, None] * freqs[None]
292
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
293
+ if dim % 2:
294
+ embedding = torch.cat(
295
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
296
+ )
297
+ return embedding
298
+
299
+ #@torch.compile()
300
+ def forward(self, t, dtype):
301
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
302
+ t_emb = self.mlp(t_freq)
303
+ return t_emb
304
+
305
+
306
+ class MMDiT(nn.Module):
307
+ def __init__(
308
+ self,
309
+ in_channels=4,
310
+ out_channels=4,
311
+ patch_size=2,
312
+ dim=3072,
313
+ n_layers=36,
314
+ n_double_layers=4,
315
+ n_heads=12,
316
+ global_conddim=3072,
317
+ cond_seq_dim=2048,
318
+ max_seq=32 * 32,
319
+ device=None,
320
+ dtype=None,
321
+ operations=None,
322
+ ):
323
+ super().__init__()
324
+ self.dtype = dtype
325
+
326
+ self.t_embedder = TimestepEmbedder(global_conddim, dtype=dtype, device=device, operations=operations)
327
+
328
+ self.cond_seq_linear = operations.Linear(
329
+ cond_seq_dim, dim, bias=False, dtype=dtype, device=device
330
+ ) # linear for something like text sequence.
331
+ self.init_x_linear = operations.Linear(
332
+ patch_size * patch_size * in_channels, dim, dtype=dtype, device=device
333
+ ) # init linear for patchified image.
334
+
335
+ self.positional_encoding = nn.Parameter(torch.empty(1, max_seq, dim, dtype=dtype, device=device))
336
+ self.register_tokens = nn.Parameter(torch.empty(1, 8, dim, dtype=dtype, device=device))
337
+
338
+ self.double_layers = nn.ModuleList([])
339
+ self.single_layers = nn.ModuleList([])
340
+
341
+
342
+ for idx in range(n_double_layers):
343
+ self.double_layers.append(
344
+ MMDiTBlock(dim, n_heads, global_conddim, is_last=(idx == n_layers - 1), dtype=dtype, device=device, operations=operations)
345
+ )
346
+
347
+ for idx in range(n_double_layers, n_layers):
348
+ self.single_layers.append(
349
+ DiTBlock(dim, n_heads, global_conddim, dtype=dtype, device=device, operations=operations)
350
+ )
351
+
352
+
353
+ self.final_linear = operations.Linear(
354
+ dim, patch_size * patch_size * out_channels, bias=False, dtype=dtype, device=device
355
+ )
356
+
357
+ self.modF = nn.Sequential(
358
+ nn.SiLU(),
359
+ operations.Linear(global_conddim, 2 * dim, bias=False, dtype=dtype, device=device),
360
+ )
361
+
362
+ self.out_channels = out_channels
363
+ self.patch_size = patch_size
364
+ self.n_double_layers = n_double_layers
365
+ self.n_layers = n_layers
366
+
367
+ self.h_max = round(max_seq**0.5)
368
+ self.w_max = round(max_seq**0.5)
369
+
370
+ @torch.no_grad()
371
+ def extend_pe(self, init_dim=(16, 16), target_dim=(64, 64)):
372
+ # extend pe
373
+ pe_data = self.positional_encoding.data.squeeze(0)[: init_dim[0] * init_dim[1]]
374
+
375
+ pe_as_2d = pe_data.view(init_dim[0], init_dim[1], -1).permute(2, 0, 1)
376
+
377
+ # now we need to extend this to target_dim. for this we will use interpolation.
378
+ # we will use torch.nn.functional.interpolate
379
+ pe_as_2d = F.interpolate(
380
+ pe_as_2d.unsqueeze(0), size=target_dim, mode="bilinear"
381
+ )
382
+ pe_new = pe_as_2d.squeeze(0).permute(1, 2, 0).flatten(0, 1)
383
+ self.positional_encoding.data = pe_new.unsqueeze(0).contiguous()
384
+ self.h_max, self.w_max = target_dim
385
+ print("PE extended to", target_dim)
386
+
387
+ def pe_selection_index_based_on_dim(self, h, w):
388
+ h_p, w_p = h // self.patch_size, w // self.patch_size
389
+ original_pe_indexes = torch.arange(self.positional_encoding.shape[1])
390
+ original_pe_indexes = original_pe_indexes.view(self.h_max, self.w_max)
391
+ starth = self.h_max // 2 - h_p // 2
392
+ endh =starth + h_p
393
+ startw = self.w_max // 2 - w_p // 2
394
+ endw = startw + w_p
395
+ original_pe_indexes = original_pe_indexes[
396
+ starth:endh, startw:endw
397
+ ]
398
+ return original_pe_indexes.flatten()
399
+
400
+ def unpatchify(self, x, h, w):
401
+ c = self.out_channels
402
+ p = self.patch_size
403
+
404
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
405
+ x = torch.einsum("nhwpqc->nchpwq", x)
406
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
407
+ return imgs
408
+
409
+ def patchify(self, x):
410
+ B, C, H, W = x.size()
411
+ x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
412
+ x = x.view(
413
+ B,
414
+ C,
415
+ (H + 1) // self.patch_size,
416
+ self.patch_size,
417
+ (W + 1) // self.patch_size,
418
+ self.patch_size,
419
+ )
420
+ x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
421
+ return x
422
+
423
+ def apply_pos_embeds(self, x, h, w):
424
+ h = (h + 1) // self.patch_size
425
+ w = (w + 1) // self.patch_size
426
+ max_dim = max(h, w)
427
+
428
+ cur_dim = self.h_max
429
+ pos_encoding = comfy.ops.cast_to_input(self.positional_encoding.reshape(1, cur_dim, cur_dim, -1), x)
430
+
431
+ if max_dim > cur_dim:
432
+ pos_encoding = F.interpolate(pos_encoding.movedim(-1, 1), (max_dim, max_dim), mode="bilinear").movedim(1, -1)
433
+ cur_dim = max_dim
434
+
435
+ from_h = (cur_dim - h) // 2
436
+ from_w = (cur_dim - w) // 2
437
+ pos_encoding = pos_encoding[:,from_h:from_h+h,from_w:from_w+w]
438
+ return x + pos_encoding.reshape(1, -1, self.positional_encoding.shape[-1])
439
+
440
+ def forward(self, x, timestep, context, **kwargs):
441
+ # patchify x, add PE
442
+ b, c, h, w = x.shape
443
+
444
+ # pe_indexes = self.pe_selection_index_based_on_dim(h, w)
445
+ # print(pe_indexes, pe_indexes.shape)
446
+
447
+ x = self.init_x_linear(self.patchify(x)) # B, T_x, D
448
+ x = self.apply_pos_embeds(x, h, w)
449
+ # x = x + self.positional_encoding[:, : x.size(1)].to(device=x.device, dtype=x.dtype)
450
+ # x = x + self.positional_encoding[:, pe_indexes].to(device=x.device, dtype=x.dtype)
451
+
452
+ # process conditions for MMDiT Blocks
453
+ c_seq = context # B, T_c, D_c
454
+ t = timestep
455
+
456
+ c = self.cond_seq_linear(c_seq) # B, T_c, D
457
+ c = torch.cat([comfy.ops.cast_to_input(self.register_tokens, c).repeat(c.size(0), 1, 1), c], dim=1)
458
+
459
+ global_cond = self.t_embedder(t, x.dtype) # B, D
460
+
461
+ if len(self.double_layers) > 0:
462
+ for layer in self.double_layers:
463
+ c, x = layer(c, x, global_cond, **kwargs)
464
+
465
+ if len(self.single_layers) > 0:
466
+ c_len = c.size(1)
467
+ cx = torch.cat([c, x], dim=1)
468
+ for layer in self.single_layers:
469
+ cx = layer(cx, global_cond, **kwargs)
470
+
471
+ x = cx[:, c_len:]
472
+
473
+ fshift, fscale = self.modF(global_cond).chunk(2, dim=1)
474
+
475
+ x = modulate(x, fshift, fscale)
476
+ x = self.final_linear(x)
477
+ x = self.unpatchify(x, (h + 1) // self.patch_size, (w + 1) // self.patch_size)[:,:,:h,:w]
478
+ return x
comfy/ldm/cascade/common.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is part of ComfyUI.
3
+ Copyright (C) 2024 Stability AI
4
+
5
+ This program is free software: you can redistribute it and/or modify
6
+ it under the terms of the GNU General Public License as published by
7
+ the Free Software Foundation, either version 3 of the License, or
8
+ (at your option) any later version.
9
+
10
+ This program is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ GNU General Public License for more details.
14
+
15
+ You should have received a copy of the GNU General Public License
16
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ """
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ from comfy.ldm.modules.attention import optimized_attention
22
+ import comfy.ops
23
+
24
+ class OptimizedAttention(nn.Module):
25
+ def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
26
+ super().__init__()
27
+ self.heads = nhead
28
+
29
+ self.to_q = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
30
+ self.to_k = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
31
+ self.to_v = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
32
+
33
+ self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
34
+
35
+ def forward(self, q, k, v):
36
+ q = self.to_q(q)
37
+ k = self.to_k(k)
38
+ v = self.to_v(v)
39
+
40
+ out = optimized_attention(q, k, v, self.heads)
41
+
42
+ return self.out_proj(out)
43
+
44
+ class Attention2D(nn.Module):
45
+ def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
46
+ super().__init__()
47
+ self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
48
+ # self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
49
+
50
+ def forward(self, x, kv, self_attn=False):
51
+ orig_shape = x.shape
52
+ x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
53
+ if self_attn:
54
+ kv = torch.cat([x, kv], dim=1)
55
+ # x = self.attn(x, kv, kv, need_weights=False)[0]
56
+ x = self.attn(x, kv, kv)
57
+ x = x.permute(0, 2, 1).view(*orig_shape)
58
+ return x
59
+
60
+
61
+ def LayerNorm2d_op(operations):
62
+ class LayerNorm2d(operations.LayerNorm):
63
+ def __init__(self, *args, **kwargs):
64
+ super().__init__(*args, **kwargs)
65
+
66
+ def forward(self, x):
67
+ return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
68
+ return LayerNorm2d
69
+
70
+ class GlobalResponseNorm(nn.Module):
71
+ "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
72
+ def __init__(self, dim, dtype=None, device=None):
73
+ super().__init__()
74
+ self.gamma = nn.Parameter(torch.empty(1, 1, 1, dim, dtype=dtype, device=device))
75
+ self.beta = nn.Parameter(torch.empty(1, 1, 1, dim, dtype=dtype, device=device))
76
+
77
+ def forward(self, x):
78
+ Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
79
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
80
+ return comfy.ops.cast_to_input(self.gamma, x) * (x * Nx) + comfy.ops.cast_to_input(self.beta, x) + x
81
+
82
+
83
+ class ResBlock(nn.Module):
84
+ def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0, dtype=None, device=None, operations=None): # , num_heads=4, expansion=2):
85
+ super().__init__()
86
+ self.depthwise = operations.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c, dtype=dtype, device=device)
87
+ # self.depthwise = SAMBlock(c, num_heads, expansion)
88
+ self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
89
+ self.channelwise = nn.Sequential(
90
+ operations.Linear(c + c_skip, c * 4, dtype=dtype, device=device),
91
+ nn.GELU(),
92
+ GlobalResponseNorm(c * 4, dtype=dtype, device=device),
93
+ nn.Dropout(dropout),
94
+ operations.Linear(c * 4, c, dtype=dtype, device=device)
95
+ )
96
+
97
+ def forward(self, x, x_skip=None):
98
+ x_res = x
99
+ x = self.norm(self.depthwise(x))
100
+ if x_skip is not None:
101
+ x = torch.cat([x, x_skip], dim=1)
102
+ x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
103
+ return x + x_res
104
+
105
+
106
+ class AttnBlock(nn.Module):
107
+ def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, dtype=None, device=None, operations=None):
108
+ super().__init__()
109
+ self.self_attn = self_attn
110
+ self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
111
+ self.attention = Attention2D(c, nhead, dropout, dtype=dtype, device=device, operations=operations)
112
+ self.kv_mapper = nn.Sequential(
113
+ nn.SiLU(),
114
+ operations.Linear(c_cond, c, dtype=dtype, device=device)
115
+ )
116
+
117
+ def forward(self, x, kv):
118
+ kv = self.kv_mapper(kv)
119
+ x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
120
+ return x
121
+
122
+
123
+ class FeedForwardBlock(nn.Module):
124
+ def __init__(self, c, dropout=0.0, dtype=None, device=None, operations=None):
125
+ super().__init__()
126
+ self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
127
+ self.channelwise = nn.Sequential(
128
+ operations.Linear(c, c * 4, dtype=dtype, device=device),
129
+ nn.GELU(),
130
+ GlobalResponseNorm(c * 4, dtype=dtype, device=device),
131
+ nn.Dropout(dropout),
132
+ operations.Linear(c * 4, c, dtype=dtype, device=device)
133
+ )
134
+
135
+ def forward(self, x):
136
+ x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
137
+ return x
138
+
139
+
140
+ class TimestepBlock(nn.Module):
141
+ def __init__(self, c, c_timestep, conds=['sca'], dtype=None, device=None, operations=None):
142
+ super().__init__()
143
+ self.mapper = operations.Linear(c_timestep, c * 2, dtype=dtype, device=device)
144
+ self.conds = conds
145
+ for cname in conds:
146
+ setattr(self, f"mapper_{cname}", operations.Linear(c_timestep, c * 2, dtype=dtype, device=device))
147
+
148
+ def forward(self, x, t):
149
+ t = t.chunk(len(self.conds) + 1, dim=1)
150
+ a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1)
151
+ for i, c in enumerate(self.conds):
152
+ ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1)
153
+ a, b = a + ac, b + bc
154
+ return x * (1 + a) + b
comfy/ldm/cascade/controlnet.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is part of ComfyUI.
3
+ Copyright (C) 2024 Stability AI
4
+
5
+ This program is free software: you can redistribute it and/or modify
6
+ it under the terms of the GNU General Public License as published by
7
+ the Free Software Foundation, either version 3 of the License, or
8
+ (at your option) any later version.
9
+
10
+ This program is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ GNU General Public License for more details.
14
+
15
+ You should have received a copy of the GNU General Public License
16
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ """
18
+
19
+ import torch
20
+ import torchvision
21
+ from torch import nn
22
+ from .common import LayerNorm2d_op
23
+
24
+
25
+ class CNetResBlock(nn.Module):
26
+ def __init__(self, c, dtype=None, device=None, operations=None):
27
+ super().__init__()
28
+ self.blocks = nn.Sequential(
29
+ LayerNorm2d_op(operations)(c, dtype=dtype, device=device),
30
+ nn.GELU(),
31
+ operations.Conv2d(c, c, kernel_size=3, padding=1),
32
+ LayerNorm2d_op(operations)(c, dtype=dtype, device=device),
33
+ nn.GELU(),
34
+ operations.Conv2d(c, c, kernel_size=3, padding=1),
35
+ )
36
+
37
+ def forward(self, x):
38
+ return x + self.blocks(x)
39
+
40
+
41
+ class ControlNet(nn.Module):
42
+ def __init__(self, c_in=3, c_proj=2048, proj_blocks=None, bottleneck_mode=None, dtype=None, device=None, operations=nn):
43
+ super().__init__()
44
+ if bottleneck_mode is None:
45
+ bottleneck_mode = 'effnet'
46
+ self.proj_blocks = proj_blocks
47
+ if bottleneck_mode == 'effnet':
48
+ embd_channels = 1280
49
+ self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
50
+ if c_in != 3:
51
+ in_weights = self.backbone[0][0].weight.data
52
+ self.backbone[0][0] = operations.Conv2d(c_in, 24, kernel_size=3, stride=2, bias=False, dtype=dtype, device=device)
53
+ if c_in > 3:
54
+ # nn.init.constant_(self.backbone[0][0].weight, 0)
55
+ self.backbone[0][0].weight.data[:, :3] = in_weights[:, :3].clone()
56
+ else:
57
+ self.backbone[0][0].weight.data = in_weights[:, :c_in].clone()
58
+ elif bottleneck_mode == 'simple':
59
+ embd_channels = c_in
60
+ self.backbone = nn.Sequential(
61
+ operations.Conv2d(embd_channels, embd_channels * 4, kernel_size=3, padding=1, dtype=dtype, device=device),
62
+ nn.LeakyReLU(0.2, inplace=True),
63
+ operations.Conv2d(embd_channels * 4, embd_channels, kernel_size=3, padding=1, dtype=dtype, device=device),
64
+ )
65
+ elif bottleneck_mode == 'large':
66
+ self.backbone = nn.Sequential(
67
+ operations.Conv2d(c_in, 4096 * 4, kernel_size=1, dtype=dtype, device=device),
68
+ nn.LeakyReLU(0.2, inplace=True),
69
+ operations.Conv2d(4096 * 4, 1024, kernel_size=1, dtype=dtype, device=device),
70
+ *[CNetResBlock(1024, dtype=dtype, device=device, operations=operations) for _ in range(8)],
71
+ operations.Conv2d(1024, 1280, kernel_size=1, dtype=dtype, device=device),
72
+ )
73
+ embd_channels = 1280
74
+ else:
75
+ raise ValueError(f'Unknown bottleneck mode: {bottleneck_mode}')
76
+ self.projections = nn.ModuleList()
77
+ for _ in range(len(proj_blocks)):
78
+ self.projections.append(nn.Sequential(
79
+ operations.Conv2d(embd_channels, embd_channels, kernel_size=1, bias=False, dtype=dtype, device=device),
80
+ nn.LeakyReLU(0.2, inplace=True),
81
+ operations.Conv2d(embd_channels, c_proj, kernel_size=1, bias=False, dtype=dtype, device=device),
82
+ ))
83
+ # nn.init.constant_(self.projections[-1][-1].weight, 0) # zero output projection
84
+ self.xl = False
85
+ self.input_channels = c_in
86
+ self.unshuffle_amount = 8
87
+
88
+ def forward(self, x):
89
+ x = self.backbone(x)
90
+ proj_outputs = [None for _ in range(max(self.proj_blocks) + 1)]
91
+ for i, idx in enumerate(self.proj_blocks):
92
+ proj_outputs[idx] = self.projections[i](x)
93
+ return {"input": proj_outputs[::-1]}
comfy/ldm/cascade/stage_a.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is part of ComfyUI.
3
+ Copyright (C) 2024 Stability AI
4
+
5
+ This program is free software: you can redistribute it and/or modify
6
+ it under the terms of the GNU General Public License as published by
7
+ the Free Software Foundation, either version 3 of the License, or
8
+ (at your option) any later version.
9
+
10
+ This program is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ GNU General Public License for more details.
14
+
15
+ You should have received a copy of the GNU General Public License
16
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ """
18
+
19
+ import torch
20
+ from torch import nn
21
+ from torch.autograd import Function
22
+
23
+ class vector_quantize(Function):
24
+ @staticmethod
25
+ def forward(ctx, x, codebook):
26
+ with torch.no_grad():
27
+ codebook_sqr = torch.sum(codebook ** 2, dim=1)
28
+ x_sqr = torch.sum(x ** 2, dim=1, keepdim=True)
29
+
30
+ dist = torch.addmm(codebook_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0)
31
+ _, indices = dist.min(dim=1)
32
+
33
+ ctx.save_for_backward(indices, codebook)
34
+ ctx.mark_non_differentiable(indices)
35
+
36
+ nn = torch.index_select(codebook, 0, indices)
37
+ return nn, indices
38
+
39
+ @staticmethod
40
+ def backward(ctx, grad_output, grad_indices):
41
+ grad_inputs, grad_codebook = None, None
42
+
43
+ if ctx.needs_input_grad[0]:
44
+ grad_inputs = grad_output.clone()
45
+ if ctx.needs_input_grad[1]:
46
+ # Gradient wrt. the codebook
47
+ indices, codebook = ctx.saved_tensors
48
+
49
+ grad_codebook = torch.zeros_like(codebook)
50
+ grad_codebook.index_add_(0, indices, grad_output)
51
+
52
+ return (grad_inputs, grad_codebook)
53
+
54
+
55
+ class VectorQuantize(nn.Module):
56
+ def __init__(self, embedding_size, k, ema_decay=0.99, ema_loss=False):
57
+ """
58
+ Takes an input of variable size (as long as the last dimension matches the embedding size).
59
+ Returns one tensor containing the nearest neigbour embeddings to each of the inputs,
60
+ with the same size as the input, vq and commitment components for the loss as a touple
61
+ in the second output and the indices of the quantized vectors in the third:
62
+ quantized, (vq_loss, commit_loss), indices
63
+ """
64
+ super(VectorQuantize, self).__init__()
65
+
66
+ self.codebook = nn.Embedding(k, embedding_size)
67
+ self.codebook.weight.data.uniform_(-1./k, 1./k)
68
+ self.vq = vector_quantize.apply
69
+
70
+ self.ema_decay = ema_decay
71
+ self.ema_loss = ema_loss
72
+ if ema_loss:
73
+ self.register_buffer('ema_element_count', torch.ones(k))
74
+ self.register_buffer('ema_weight_sum', torch.zeros_like(self.codebook.weight))
75
+
76
+ def _laplace_smoothing(self, x, epsilon):
77
+ n = torch.sum(x)
78
+ return ((x + epsilon) / (n + x.size(0) * epsilon) * n)
79
+
80
+ def _updateEMA(self, z_e_x, indices):
81
+ mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()
82
+ elem_count = mask.sum(dim=0)
83
+ weight_sum = torch.mm(mask.t(), z_e_x)
84
+
85
+ self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count)
86
+ self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5)
87
+ self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum)
88
+
89
+ self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
90
+
91
+ def idx2vq(self, idx, dim=-1):
92
+ q_idx = self.codebook(idx)
93
+ if dim != -1:
94
+ q_idx = q_idx.movedim(-1, dim)
95
+ return q_idx
96
+
97
+ def forward(self, x, get_losses=True, dim=-1):
98
+ if dim != -1:
99
+ x = x.movedim(dim, -1)
100
+ z_e_x = x.contiguous().view(-1, x.size(-1)) if len(x.shape) > 2 else x
101
+ z_q_x, indices = self.vq(z_e_x, self.codebook.weight.detach())
102
+ vq_loss, commit_loss = None, None
103
+ if self.ema_loss and self.training:
104
+ self._updateEMA(z_e_x.detach(), indices.detach())
105
+ # pick the graded embeddings after updating the codebook in order to have a more accurate commitment loss
106
+ z_q_x_grd = torch.index_select(self.codebook.weight, dim=0, index=indices)
107
+ if get_losses:
108
+ vq_loss = (z_q_x_grd - z_e_x.detach()).pow(2).mean()
109
+ commit_loss = (z_e_x - z_q_x_grd.detach()).pow(2).mean()
110
+
111
+ z_q_x = z_q_x.view(x.shape)
112
+ if dim != -1:
113
+ z_q_x = z_q_x.movedim(-1, dim)
114
+ return z_q_x, (vq_loss, commit_loss), indices.view(x.shape[:-1])
115
+
116
+
117
+ class ResBlock(nn.Module):
118
+ def __init__(self, c, c_hidden):
119
+ super().__init__()
120
+ # depthwise/attention
121
+ self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
122
+ self.depthwise = nn.Sequential(
123
+ nn.ReplicationPad2d(1),
124
+ nn.Conv2d(c, c, kernel_size=3, groups=c)
125
+ )
126
+
127
+ # channelwise
128
+ self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
129
+ self.channelwise = nn.Sequential(
130
+ nn.Linear(c, c_hidden),
131
+ nn.GELU(),
132
+ nn.Linear(c_hidden, c),
133
+ )
134
+
135
+ self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
136
+
137
+ # Init weights
138
+ def _basic_init(module):
139
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
140
+ torch.nn.init.xavier_uniform_(module.weight)
141
+ if module.bias is not None:
142
+ nn.init.constant_(module.bias, 0)
143
+
144
+ self.apply(_basic_init)
145
+
146
+ def _norm(self, x, norm):
147
+ return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
148
+
149
+ def forward(self, x):
150
+ mods = self.gammas
151
+
152
+ x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1]
153
+ try:
154
+ x = x + self.depthwise(x_temp) * mods[2]
155
+ except: #operation not implemented for bf16
156
+ x_temp = self.depthwise[0](x_temp.float()).to(x.dtype)
157
+ x = x + self.depthwise[1](x_temp) * mods[2]
158
+
159
+ x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4]
160
+ x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5]
161
+
162
+ return x
163
+
164
+
165
+ class StageA(nn.Module):
166
+ def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192):
167
+ super().__init__()
168
+ self.c_latent = c_latent
169
+ c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))]
170
+
171
+ # Encoder blocks
172
+ self.in_block = nn.Sequential(
173
+ nn.PixelUnshuffle(2),
174
+ nn.Conv2d(3 * 4, c_levels[0], kernel_size=1)
175
+ )
176
+ down_blocks = []
177
+ for i in range(levels):
178
+ if i > 0:
179
+ down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
180
+ block = ResBlock(c_levels[i], c_levels[i] * 4)
181
+ down_blocks.append(block)
182
+ down_blocks.append(nn.Sequential(
183
+ nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
184
+ nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
185
+ ))
186
+ self.down_blocks = nn.Sequential(*down_blocks)
187
+ self.down_blocks[0]
188
+
189
+ self.codebook_size = codebook_size
190
+ self.vquantizer = VectorQuantize(c_latent, k=codebook_size)
191
+
192
+ # Decoder blocks
193
+ up_blocks = [nn.Sequential(
194
+ nn.Conv2d(c_latent, c_levels[-1], kernel_size=1)
195
+ )]
196
+ for i in range(levels):
197
+ for j in range(bottleneck_blocks if i == 0 else 1):
198
+ block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4)
199
+ up_blocks.append(block)
200
+ if i < levels - 1:
201
+ up_blocks.append(
202
+ nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
203
+ padding=1))
204
+ self.up_blocks = nn.Sequential(*up_blocks)
205
+ self.out_block = nn.Sequential(
206
+ nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
207
+ nn.PixelShuffle(2),
208
+ )
209
+
210
+ def encode(self, x, quantize=False):
211
+ x = self.in_block(x)
212
+ x = self.down_blocks(x)
213
+ if quantize:
214
+ qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1)
215
+ return qe, x, indices, vq_loss + commit_loss * 0.25
216
+ else:
217
+ return x
218
+
219
+ def decode(self, x):
220
+ x = self.up_blocks(x)
221
+ x = self.out_block(x)
222
+ return x
223
+
224
+ def forward(self, x, quantize=False):
225
+ qe, x, _, vq_loss = self.encode(x, quantize)
226
+ x = self.decode(qe)
227
+ return x, vq_loss
228
+
229
+
230
+ class Discriminator(nn.Module):
231
+ def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6):
232
+ super().__init__()
233
+ d = max(depth - 3, 3)
234
+ layers = [
235
+ nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
236
+ nn.LeakyReLU(0.2),
237
+ ]
238
+ for i in range(depth - 1):
239
+ c_in = c_hidden // (2 ** max((d - i), 0))
240
+ c_out = c_hidden // (2 ** max((d - 1 - i), 0))
241
+ layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
242
+ layers.append(nn.InstanceNorm2d(c_out))
243
+ layers.append(nn.LeakyReLU(0.2))
244
+ self.encoder = nn.Sequential(*layers)
245
+ self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
246
+ self.logits = nn.Sigmoid()
247
+
248
+ def forward(self, x, cond=None):
249
+ x = self.encoder(x)
250
+ if cond is not None:
251
+ cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1))
252
+ x = torch.cat([x, cond], dim=1)
253
+ x = self.shuffle(x)
254
+ x = self.logits(x)
255
+ return x
comfy/ldm/cascade/stage_b.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is part of ComfyUI.
3
+ Copyright (C) 2024 Stability AI
4
+
5
+ This program is free software: you can redistribute it and/or modify
6
+ it under the terms of the GNU General Public License as published by
7
+ the Free Software Foundation, either version 3 of the License, or
8
+ (at your option) any later version.
9
+
10
+ This program is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ GNU General Public License for more details.
14
+
15
+ You should have received a copy of the GNU General Public License
16
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ """
18
+
19
+ import math
20
+ import torch
21
+ from torch import nn
22
+ from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
23
+
24
+ class StageB(nn.Module):
25
+ def __init__(self, c_in=4, c_out=4, c_r=64, patch_size=2, c_cond=1280, c_hidden=[320, 640, 1280, 1280],
26
+ nhead=[-1, -1, 20, 20], blocks=[[2, 6, 28, 6], [6, 28, 6, 2]],
27
+ block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]], level_config=['CT', 'CT', 'CTA', 'CTA'], c_clip=1280,
28
+ c_clip_seq=4, c_effnet=16, c_pixels=3, kernel_size=3, dropout=[0, 0, 0.0, 0.0], self_attn=True,
29
+ t_conds=['sca'], stable_cascade_stage=None, dtype=None, device=None, operations=None):
30
+ super().__init__()
31
+ self.dtype = dtype
32
+ self.c_r = c_r
33
+ self.t_conds = t_conds
34
+ self.c_clip_seq = c_clip_seq
35
+ if not isinstance(dropout, list):
36
+ dropout = [dropout] * len(c_hidden)
37
+ if not isinstance(self_attn, list):
38
+ self_attn = [self_attn] * len(c_hidden)
39
+
40
+ # CONDITIONING
41
+ self.effnet_mapper = nn.Sequential(
42
+ operations.Conv2d(c_effnet, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device),
43
+ nn.GELU(),
44
+ operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device),
45
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
46
+ )
47
+ self.pixels_mapper = nn.Sequential(
48
+ operations.Conv2d(c_pixels, c_hidden[0] * 4, kernel_size=1, dtype=dtype, device=device),
49
+ nn.GELU(),
50
+ operations.Conv2d(c_hidden[0] * 4, c_hidden[0], kernel_size=1, dtype=dtype, device=device),
51
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
52
+ )
53
+ self.clip_mapper = operations.Linear(c_clip, c_cond * c_clip_seq, dtype=dtype, device=device)
54
+ self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
55
+
56
+ self.embedding = nn.Sequential(
57
+ nn.PixelUnshuffle(patch_size),
58
+ operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device),
59
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
60
+ )
61
+
62
+ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
63
+ if block_type == 'C':
64
+ return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations)
65
+ elif block_type == 'A':
66
+ return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations)
67
+ elif block_type == 'F':
68
+ return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations)
69
+ elif block_type == 'T':
70
+ return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations)
71
+ else:
72
+ raise Exception(f'Block type {block_type} not supported')
73
+
74
+ # BLOCKS
75
+ # -- down blocks
76
+ self.down_blocks = nn.ModuleList()
77
+ self.down_downscalers = nn.ModuleList()
78
+ self.down_repeat_mappers = nn.ModuleList()
79
+ for i in range(len(c_hidden)):
80
+ if i > 0:
81
+ self.down_downscalers.append(nn.Sequential(
82
+ LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
83
+ operations.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2, dtype=dtype, device=device),
84
+ ))
85
+ else:
86
+ self.down_downscalers.append(nn.Identity())
87
+ down_block = nn.ModuleList()
88
+ for _ in range(blocks[0][i]):
89
+ for block_type in level_config[i]:
90
+ block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
91
+ down_block.append(block)
92
+ self.down_blocks.append(down_block)
93
+ if block_repeat is not None:
94
+ block_repeat_mappers = nn.ModuleList()
95
+ for _ in range(block_repeat[0][i] - 1):
96
+ block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
97
+ self.down_repeat_mappers.append(block_repeat_mappers)
98
+
99
+ # -- up blocks
100
+ self.up_blocks = nn.ModuleList()
101
+ self.up_upscalers = nn.ModuleList()
102
+ self.up_repeat_mappers = nn.ModuleList()
103
+ for i in reversed(range(len(c_hidden))):
104
+ if i > 0:
105
+ self.up_upscalers.append(nn.Sequential(
106
+ LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
107
+ operations.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2, dtype=dtype, device=device),
108
+ ))
109
+ else:
110
+ self.up_upscalers.append(nn.Identity())
111
+ up_block = nn.ModuleList()
112
+ for j in range(blocks[1][::-1][i]):
113
+ for k, block_type in enumerate(level_config[i]):
114
+ c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
115
+ block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
116
+ self_attn=self_attn[i])
117
+ up_block.append(block)
118
+ self.up_blocks.append(up_block)
119
+ if block_repeat is not None:
120
+ block_repeat_mappers = nn.ModuleList()
121
+ for _ in range(block_repeat[1][::-1][i] - 1):
122
+ block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
123
+ self.up_repeat_mappers.append(block_repeat_mappers)
124
+
125
+ # OUTPUT
126
+ self.clf = nn.Sequential(
127
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
128
+ operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device),
129
+ nn.PixelShuffle(patch_size),
130
+ )
131
+
132
+ # --- WEIGHT INIT ---
133
+ # self.apply(self._init_weights) # General init
134
+ # nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings
135
+ # nn.init.normal_(self.effnet_mapper[0].weight, std=0.02) # conditionings
136
+ # nn.init.normal_(self.effnet_mapper[2].weight, std=0.02) # conditionings
137
+ # nn.init.normal_(self.pixels_mapper[0].weight, std=0.02) # conditionings
138
+ # nn.init.normal_(self.pixels_mapper[2].weight, std=0.02) # conditionings
139
+ # torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
140
+ # nn.init.constant_(self.clf[1].weight, 0) # outputs
141
+ #
142
+ # # blocks
143
+ # for level_block in self.down_blocks + self.up_blocks:
144
+ # for block in level_block:
145
+ # if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
146
+ # block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
147
+ # elif isinstance(block, TimestepBlock):
148
+ # for layer in block.modules():
149
+ # if isinstance(layer, nn.Linear):
150
+ # nn.init.constant_(layer.weight, 0)
151
+ #
152
+ # def _init_weights(self, m):
153
+ # if isinstance(m, (nn.Conv2d, nn.Linear)):
154
+ # torch.nn.init.xavier_uniform_(m.weight)
155
+ # if m.bias is not None:
156
+ # nn.init.constant_(m.bias, 0)
157
+
158
+ def gen_r_embedding(self, r, max_positions=10000):
159
+ r = r * max_positions
160
+ half_dim = self.c_r // 2
161
+ emb = math.log(max_positions) / (half_dim - 1)
162
+ emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
163
+ emb = r[:, None] * emb[None, :]
164
+ emb = torch.cat([emb.sin(), emb.cos()], dim=1)
165
+ if self.c_r % 2 == 1: # zero pad
166
+ emb = nn.functional.pad(emb, (0, 1), mode='constant')
167
+ return emb
168
+
169
+ def gen_c_embeddings(self, clip):
170
+ if len(clip.shape) == 2:
171
+ clip = clip.unsqueeze(1)
172
+ clip = self.clip_mapper(clip).view(clip.size(0), clip.size(1) * self.c_clip_seq, -1)
173
+ clip = self.clip_norm(clip)
174
+ return clip
175
+
176
+ def _down_encode(self, x, r_embed, clip):
177
+ level_outputs = []
178
+ block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
179
+ for down_block, downscaler, repmap in block_group:
180
+ x = downscaler(x)
181
+ for i in range(len(repmap) + 1):
182
+ for block in down_block:
183
+ if isinstance(block, ResBlock) or (
184
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
185
+ ResBlock)):
186
+ x = block(x)
187
+ elif isinstance(block, AttnBlock) or (
188
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
189
+ AttnBlock)):
190
+ x = block(x, clip)
191
+ elif isinstance(block, TimestepBlock) or (
192
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
193
+ TimestepBlock)):
194
+ x = block(x, r_embed)
195
+ else:
196
+ x = block(x)
197
+ if i < len(repmap):
198
+ x = repmap[i](x)
199
+ level_outputs.insert(0, x)
200
+ return level_outputs
201
+
202
+ def _up_decode(self, level_outputs, r_embed, clip):
203
+ x = level_outputs[0]
204
+ block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
205
+ for i, (up_block, upscaler, repmap) in enumerate(block_group):
206
+ for j in range(len(repmap) + 1):
207
+ for k, block in enumerate(up_block):
208
+ if isinstance(block, ResBlock) or (
209
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
210
+ ResBlock)):
211
+ skip = level_outputs[i] if k == 0 and i > 0 else None
212
+ if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
213
+ x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear',
214
+ align_corners=True)
215
+ x = block(x, skip)
216
+ elif isinstance(block, AttnBlock) or (
217
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
218
+ AttnBlock)):
219
+ x = block(x, clip)
220
+ elif isinstance(block, TimestepBlock) or (
221
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
222
+ TimestepBlock)):
223
+ x = block(x, r_embed)
224
+ else:
225
+ x = block(x)
226
+ if j < len(repmap):
227
+ x = repmap[j](x)
228
+ x = upscaler(x)
229
+ return x
230
+
231
+ def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
232
+ if pixels is None:
233
+ pixels = x.new_zeros(x.size(0), 3, 8, 8)
234
+
235
+ # Process the conditioning embeddings
236
+ r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
237
+ for c in self.t_conds:
238
+ t_cond = kwargs.get(c, torch.zeros_like(r))
239
+ r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1)
240
+ clip = self.gen_c_embeddings(clip)
241
+
242
+ # Model Blocks
243
+ x = self.embedding(x)
244
+ x = x + self.effnet_mapper(
245
+ nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True))
246
+ x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear',
247
+ align_corners=True)
248
+ level_outputs = self._down_encode(x, r_embed, clip)
249
+ x = self._up_decode(level_outputs, r_embed, clip)
250
+ return self.clf(x)
251
+
252
+ def update_weights_ema(self, src_model, beta=0.999):
253
+ for self_params, src_params in zip(self.parameters(), src_model.parameters()):
254
+ self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
255
+ for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
256
+ self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
comfy/ldm/cascade/stage_c.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is part of ComfyUI.
3
+ Copyright (C) 2024 Stability AI
4
+
5
+ This program is free software: you can redistribute it and/or modify
6
+ it under the terms of the GNU General Public License as published by
7
+ the Free Software Foundation, either version 3 of the License, or
8
+ (at your option) any later version.
9
+
10
+ This program is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ GNU General Public License for more details.
14
+
15
+ You should have received a copy of the GNU General Public License
16
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ """
18
+
19
+ import torch
20
+ from torch import nn
21
+ import math
22
+ from .common import AttnBlock, LayerNorm2d_op, ResBlock, FeedForwardBlock, TimestepBlock
23
+ # from .controlnet import ControlNetDeliverer
24
+
25
+ class UpDownBlock2d(nn.Module):
26
+ def __init__(self, c_in, c_out, mode, enabled=True, dtype=None, device=None, operations=None):
27
+ super().__init__()
28
+ assert mode in ['up', 'down']
29
+ interpolation = nn.Upsample(scale_factor=2 if mode == 'up' else 0.5, mode='bilinear',
30
+ align_corners=True) if enabled else nn.Identity()
31
+ mapping = operations.Conv2d(c_in, c_out, kernel_size=1, dtype=dtype, device=device)
32
+ self.blocks = nn.ModuleList([interpolation, mapping] if mode == 'up' else [mapping, interpolation])
33
+
34
+ def forward(self, x):
35
+ for block in self.blocks:
36
+ x = block(x)
37
+ return x
38
+
39
+
40
+ class StageC(nn.Module):
41
+ def __init__(self, c_in=16, c_out=16, c_r=64, patch_size=1, c_cond=2048, c_hidden=[2048, 2048], nhead=[32, 32],
42
+ blocks=[[8, 24], [24, 8]], block_repeat=[[1, 1], [1, 1]], level_config=['CTA', 'CTA'],
43
+ c_clip_text=1280, c_clip_text_pooled=1280, c_clip_img=768, c_clip_seq=4, kernel_size=3,
44
+ dropout=[0.0, 0.0], self_attn=True, t_conds=['sca', 'crp'], switch_level=[False], stable_cascade_stage=None,
45
+ dtype=None, device=None, operations=None):
46
+ super().__init__()
47
+ self.dtype = dtype
48
+ self.c_r = c_r
49
+ self.t_conds = t_conds
50
+ self.c_clip_seq = c_clip_seq
51
+ if not isinstance(dropout, list):
52
+ dropout = [dropout] * len(c_hidden)
53
+ if not isinstance(self_attn, list):
54
+ self_attn = [self_attn] * len(c_hidden)
55
+
56
+ # CONDITIONING
57
+ self.clip_txt_mapper = operations.Linear(c_clip_text, c_cond, dtype=dtype, device=device)
58
+ self.clip_txt_pooled_mapper = operations.Linear(c_clip_text_pooled, c_cond * c_clip_seq, dtype=dtype, device=device)
59
+ self.clip_img_mapper = operations.Linear(c_clip_img, c_cond * c_clip_seq, dtype=dtype, device=device)
60
+ self.clip_norm = operations.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
61
+
62
+ self.embedding = nn.Sequential(
63
+ nn.PixelUnshuffle(patch_size),
64
+ operations.Conv2d(c_in * (patch_size ** 2), c_hidden[0], kernel_size=1, dtype=dtype, device=device),
65
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6)
66
+ )
67
+
68
+ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0, self_attn=True):
69
+ if block_type == 'C':
70
+ return ResBlock(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout, dtype=dtype, device=device, operations=operations)
71
+ elif block_type == 'A':
72
+ return AttnBlock(c_hidden, c_cond, nhead, self_attn=self_attn, dropout=dropout, dtype=dtype, device=device, operations=operations)
73
+ elif block_type == 'F':
74
+ return FeedForwardBlock(c_hidden, dropout=dropout, dtype=dtype, device=device, operations=operations)
75
+ elif block_type == 'T':
76
+ return TimestepBlock(c_hidden, c_r, conds=t_conds, dtype=dtype, device=device, operations=operations)
77
+ else:
78
+ raise Exception(f'Block type {block_type} not supported')
79
+
80
+ # BLOCKS
81
+ # -- down blocks
82
+ self.down_blocks = nn.ModuleList()
83
+ self.down_downscalers = nn.ModuleList()
84
+ self.down_repeat_mappers = nn.ModuleList()
85
+ for i in range(len(c_hidden)):
86
+ if i > 0:
87
+ self.down_downscalers.append(nn.Sequential(
88
+ LayerNorm2d_op(operations)(c_hidden[i - 1], elementwise_affine=False, eps=1e-6),
89
+ UpDownBlock2d(c_hidden[i - 1], c_hidden[i], mode='down', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations)
90
+ ))
91
+ else:
92
+ self.down_downscalers.append(nn.Identity())
93
+ down_block = nn.ModuleList()
94
+ for _ in range(blocks[0][i]):
95
+ for block_type in level_config[i]:
96
+ block = get_block(block_type, c_hidden[i], nhead[i], dropout=dropout[i], self_attn=self_attn[i])
97
+ down_block.append(block)
98
+ self.down_blocks.append(down_block)
99
+ if block_repeat is not None:
100
+ block_repeat_mappers = nn.ModuleList()
101
+ for _ in range(block_repeat[0][i] - 1):
102
+ block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
103
+ self.down_repeat_mappers.append(block_repeat_mappers)
104
+
105
+ # -- up blocks
106
+ self.up_blocks = nn.ModuleList()
107
+ self.up_upscalers = nn.ModuleList()
108
+ self.up_repeat_mappers = nn.ModuleList()
109
+ for i in reversed(range(len(c_hidden))):
110
+ if i > 0:
111
+ self.up_upscalers.append(nn.Sequential(
112
+ LayerNorm2d_op(operations)(c_hidden[i], elementwise_affine=False, eps=1e-6),
113
+ UpDownBlock2d(c_hidden[i], c_hidden[i - 1], mode='up', enabled=switch_level[i - 1], dtype=dtype, device=device, operations=operations)
114
+ ))
115
+ else:
116
+ self.up_upscalers.append(nn.Identity())
117
+ up_block = nn.ModuleList()
118
+ for j in range(blocks[1][::-1][i]):
119
+ for k, block_type in enumerate(level_config[i]):
120
+ c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
121
+ block = get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i],
122
+ self_attn=self_attn[i])
123
+ up_block.append(block)
124
+ self.up_blocks.append(up_block)
125
+ if block_repeat is not None:
126
+ block_repeat_mappers = nn.ModuleList()
127
+ for _ in range(block_repeat[1][::-1][i] - 1):
128
+ block_repeat_mappers.append(operations.Conv2d(c_hidden[i], c_hidden[i], kernel_size=1, dtype=dtype, device=device))
129
+ self.up_repeat_mappers.append(block_repeat_mappers)
130
+
131
+ # OUTPUT
132
+ self.clf = nn.Sequential(
133
+ LayerNorm2d_op(operations)(c_hidden[0], elementwise_affine=False, eps=1e-6, dtype=dtype, device=device),
134
+ operations.Conv2d(c_hidden[0], c_out * (patch_size ** 2), kernel_size=1, dtype=dtype, device=device),
135
+ nn.PixelShuffle(patch_size),
136
+ )
137
+
138
+ # --- WEIGHT INIT ---
139
+ # self.apply(self._init_weights) # General init
140
+ # nn.init.normal_(self.clip_txt_mapper.weight, std=0.02) # conditionings
141
+ # nn.init.normal_(self.clip_txt_pooled_mapper.weight, std=0.02) # conditionings
142
+ # nn.init.normal_(self.clip_img_mapper.weight, std=0.02) # conditionings
143
+ # torch.nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
144
+ # nn.init.constant_(self.clf[1].weight, 0) # outputs
145
+ #
146
+ # # blocks
147
+ # for level_block in self.down_blocks + self.up_blocks:
148
+ # for block in level_block:
149
+ # if isinstance(block, ResBlock) or isinstance(block, FeedForwardBlock):
150
+ # block.channelwise[-1].weight.data *= np.sqrt(1 / sum(blocks[0]))
151
+ # elif isinstance(block, TimestepBlock):
152
+ # for layer in block.modules():
153
+ # if isinstance(layer, nn.Linear):
154
+ # nn.init.constant_(layer.weight, 0)
155
+ #
156
+ # def _init_weights(self, m):
157
+ # if isinstance(m, (nn.Conv2d, nn.Linear)):
158
+ # torch.nn.init.xavier_uniform_(m.weight)
159
+ # if m.bias is not None:
160
+ # nn.init.constant_(m.bias, 0)
161
+
162
+ def gen_r_embedding(self, r, max_positions=10000):
163
+ r = r * max_positions
164
+ half_dim = self.c_r // 2
165
+ emb = math.log(max_positions) / (half_dim - 1)
166
+ emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
167
+ emb = r[:, None] * emb[None, :]
168
+ emb = torch.cat([emb.sin(), emb.cos()], dim=1)
169
+ if self.c_r % 2 == 1: # zero pad
170
+ emb = nn.functional.pad(emb, (0, 1), mode='constant')
171
+ return emb
172
+
173
+ def gen_c_embeddings(self, clip_txt, clip_txt_pooled, clip_img):
174
+ clip_txt = self.clip_txt_mapper(clip_txt)
175
+ if len(clip_txt_pooled.shape) == 2:
176
+ clip_txt_pooled = clip_txt_pooled.unsqueeze(1)
177
+ if len(clip_img.shape) == 2:
178
+ clip_img = clip_img.unsqueeze(1)
179
+ clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view(clip_txt_pooled.size(0), clip_txt_pooled.size(1) * self.c_clip_seq, -1)
180
+ clip_img = self.clip_img_mapper(clip_img).view(clip_img.size(0), clip_img.size(1) * self.c_clip_seq, -1)
181
+ clip = torch.cat([clip_txt, clip_txt_pool, clip_img], dim=1)
182
+ clip = self.clip_norm(clip)
183
+ return clip
184
+
185
+ def _down_encode(self, x, r_embed, clip, cnet=None):
186
+ level_outputs = []
187
+ block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
188
+ for down_block, downscaler, repmap in block_group:
189
+ x = downscaler(x)
190
+ for i in range(len(repmap) + 1):
191
+ for block in down_block:
192
+ if isinstance(block, ResBlock) or (
193
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
194
+ ResBlock)):
195
+ if cnet is not None:
196
+ next_cnet = cnet.pop()
197
+ if next_cnet is not None:
198
+ x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
199
+ align_corners=True).to(x.dtype)
200
+ x = block(x)
201
+ elif isinstance(block, AttnBlock) or (
202
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
203
+ AttnBlock)):
204
+ x = block(x, clip)
205
+ elif isinstance(block, TimestepBlock) or (
206
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
207
+ TimestepBlock)):
208
+ x = block(x, r_embed)
209
+ else:
210
+ x = block(x)
211
+ if i < len(repmap):
212
+ x = repmap[i](x)
213
+ level_outputs.insert(0, x)
214
+ return level_outputs
215
+
216
+ def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
217
+ x = level_outputs[0]
218
+ block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
219
+ for i, (up_block, upscaler, repmap) in enumerate(block_group):
220
+ for j in range(len(repmap) + 1):
221
+ for k, block in enumerate(up_block):
222
+ if isinstance(block, ResBlock) or (
223
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
224
+ ResBlock)):
225
+ skip = level_outputs[i] if k == 0 and i > 0 else None
226
+ if skip is not None and (x.size(-1) != skip.size(-1) or x.size(-2) != skip.size(-2)):
227
+ x = torch.nn.functional.interpolate(x, skip.shape[-2:], mode='bilinear',
228
+ align_corners=True)
229
+ if cnet is not None:
230
+ next_cnet = cnet.pop()
231
+ if next_cnet is not None:
232
+ x = x + nn.functional.interpolate(next_cnet, size=x.shape[-2:], mode='bilinear',
233
+ align_corners=True).to(x.dtype)
234
+ x = block(x, skip)
235
+ elif isinstance(block, AttnBlock) or (
236
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
237
+ AttnBlock)):
238
+ x = block(x, clip)
239
+ elif isinstance(block, TimestepBlock) or (
240
+ hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
241
+ TimestepBlock)):
242
+ x = block(x, r_embed)
243
+ else:
244
+ x = block(x)
245
+ if j < len(repmap):
246
+ x = repmap[j](x)
247
+ x = upscaler(x)
248
+ return x
249
+
250
+ def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **kwargs):
251
+ # Process the conditioning embeddings
252
+ r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
253
+ for c in self.t_conds:
254
+ t_cond = kwargs.get(c, torch.zeros_like(r))
255
+ r_embed = torch.cat([r_embed, self.gen_r_embedding(t_cond).to(dtype=x.dtype)], dim=1)
256
+ clip = self.gen_c_embeddings(clip_text, clip_text_pooled, clip_img)
257
+
258
+ if control is not None:
259
+ cnet = control.get("input")
260
+ else:
261
+ cnet = None
262
+
263
+ # Model Blocks
264
+ x = self.embedding(x)
265
+ level_outputs = self._down_encode(x, r_embed, clip, cnet)
266
+ x = self._up_decode(level_outputs, r_embed, clip, cnet)
267
+ return self.clf(x)
268
+
269
+ def update_weights_ema(self, src_model, beta=0.999):
270
+ for self_params, src_params in zip(self.parameters(), src_model.parameters()):
271
+ self_params.data = self_params.data * beta + src_params.data.clone().to(self_params.device) * (1 - beta)
272
+ for self_buffers, src_buffers in zip(self.buffers(), src_model.buffers()):
273
+ self_buffers.data = self_buffers.data * beta + src_buffers.data.clone().to(self_buffers.device) * (1 - beta)
comfy/ldm/cascade/stage_c_coder.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is part of ComfyUI.
3
+ Copyright (C) 2024 Stability AI
4
+
5
+ This program is free software: you can redistribute it and/or modify
6
+ it under the terms of the GNU General Public License as published by
7
+ the Free Software Foundation, either version 3 of the License, or
8
+ (at your option) any later version.
9
+
10
+ This program is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ GNU General Public License for more details.
14
+
15
+ You should have received a copy of the GNU General Public License
16
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ """
18
+ import torch
19
+ import torchvision
20
+ from torch import nn
21
+
22
+
23
+ # EfficientNet
24
+ class EfficientNetEncoder(nn.Module):
25
+ def __init__(self, c_latent=16):
26
+ super().__init__()
27
+ self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
28
+ self.mapper = nn.Sequential(
29
+ nn.Conv2d(1280, c_latent, kernel_size=1, bias=False),
30
+ nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
31
+ )
32
+ self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406]))
33
+ self.std = nn.Parameter(torch.tensor([0.229, 0.224, 0.225]))
34
+
35
+ def forward(self, x):
36
+ x = x * 0.5 + 0.5
37
+ x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1])
38
+ o = self.mapper(self.backbone(x))
39
+ return o
40
+
41
+
42
+ # Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192
43
+ class Previewer(nn.Module):
44
+ def __init__(self, c_in=16, c_hidden=512, c_out=3):
45
+ super().__init__()
46
+ self.blocks = nn.Sequential(
47
+ nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
48
+ nn.GELU(),
49
+ nn.BatchNorm2d(c_hidden),
50
+
51
+ nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
52
+ nn.GELU(),
53
+ nn.BatchNorm2d(c_hidden),
54
+
55
+ nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
56
+ nn.GELU(),
57
+ nn.BatchNorm2d(c_hidden // 2),
58
+
59
+ nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
60
+ nn.GELU(),
61
+ nn.BatchNorm2d(c_hidden // 2),
62
+
63
+ nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
64
+ nn.GELU(),
65
+ nn.BatchNorm2d(c_hidden // 4),
66
+
67
+ nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
68
+ nn.GELU(),
69
+ nn.BatchNorm2d(c_hidden // 4),
70
+
71
+ nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
72
+ nn.GELU(),
73
+ nn.BatchNorm2d(c_hidden // 4),
74
+
75
+ nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
76
+ nn.GELU(),
77
+ nn.BatchNorm2d(c_hidden // 4),
78
+
79
+ nn.Conv2d(c_hidden // 4, c_out, kernel_size=1),
80
+ )
81
+
82
+ def forward(self, x):
83
+ return (self.blocks(x) - 0.5) * 2.0
84
+
85
+ class StageC_coder(nn.Module):
86
+ def __init__(self):
87
+ super().__init__()
88
+ self.previewer = Previewer()
89
+ self.encoder = EfficientNetEncoder()
90
+
91
+ def encode(self, x):
92
+ return self.encoder(x)
93
+
94
+ def decode(self, x):
95
+ return self.previewer(x)
comfy/ldm/common_dit.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import comfy.ops
3
+
4
+ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
5
+ if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
6
+ padding_mode = "reflect"
7
+ pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
8
+ pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
9
+ return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
10
+
11
+ try:
12
+ rms_norm_torch = torch.nn.functional.rms_norm
13
+ except:
14
+ rms_norm_torch = None
15
+
16
+ def rms_norm(x, weight, eps=1e-6):
17
+ if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
18
+ return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
19
+ else:
20
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
21
+ return (x * rrms) * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)
comfy/ldm/flux/controlnet.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
2
+ #modified to support different types of flux controlnets
3
+
4
+ import torch
5
+ import math
6
+ from torch import Tensor, nn
7
+ from einops import rearrange, repeat
8
+
9
+ from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
10
+ MLPEmbedder, SingleStreamBlock,
11
+ timestep_embedding)
12
+
13
+ from .model import Flux
14
+ import comfy.ldm.common_dit
15
+
16
+ class MistolineCondDownsamplBlock(nn.Module):
17
+ def __init__(self, dtype=None, device=None, operations=None):
18
+ super().__init__()
19
+ self.encoder = nn.Sequential(
20
+ operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
21
+ nn.SiLU(),
22
+ operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
23
+ nn.SiLU(),
24
+ operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
25
+ nn.SiLU(),
26
+ operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
27
+ nn.SiLU(),
28
+ operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
29
+ nn.SiLU(),
30
+ operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
31
+ nn.SiLU(),
32
+ operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
33
+ nn.SiLU(),
34
+ operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
35
+ nn.SiLU(),
36
+ operations.Conv2d(16, 16, 1, dtype=dtype, device=device),
37
+ nn.SiLU(),
38
+ operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
39
+ )
40
+
41
+ def forward(self, x):
42
+ return self.encoder(x)
43
+
44
+ class MistolineControlnetBlock(nn.Module):
45
+ def __init__(self, hidden_size, dtype=None, device=None, operations=None):
46
+ super().__init__()
47
+ self.linear = operations.Linear(hidden_size, hidden_size, dtype=dtype, device=device)
48
+ self.act = nn.SiLU()
49
+
50
+ def forward(self, x):
51
+ return self.act(self.linear(x))
52
+
53
+
54
+ class ControlNetFlux(Flux):
55
+ def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, control_latent_channels=None, image_model=None, dtype=None, device=None, operations=None, **kwargs):
56
+ super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
57
+
58
+ self.main_model_double = 19
59
+ self.main_model_single = 38
60
+
61
+ self.mistoline = mistoline
62
+ # add ControlNet blocks
63
+ if self.mistoline:
64
+ control_block = lambda : MistolineControlnetBlock(self.hidden_size, dtype=dtype, device=device, operations=operations)
65
+ else:
66
+ control_block = lambda : operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
67
+
68
+ self.controlnet_blocks = nn.ModuleList([])
69
+ for _ in range(self.params.depth):
70
+ self.controlnet_blocks.append(control_block())
71
+
72
+ self.controlnet_single_blocks = nn.ModuleList([])
73
+ for _ in range(self.params.depth_single_blocks):
74
+ self.controlnet_single_blocks.append(control_block())
75
+
76
+ self.num_union_modes = num_union_modes
77
+ self.controlnet_mode_embedder = None
78
+ if self.num_union_modes > 0:
79
+ self.controlnet_mode_embedder = operations.Embedding(self.num_union_modes, self.hidden_size, dtype=dtype, device=device)
80
+
81
+ self.gradient_checkpointing = False
82
+ self.latent_input = latent_input
83
+ if control_latent_channels is None:
84
+ control_latent_channels = self.in_channels
85
+ else:
86
+ control_latent_channels *= 2 * 2 #patch size
87
+
88
+ self.pos_embed_input = operations.Linear(control_latent_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
89
+ if not self.latent_input:
90
+ if self.mistoline:
91
+ self.input_cond_block = MistolineCondDownsamplBlock(dtype=dtype, device=device, operations=operations)
92
+ else:
93
+ self.input_hint_block = nn.Sequential(
94
+ operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
95
+ nn.SiLU(),
96
+ operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
97
+ nn.SiLU(),
98
+ operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
99
+ nn.SiLU(),
100
+ operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
101
+ nn.SiLU(),
102
+ operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
103
+ nn.SiLU(),
104
+ operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
105
+ nn.SiLU(),
106
+ operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
107
+ nn.SiLU(),
108
+ operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
109
+ )
110
+
111
+ def forward_orig(
112
+ self,
113
+ img: Tensor,
114
+ img_ids: Tensor,
115
+ controlnet_cond: Tensor,
116
+ txt: Tensor,
117
+ txt_ids: Tensor,
118
+ timesteps: Tensor,
119
+ y: Tensor,
120
+ guidance: Tensor = None,
121
+ control_type: Tensor = None,
122
+ ) -> Tensor:
123
+ if img.ndim != 3 or txt.ndim != 3:
124
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
125
+
126
+ # running on sequences img
127
+ img = self.img_in(img)
128
+
129
+ controlnet_cond = self.pos_embed_input(controlnet_cond)
130
+ img = img + controlnet_cond
131
+ vec = self.time_in(timestep_embedding(timesteps, 256))
132
+ if self.params.guidance_embed:
133
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
134
+ vec = vec + self.vector_in(y)
135
+ txt = self.txt_in(txt)
136
+
137
+ if self.controlnet_mode_embedder is not None and len(control_type) > 0:
138
+ control_cond = self.controlnet_mode_embedder(torch.tensor(control_type, device=img.device), out_dtype=img.dtype).unsqueeze(0).repeat((txt.shape[0], 1, 1))
139
+ txt = torch.cat([control_cond, txt], dim=1)
140
+ txt_ids = torch.cat([txt_ids[:,:1], txt_ids], dim=1)
141
+
142
+ ids = torch.cat((txt_ids, img_ids), dim=1)
143
+ pe = self.pe_embedder(ids)
144
+
145
+ controlnet_double = ()
146
+
147
+ for i in range(len(self.double_blocks)):
148
+ img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
149
+ controlnet_double = controlnet_double + (self.controlnet_blocks[i](img),)
150
+
151
+ img = torch.cat((txt, img), 1)
152
+
153
+ controlnet_single = ()
154
+
155
+ for i in range(len(self.single_blocks)):
156
+ img = self.single_blocks[i](img, vec=vec, pe=pe)
157
+ controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1] :, ...]),)
158
+
159
+ repeat = math.ceil(self.main_model_double / len(controlnet_double))
160
+ if self.latent_input:
161
+ out_input = ()
162
+ for x in controlnet_double:
163
+ out_input += (x,) * repeat
164
+ else:
165
+ out_input = (controlnet_double * repeat)
166
+
167
+ out = {"input": out_input[:self.main_model_double]}
168
+ if len(controlnet_single) > 0:
169
+ repeat = math.ceil(self.main_model_single / len(controlnet_single))
170
+ out_output = ()
171
+ if self.latent_input:
172
+ for x in controlnet_single:
173
+ out_output += (x,) * repeat
174
+ else:
175
+ out_output = (controlnet_single * repeat)
176
+ out["output"] = out_output[:self.main_model_single]
177
+ return out
178
+
179
+ def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
180
+ patch_size = 2
181
+ if self.latent_input:
182
+ hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
183
+ elif self.mistoline:
184
+ hint = hint * 2.0 - 1.0
185
+ hint = self.input_cond_block(hint)
186
+ else:
187
+ hint = hint * 2.0 - 1.0
188
+ hint = self.input_hint_block(hint)
189
+
190
+ hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
191
+
192
+ bs, c, h, w = x.shape
193
+ x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
194
+
195
+ img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
196
+
197
+ h_len = ((h + (patch_size // 2)) // patch_size)
198
+ w_len = ((w + (patch_size // 2)) // patch_size)
199
+ img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
200
+ img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
201
+ img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
202
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
203
+
204
+ txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
205
+ return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type=kwargs.get("control_type", []))