diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..5ed75580d18ef6d4a19bfaee219379194cbcbb9a 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +comfyui_screenshot.png filter=lfs diff=lfs merge=lfs -text diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..048f127e72ddc351451a9612d4c715f603393a84 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,41 @@ +# Contributing to ComfyUI + +Welcome, and thank you for your interest in contributing to ComfyUI! + +There are several ways in which you can contribute, beyond writing code. The goal of this document is to provide a high-level overview of how you can get involved. + +## Asking Questions + +Have a question? Instead of opening an issue, please ask on [Discord](https://comfy.org/discord) or [Matrix](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) channels. Our team and the community will help you. + +## Providing Feedback + +Your comments and feedback are welcome, and the development team is available via a handful of different channels. + +See the `#bug-report`, `#feature-request` and `#feedback` channels on Discord. + +## Reporting Issues + +Have you identified a reproducible problem in ComfyUI? Do you have a feature request? We want to hear about it! Here's how you can report your issue as effectively as possible. + + +### Look For an Existing Issue + +Before you create a new issue, please do a search in [open issues](https://github.com/comfyanonymous/ComfyUI/issues) to see if the issue or feature request has already been filed. + +If you find your issue already exists, make relevant comments and add your [reaction](https://github.com/blog/2119-add-reactions-to-pull-requests-issues-and-comments). Use a reaction in place of a "+1" comment: + +* 👍 - upvote +* 👎 - downvote + +If you cannot find an existing issue that describes your bug or feature, create a new issue. We have an issue template in place to organize new issues. + + +### Creating Pull Requests + +* Please refer to the article on [creating pull requests](https://github.com/comfyanonymous/ComfyUI/wiki/How-to-Contribute-Code) and contributing to this project. + + +## Thank You + +Your contributions to open source, large or small, make great projects like this possible. Thank you for taking the time to contribute. diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py new file mode 100644 index 0000000000000000000000000000000000000000..630f280fc5e94ae7fd1cfa81f290a49f2a506e8a --- /dev/null +++ b/comfy_execution/caching.py @@ -0,0 +1,318 @@ +import itertools +from typing import Sequence, Mapping, Dict +from comfy_execution.graph import DynamicPrompt + +import nodes + +from comfy_execution.graph_utils import is_link + +NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {} + + +def include_unique_id_in_input(class_type: str) -> bool: + if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID: + return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values() + return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] + +class CacheKeySet: + def __init__(self, dynprompt, node_ids, is_changed_cache): + self.keys = {} + self.subcache_keys = {} + + def add_keys(self, node_ids): + raise NotImplementedError() + + def all_node_ids(self): + return set(self.keys.keys()) + + def get_used_keys(self): + return self.keys.values() + + def get_used_subcache_keys(self): + return self.subcache_keys.values() + + def get_data_key(self, node_id): + return self.keys.get(node_id, None) + + def get_subcache_key(self, node_id): + return self.subcache_keys.get(node_id, None) + +class Unhashable: + def __init__(self): + self.value = float("NaN") + +def to_hashable(obj): + # So that we don't infinitely recurse since frozenset and tuples + # are Sequences. + if isinstance(obj, (int, float, str, bool, type(None))): + return obj + elif isinstance(obj, Mapping): + return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())]) + elif isinstance(obj, Sequence): + return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj])) + else: + # TODO - Support other objects like tensors? + return Unhashable() + +class CacheKeySetID(CacheKeySet): + def __init__(self, dynprompt, node_ids, is_changed_cache): + super().__init__(dynprompt, node_ids, is_changed_cache) + self.dynprompt = dynprompt + self.add_keys(node_ids) + + def add_keys(self, node_ids): + for node_id in node_ids: + if node_id in self.keys: + continue + if not self.dynprompt.has_node(node_id): + continue + node = self.dynprompt.get_node(node_id) + self.keys[node_id] = (node_id, node["class_type"]) + self.subcache_keys[node_id] = (node_id, node["class_type"]) + +class CacheKeySetInputSignature(CacheKeySet): + def __init__(self, dynprompt, node_ids, is_changed_cache): + super().__init__(dynprompt, node_ids, is_changed_cache) + self.dynprompt = dynprompt + self.is_changed_cache = is_changed_cache + self.add_keys(node_ids) + + def include_node_id_in_input(self) -> bool: + return False + + def add_keys(self, node_ids): + for node_id in node_ids: + if node_id in self.keys: + continue + if not self.dynprompt.has_node(node_id): + continue + node = self.dynprompt.get_node(node_id) + self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id) + self.subcache_keys[node_id] = (node_id, node["class_type"]) + + def get_node_signature(self, dynprompt, node_id): + signature = [] + ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id) + signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping)) + for ancestor_id in ancestors: + signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping)) + return to_hashable(signature) + + def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping): + if not dynprompt.has_node(node_id): + # This node doesn't exist -- we can't cache it. + return [float("NaN")] + node = dynprompt.get_node(node_id) + class_type = node["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + signature = [class_type, self.is_changed_cache.get(node_id)] + if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type): + signature.append(node_id) + inputs = node["inputs"] + for key in sorted(inputs.keys()): + if is_link(inputs[key]): + (ancestor_id, ancestor_socket) = inputs[key] + ancestor_index = ancestor_order_mapping[ancestor_id] + signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket))) + else: + signature.append((key, inputs[key])) + return signature + + # This function returns a list of all ancestors of the given node. The order of the list is + # deterministic based on which specific inputs the ancestor is connected by. + def get_ordered_ancestry(self, dynprompt, node_id): + ancestors = [] + order_mapping = {} + self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping) + return ancestors, order_mapping + + def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping): + if not dynprompt.has_node(node_id): + return + inputs = dynprompt.get_node(node_id)["inputs"] + input_keys = sorted(inputs.keys()) + for key in input_keys: + if is_link(inputs[key]): + ancestor_id = inputs[key][0] + if ancestor_id not in order_mapping: + ancestors.append(ancestor_id) + order_mapping[ancestor_id] = len(ancestors) - 1 + self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping) + +class BasicCache: + def __init__(self, key_class): + self.key_class = key_class + self.initialized = False + self.dynprompt: DynamicPrompt + self.cache_key_set: CacheKeySet + self.cache = {} + self.subcaches = {} + + def set_prompt(self, dynprompt, node_ids, is_changed_cache): + self.dynprompt = dynprompt + self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache) + self.is_changed_cache = is_changed_cache + self.initialized = True + + def all_node_ids(self): + assert self.initialized + node_ids = self.cache_key_set.all_node_ids() + for subcache in self.subcaches.values(): + node_ids = node_ids.union(subcache.all_node_ids()) + return node_ids + + def _clean_cache(self): + preserve_keys = set(self.cache_key_set.get_used_keys()) + to_remove = [] + for key in self.cache: + if key not in preserve_keys: + to_remove.append(key) + for key in to_remove: + del self.cache[key] + + def _clean_subcaches(self): + preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys()) + + to_remove = [] + for key in self.subcaches: + if key not in preserve_subcaches: + to_remove.append(key) + for key in to_remove: + del self.subcaches[key] + + def clean_unused(self): + assert self.initialized + self._clean_cache() + self._clean_subcaches() + + def _set_immediate(self, node_id, value): + assert self.initialized + cache_key = self.cache_key_set.get_data_key(node_id) + self.cache[cache_key] = value + + def _get_immediate(self, node_id): + if not self.initialized: + return None + cache_key = self.cache_key_set.get_data_key(node_id) + if cache_key in self.cache: + return self.cache[cache_key] + else: + return None + + def _ensure_subcache(self, node_id, children_ids): + subcache_key = self.cache_key_set.get_subcache_key(node_id) + subcache = self.subcaches.get(subcache_key, None) + if subcache is None: + subcache = BasicCache(self.key_class) + self.subcaches[subcache_key] = subcache + subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache) + return subcache + + def _get_subcache(self, node_id): + assert self.initialized + subcache_key = self.cache_key_set.get_subcache_key(node_id) + if subcache_key in self.subcaches: + return self.subcaches[subcache_key] + else: + return None + + def recursive_debug_dump(self): + result = [] + for key in self.cache: + result.append({"key": key, "value": self.cache[key]}) + for key in self.subcaches: + result.append({"subcache_key": key, "subcache": self.subcaches[key].recursive_debug_dump()}) + return result + +class HierarchicalCache(BasicCache): + def __init__(self, key_class): + super().__init__(key_class) + + def _get_cache_for(self, node_id): + assert self.dynprompt is not None + parent_id = self.dynprompt.get_parent_node_id(node_id) + if parent_id is None: + return self + + hierarchy = [] + while parent_id is not None: + hierarchy.append(parent_id) + parent_id = self.dynprompt.get_parent_node_id(parent_id) + + cache = self + for parent_id in reversed(hierarchy): + cache = cache._get_subcache(parent_id) + if cache is None: + return None + return cache + + def get(self, node_id): + cache = self._get_cache_for(node_id) + if cache is None: + return None + return cache._get_immediate(node_id) + + def set(self, node_id, value): + cache = self._get_cache_for(node_id) + assert cache is not None + cache._set_immediate(node_id, value) + + def ensure_subcache_for(self, node_id, children_ids): + cache = self._get_cache_for(node_id) + assert cache is not None + return cache._ensure_subcache(node_id, children_ids) + +class LRUCache(BasicCache): + def __init__(self, key_class, max_size=100): + super().__init__(key_class) + self.max_size = max_size + self.min_generation = 0 + self.generation = 0 + self.used_generation = {} + self.children = {} + + def set_prompt(self, dynprompt, node_ids, is_changed_cache): + super().set_prompt(dynprompt, node_ids, is_changed_cache) + self.generation += 1 + for node_id in node_ids: + self._mark_used(node_id) + + def clean_unused(self): + while len(self.cache) > self.max_size and self.min_generation < self.generation: + self.min_generation += 1 + to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation] + for key in to_remove: + del self.cache[key] + del self.used_generation[key] + if key in self.children: + del self.children[key] + self._clean_subcaches() + + def get(self, node_id): + self._mark_used(node_id) + return self._get_immediate(node_id) + + def _mark_used(self, node_id): + cache_key = self.cache_key_set.get_data_key(node_id) + if cache_key is not None: + self.used_generation[cache_key] = self.generation + + def set(self, node_id, value): + self._mark_used(node_id) + return self._set_immediate(node_id, value) + + def ensure_subcache_for(self, node_id, children_ids): + # Just uses subcaches for tracking 'live' nodes + super()._ensure_subcache(node_id, children_ids) + + self.cache_key_set.add_keys(children_ids) + self._mark_used(node_id) + cache_key = self.cache_key_set.get_data_key(node_id) + self.children[cache_key] = [] + for child_id in children_ids: + self._mark_used(child_id) + self.children[cache_key].append(self.cache_key_set.get_data_key(child_id)) + return self + diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..0b5bf189906994a3b6c5db61929a174d4cc7b109 --- /dev/null +++ b/comfy_execution/graph.py @@ -0,0 +1,270 @@ +import nodes + +from comfy_execution.graph_utils import is_link + +class DependencyCycleError(Exception): + pass + +class NodeInputError(Exception): + pass + +class NodeNotFoundError(Exception): + pass + +class DynamicPrompt: + def __init__(self, original_prompt): + # The original prompt provided by the user + self.original_prompt = original_prompt + # Any extra pieces of the graph created during execution + self.ephemeral_prompt = {} + self.ephemeral_parents = {} + self.ephemeral_display = {} + + def get_node(self, node_id): + if node_id in self.ephemeral_prompt: + return self.ephemeral_prompt[node_id] + if node_id in self.original_prompt: + return self.original_prompt[node_id] + raise NodeNotFoundError(f"Node {node_id} not found") + + def has_node(self, node_id): + return node_id in self.original_prompt or node_id in self.ephemeral_prompt + + def add_ephemeral_node(self, node_id, node_info, parent_id, display_id): + self.ephemeral_prompt[node_id] = node_info + self.ephemeral_parents[node_id] = parent_id + self.ephemeral_display[node_id] = display_id + + def get_real_node_id(self, node_id): + while node_id in self.ephemeral_parents: + node_id = self.ephemeral_parents[node_id] + return node_id + + def get_parent_node_id(self, node_id): + return self.ephemeral_parents.get(node_id, None) + + def get_display_node_id(self, node_id): + while node_id in self.ephemeral_display: + node_id = self.ephemeral_display[node_id] + return node_id + + def all_node_ids(self): + return set(self.original_prompt.keys()).union(set(self.ephemeral_prompt.keys())) + + def get_original_prompt(self): + return self.original_prompt + +def get_input_info(class_def, input_name): + valid_inputs = class_def.INPUT_TYPES() + input_info = None + input_category = None + if "required" in valid_inputs and input_name in valid_inputs["required"]: + input_category = "required" + input_info = valid_inputs["required"][input_name] + elif "optional" in valid_inputs and input_name in valid_inputs["optional"]: + input_category = "optional" + input_info = valid_inputs["optional"][input_name] + elif "hidden" in valid_inputs and input_name in valid_inputs["hidden"]: + input_category = "hidden" + input_info = valid_inputs["hidden"][input_name] + if input_info is None: + return None, None, None + input_type = input_info[0] + if len(input_info) > 1: + extra_info = input_info[1] + else: + extra_info = {} + return input_type, input_category, extra_info + +class TopologicalSort: + def __init__(self, dynprompt): + self.dynprompt = dynprompt + self.pendingNodes = {} + self.blockCount = {} # Number of nodes this node is directly blocked by + self.blocking = {} # Which nodes are blocked by this node + + def get_input_info(self, unique_id, input_name): + class_type = self.dynprompt.get_node(unique_id)["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + return get_input_info(class_def, input_name) + + def make_input_strong_link(self, to_node_id, to_input): + inputs = self.dynprompt.get_node(to_node_id)["inputs"] + if to_input not in inputs: + raise NodeInputError(f"Node {to_node_id} says it needs input {to_input}, but there is no input to that node at all") + value = inputs[to_input] + if not is_link(value): + raise NodeInputError(f"Node {to_node_id} says it needs input {to_input}, but that value is a constant") + from_node_id, from_socket = value + self.add_strong_link(from_node_id, from_socket, to_node_id) + + def add_strong_link(self, from_node_id, from_socket, to_node_id): + if not self.is_cached(from_node_id): + self.add_node(from_node_id) + if to_node_id not in self.blocking[from_node_id]: + self.blocking[from_node_id][to_node_id] = {} + self.blockCount[to_node_id] += 1 + self.blocking[from_node_id][to_node_id][from_socket] = True + + def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None): + node_ids = [node_unique_id] + links = [] + + while len(node_ids) > 0: + unique_id = node_ids.pop() + if unique_id in self.pendingNodes: + continue + + self.pendingNodes[unique_id] = True + self.blockCount[unique_id] = 0 + self.blocking[unique_id] = {} + + inputs = self.dynprompt.get_node(unique_id)["inputs"] + for input_name in inputs: + value = inputs[input_name] + if is_link(value): + from_node_id, from_socket = value + if subgraph_nodes is not None and from_node_id not in subgraph_nodes: + continue + input_type, input_category, input_info = self.get_input_info(unique_id, input_name) + is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"] + if (include_lazy or not is_lazy) and not self.is_cached(from_node_id): + node_ids.append(from_node_id) + links.append((from_node_id, from_socket, unique_id)) + + for link in links: + self.add_strong_link(*link) + + def is_cached(self, node_id): + return False + + def get_ready_nodes(self): + return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0] + + def pop_node(self, unique_id): + del self.pendingNodes[unique_id] + for blocked_node_id in self.blocking[unique_id]: + self.blockCount[blocked_node_id] -= 1 + del self.blocking[unique_id] + + def is_empty(self): + return len(self.pendingNodes) == 0 + +class ExecutionList(TopologicalSort): + """ + ExecutionList implements a topological dissolve of the graph. After a node is staged for execution, + it can still be returned to the graph after having further dependencies added. + """ + def __init__(self, dynprompt, output_cache): + super().__init__(dynprompt) + self.output_cache = output_cache + self.staged_node_id = None + + def is_cached(self, node_id): + return self.output_cache.get(node_id) is not None + + def stage_node_execution(self): + assert self.staged_node_id is None + if self.is_empty(): + return None, None, None + available = self.get_ready_nodes() + if len(available) == 0: + cycled_nodes = self.get_nodes_in_cycle() + # Because cycles composed entirely of static nodes are caught during initial validation, + # we will 'blame' the first node in the cycle that is not a static node. + blamed_node = cycled_nodes[0] + for node_id in cycled_nodes: + display_node_id = self.dynprompt.get_display_node_id(node_id) + if display_node_id != node_id: + blamed_node = display_node_id + break + ex = DependencyCycleError("Dependency cycle detected") + error_details = { + "node_id": blamed_node, + "exception_message": str(ex), + "exception_type": "graph.DependencyCycleError", + "traceback": [], + "current_inputs": [] + } + return None, error_details, ex + + self.staged_node_id = self.ux_friendly_pick_node(available) + return self.staged_node_id, None, None + + def ux_friendly_pick_node(self, node_list): + # If an output node is available, do that first. + # Technically this has no effect on the overall length of execution, but it feels better as a user + # for a PreviewImage to display a result as soon as it can + # Some other heuristics could probably be used here to improve the UX further. + def is_output(node_id): + class_type = self.dynprompt.get_node(node_id)["class_type"] + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True: + return True + return False + + for node_id in node_list: + if is_output(node_id): + return node_id + + #This should handle the VAEDecode -> preview case + for node_id in node_list: + for blocked_node_id in self.blocking[node_id]: + if is_output(blocked_node_id): + return node_id + + #This should handle the VAELoader -> VAEDecode -> preview case + for node_id in node_list: + for blocked_node_id in self.blocking[node_id]: + for blocked_node_id1 in self.blocking[blocked_node_id]: + if is_output(blocked_node_id1): + return node_id + + #TODO: this function should be improved + return node_list[0] + + def unstage_node_execution(self): + assert self.staged_node_id is not None + self.staged_node_id = None + + def complete_node_execution(self): + node_id = self.staged_node_id + self.pop_node(node_id) + self.staged_node_id = None + + def get_nodes_in_cycle(self): + # We'll dissolve the graph in reverse topological order to leave only the nodes in the cycle. + # We're skipping some of the performance optimizations from the original TopologicalSort to keep + # the code simple (and because having a cycle in the first place is a catastrophic error) + blocked_by = { node_id: {} for node_id in self.pendingNodes } + for from_node_id in self.blocking: + for to_node_id in self.blocking[from_node_id]: + if True in self.blocking[from_node_id][to_node_id].values(): + blocked_by[to_node_id][from_node_id] = True + to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0] + while len(to_remove) > 0: + for node_id in to_remove: + for to_node_id in blocked_by: + if node_id in blocked_by[to_node_id]: + del blocked_by[to_node_id][node_id] + del blocked_by[node_id] + to_remove = [node_id for node_id in blocked_by if len(blocked_by[node_id]) == 0] + return list(blocked_by.keys()) + +class ExecutionBlocker: + """ + Return this from a node and any users will be blocked with the given error message. + If the message is None, execution will be blocked silently instead. + Generally, you should avoid using this functionality unless absolutely necessary. Whenever it's + possible, a lazy input will be more efficient and have a better user experience. + This functionality is useful in two cases: + 1. You want to conditionally prevent an output node from executing. (Particularly a built-in node + like SaveImage. For your own output nodes, I would recommend just adding a BOOL input and using + lazy evaluation to let it conditionally disable itself.) + 2. You have a node with multiple possible outputs, some of which are invalid and should not be used. + (I would recommend not making nodes like this in the future -- instead, make multiple nodes with + different outputs. Unfortunately, there are several popular existing nodes using this pattern.) + """ + def __init__(self, message): + self.message = message + diff --git a/comfy_execution/graph_utils.py b/comfy_execution/graph_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8595e942d32160152c3c0163c85ad9bcd2e68d45 --- /dev/null +++ b/comfy_execution/graph_utils.py @@ -0,0 +1,139 @@ +def is_link(obj): + if not isinstance(obj, list): + return False + if len(obj) != 2: + return False + if not isinstance(obj[0], str): + return False + if not isinstance(obj[1], int) and not isinstance(obj[1], float): + return False + return True + +# The GraphBuilder is just a utility class that outputs graphs in the form expected by the ComfyUI back-end +class GraphBuilder: + _default_prefix_root = "" + _default_prefix_call_index = 0 + _default_prefix_graph_index = 0 + + def __init__(self, prefix = None): + if prefix is None: + self.prefix = GraphBuilder.alloc_prefix() + else: + self.prefix = prefix + self.nodes = {} + self.id_gen = 1 + + @classmethod + def set_default_prefix(cls, prefix_root, call_index, graph_index = 0): + cls._default_prefix_root = prefix_root + cls._default_prefix_call_index = call_index + cls._default_prefix_graph_index = graph_index + + @classmethod + def alloc_prefix(cls, root=None, call_index=None, graph_index=None): + if root is None: + root = GraphBuilder._default_prefix_root + if call_index is None: + call_index = GraphBuilder._default_prefix_call_index + if graph_index is None: + graph_index = GraphBuilder._default_prefix_graph_index + result = f"{root}.{call_index}.{graph_index}." + GraphBuilder._default_prefix_graph_index += 1 + return result + + def node(self, class_type, id=None, **kwargs): + if id is None: + id = str(self.id_gen) + self.id_gen += 1 + id = self.prefix + id + if id in self.nodes: + return self.nodes[id] + + node = Node(id, class_type, kwargs) + self.nodes[id] = node + return node + + def lookup_node(self, id): + id = self.prefix + id + return self.nodes.get(id) + + def finalize(self): + output = {} + for node_id, node in self.nodes.items(): + output[node_id] = node.serialize() + return output + + def replace_node_output(self, node_id, index, new_value): + node_id = self.prefix + node_id + to_remove = [] + for node in self.nodes.values(): + for key, value in node.inputs.items(): + if is_link(value) and value[0] == node_id and value[1] == index: + if new_value is None: + to_remove.append((node, key)) + else: + node.inputs[key] = new_value + for node, key in to_remove: + del node.inputs[key] + + def remove_node(self, id): + id = self.prefix + id + del self.nodes[id] + +class Node: + def __init__(self, id, class_type, inputs): + self.id = id + self.class_type = class_type + self.inputs = inputs + self.override_display_id = None + + def out(self, index): + return [self.id, index] + + def set_input(self, key, value): + if value is None: + if key in self.inputs: + del self.inputs[key] + else: + self.inputs[key] = value + + def get_input(self, key): + return self.inputs.get(key) + + def set_override_display_id(self, override_display_id): + self.override_display_id = override_display_id + + def serialize(self): + serialized = { + "class_type": self.class_type, + "inputs": self.inputs + } + if self.override_display_id is not None: + serialized["override_display_id"] = self.override_display_id + return serialized + +def add_graph_prefix(graph, outputs, prefix): + # Change the node IDs and any internal links + new_graph = {} + for node_id, node_info in graph.items(): + # Make sure the added nodes have unique IDs + new_node_id = prefix + node_id + new_node = { "class_type": node_info["class_type"], "inputs": {} } + for input_name, input_value in node_info.get("inputs", {}).items(): + if is_link(input_value): + new_node["inputs"][input_name] = [prefix + input_value[0], input_value[1]] + else: + new_node["inputs"][input_name] = input_value + new_graph[new_node_id] = new_node + + # Change the node IDs in the outputs + new_outputs = [] + for n in range(len(outputs)): + output = outputs[n] + if is_link(output): + new_outputs.append([prefix + output[0], output[1]]) + else: + new_outputs.append(output) + + return new_graph, tuple(new_outputs) + diff --git a/comfy_extras/chainner_models/model_loading.py b/comfy_extras/chainner_models/model_loading.py new file mode 100644 index 0000000000000000000000000000000000000000..d48bc238ccc6efbb383232c46d83d9c29c54ce5d --- /dev/null +++ b/comfy_extras/chainner_models/model_loading.py @@ -0,0 +1,5 @@ +from spandrel import ModelLoader + +def load_state_dict(state_dict): + print("WARNING: comfy_extras.chainner_models is deprecated and has been replaced by the spandrel library.") + return ModelLoader().load_from_state_dict(state_dict).eval() diff --git a/comfy_extras/nodes_advanced_samplers.py b/comfy_extras/nodes_advanced_samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..820c250ef3aa4b3099ab7e20c4443bd6ed9bff6b --- /dev/null +++ b/comfy_extras/nodes_advanced_samplers.py @@ -0,0 +1,112 @@ +import comfy.samplers +import comfy.utils +import torch +import numpy as np +from tqdm.auto import trange, tqdm +import math + + +@torch.no_grad() +def sample_lcm_upscale(model, x, sigmas, extra_args=None, callback=None, disable=None, total_upscale=2.0, upscale_method="bislerp", upscale_steps=None): + extra_args = {} if extra_args is None else extra_args + + if upscale_steps is None: + upscale_steps = max(len(sigmas) // 2 + 1, 2) + else: + upscale_steps += 1 + upscale_steps = min(upscale_steps, len(sigmas) + 1) + + upscales = np.linspace(1.0, total_upscale, upscale_steps)[1:] + + orig_shape = x.size() + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + + x = denoised + if i < len(upscales): + x = comfy.utils.common_upscale(x, round(orig_shape[-1] * upscales[i]), round(orig_shape[-2] * upscales[i]), upscale_method, "disabled") + + if sigmas[i + 1] > 0: + x += sigmas[i + 1] * torch.randn_like(x) + return x + + +class SamplerLCMUpscale: + upscale_methods = ["bislerp", "nearest-exact", "bilinear", "area", "bicubic"] + + @classmethod + def INPUT_TYPES(s): + return {"required": + {"scale_ratio": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 20.0, "step": 0.01}), + "scale_steps": ("INT", {"default": -1, "min": -1, "max": 1000, "step": 1}), + "upscale_method": (s.upscale_methods,), + } + } + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "sampling/custom_sampling/samplers" + + FUNCTION = "get_sampler" + + def get_sampler(self, scale_ratio, scale_steps, upscale_method): + if scale_steps < 0: + scale_steps = None + sampler = comfy.samplers.KSAMPLER(sample_lcm_upscale, extra_options={"total_upscale": scale_ratio, "upscale_steps": scale_steps, "upscale_method": upscale_method}) + return (sampler, ) + +from comfy.k_diffusion.sampling import to_d +import comfy.model_patcher + +@torch.no_grad() +def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=None): + extra_args = {} if extra_args is None else extra_args + + temp = [0] + def post_cfg_function(args): + temp[0] = args["uncond_denoised"] + return args["denoised"] + + model_options = extra_args.get("model_options", {}).copy() + extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) + + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + sigma_hat = sigmas[i] + denoised = model(x, sigma_hat * s_in, **extra_args) + d = to_d(x - denoised + temp[0], sigmas[i], denoised) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + dt = sigmas[i + 1] - sigma_hat + x = x + d * dt + return x + + +class SamplerEulerCFGpp: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"version": (["regular", "alternative"],),} + } + RETURN_TYPES = ("SAMPLER",) + # CATEGORY = "sampling/custom_sampling/samplers" + CATEGORY = "_for_testing" + + FUNCTION = "get_sampler" + + def get_sampler(self, version): + if version == "alternative": + sampler = comfy.samplers.KSAMPLER(sample_euler_pp) + else: + sampler = comfy.samplers.ksampler("euler_cfg_pp") + return (sampler, ) + +NODE_CLASS_MAPPINGS = { + "SamplerLCMUpscale": SamplerLCMUpscale, + "SamplerEulerCFGpp": SamplerEulerCFGpp, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "SamplerEulerCFGpp": "SamplerEulerCFG++", +} diff --git a/comfy_extras/nodes_align_your_steps.py b/comfy_extras/nodes_align_your_steps.py new file mode 100644 index 0000000000000000000000000000000000000000..3ffe531878521bdb81e7fee4e9d03f7c37636abc --- /dev/null +++ b/comfy_extras/nodes_align_your_steps.py @@ -0,0 +1,53 @@ +#from: https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html +import numpy as np +import torch + +def loglinear_interp(t_steps, num_steps): + """ + Performs log-linear interpolation of a given array of decreasing numbers. + """ + xs = np.linspace(0, 1, len(t_steps)) + ys = np.log(t_steps[::-1]) + + new_xs = np.linspace(0, 1, num_steps) + new_ys = np.interp(new_xs, xs, ys) + + interped_ys = np.exp(new_ys)[::-1].copy() + return interped_ys + +NOISE_LEVELS = {"SD1": [14.6146412293, 6.4745760956, 3.8636745985, 2.6946151520, 1.8841921177, 1.3943805092, 0.9642583904, 0.6523686016, 0.3977456272, 0.1515232662, 0.0291671582], + "SDXL":[14.6146412293, 6.3184485287, 3.7681790315, 2.1811480769, 1.3405244945, 0.8620721141, 0.5550693289, 0.3798540708, 0.2332364134, 0.1114188177, 0.0291671582], + "SVD": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002]} + +class AlignYourStepsScheduler: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model_type": (["SD1", "SDXL", "SVD"], ), + "steps": ("INT", {"default": 10, "min": 10, "max": 10000}), + "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + } + } + RETURN_TYPES = ("SIGMAS",) + CATEGORY = "sampling/custom_sampling/schedulers" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, model_type, steps, denoise): + total_steps = steps + if denoise < 1.0: + if denoise <= 0.0: + return (torch.FloatTensor([]),) + total_steps = round(steps * denoise) + + sigmas = NOISE_LEVELS[model_type][:] + if (steps + 1) != len(sigmas): + sigmas = loglinear_interp(sigmas, steps + 1) + + sigmas = sigmas[-(total_steps + 1):] + sigmas[-1] = 0 + return (torch.FloatTensor(sigmas), ) + +NODE_CLASS_MAPPINGS = { + "AlignYourStepsScheduler": AlignYourStepsScheduler, +} diff --git a/comfy_extras/nodes_attention_multiply.py b/comfy_extras/nodes_attention_multiply.py new file mode 100644 index 0000000000000000000000000000000000000000..4747eb39568afe9e19d34fd583531fcab1acba69 --- /dev/null +++ b/comfy_extras/nodes_attention_multiply.py @@ -0,0 +1,120 @@ + +def attention_multiply(attn, model, q, k, v, out): + m = model.clone() + sd = model.model_state_dict() + + for key in sd: + if key.endswith("{}.to_q.bias".format(attn)) or key.endswith("{}.to_q.weight".format(attn)): + m.add_patches({key: (None,)}, 0.0, q) + if key.endswith("{}.to_k.bias".format(attn)) or key.endswith("{}.to_k.weight".format(attn)): + m.add_patches({key: (None,)}, 0.0, k) + if key.endswith("{}.to_v.bias".format(attn)) or key.endswith("{}.to_v.weight".format(attn)): + m.add_patches({key: (None,)}, 0.0, v) + if key.endswith("{}.to_out.0.bias".format(attn)) or key.endswith("{}.to_out.0.weight".format(attn)): + m.add_patches({key: (None,)}, 0.0, out) + + return m + + +class UNetSelfAttentionMultiply: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing/attention_experiments" + + def patch(self, model, q, k, v, out): + m = attention_multiply("attn1", model, q, k, v, out) + return (m, ) + +class UNetCrossAttentionMultiply: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing/attention_experiments" + + def patch(self, model, q, k, v, out): + m = attention_multiply("attn2", model, q, k, v, out) + return (m, ) + +class CLIPAttentionMultiply: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip": ("CLIP",), + "q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("CLIP",) + FUNCTION = "patch" + + CATEGORY = "_for_testing/attention_experiments" + + def patch(self, clip, q, k, v, out): + m = clip.clone() + sd = m.patcher.model_state_dict() + + for key in sd: + if key.endswith("self_attn.q_proj.weight") or key.endswith("self_attn.q_proj.bias"): + m.add_patches({key: (None,)}, 0.0, q) + if key.endswith("self_attn.k_proj.weight") or key.endswith("self_attn.k_proj.bias"): + m.add_patches({key: (None,)}, 0.0, k) + if key.endswith("self_attn.v_proj.weight") or key.endswith("self_attn.v_proj.bias"): + m.add_patches({key: (None,)}, 0.0, v) + if key.endswith("self_attn.out_proj.weight") or key.endswith("self_attn.out_proj.bias"): + m.add_patches({key: (None,)}, 0.0, out) + return (m, ) + +class UNetTemporalAttentionMultiply: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "self_structural": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "self_temporal": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "cross_structural": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "cross_temporal": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing/attention_experiments" + + def patch(self, model, self_structural, self_temporal, cross_structural, cross_temporal): + m = model.clone() + sd = model.model_state_dict() + + for k in sd: + if (k.endswith("attn1.to_out.0.bias") or k.endswith("attn1.to_out.0.weight")): + if '.time_stack.' in k: + m.add_patches({k: (None,)}, 0.0, self_temporal) + else: + m.add_patches({k: (None,)}, 0.0, self_structural) + elif (k.endswith("attn2.to_out.0.bias") or k.endswith("attn2.to_out.0.weight")): + if '.time_stack.' in k: + m.add_patches({k: (None,)}, 0.0, cross_temporal) + else: + m.add_patches({k: (None,)}, 0.0, cross_structural) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "UNetSelfAttentionMultiply": UNetSelfAttentionMultiply, + "UNetCrossAttentionMultiply": UNetCrossAttentionMultiply, + "CLIPAttentionMultiply": CLIPAttentionMultiply, + "UNetTemporalAttentionMultiply": UNetTemporalAttentionMultiply, +} diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..1c763f259185795fbbe64ae560fe30ef1b38a008 --- /dev/null +++ b/comfy_extras/nodes_audio.py @@ -0,0 +1,227 @@ +import torchaudio +import torch +import comfy.model_management +import folder_paths +import os +import io +import json +import struct +import random +import hashlib +from comfy.cli_args import args + +class EmptyLatentAudio: + def __init__(self): + self.device = comfy.model_management.intermediate_device() + + @classmethod + def INPUT_TYPES(s): + return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1})}} + RETURN_TYPES = ("LATENT",) + FUNCTION = "generate" + + CATEGORY = "latent/audio" + + def generate(self, seconds): + batch_size = 1 + length = round((seconds * 44100 / 2048) / 2) * 2 + latent = torch.zeros([batch_size, 64, length], device=self.device) + return ({"samples":latent, "type": "audio"}, ) + +class VAEEncodeAudio: + @classmethod + def INPUT_TYPES(s): + return {"required": { "audio": ("AUDIO", ), "vae": ("VAE", )}} + RETURN_TYPES = ("LATENT",) + FUNCTION = "encode" + + CATEGORY = "latent/audio" + + def encode(self, vae, audio): + sample_rate = audio["sample_rate"] + if 44100 != sample_rate: + waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100) + else: + waveform = audio["waveform"] + + t = vae.encode(waveform.movedim(1, -1)) + return ({"samples":t}, ) + +class VAEDecodeAudio: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}} + RETURN_TYPES = ("AUDIO",) + FUNCTION = "decode" + + CATEGORY = "latent/audio" + + def decode(self, vae, samples): + audio = vae.decode(samples["samples"]).movedim(-1, 1) + std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0 + std[std < 1.0] = 1.0 + audio /= std + return ({"waveform": audio, "sample_rate": 44100}, ) + + +def create_vorbis_comment_block(comment_dict, last_block): + vendor_string = b'ComfyUI' + vendor_length = len(vendor_string) + + comments = [] + for key, value in comment_dict.items(): + comment = f"{key}={value}".encode('utf-8') + comments.append(struct.pack('I', len(comment_data))[1:] + comment_data + + return comment_block + +def insert_or_replace_vorbis_comment(flac_io, comment_dict): + if len(comment_dict) == 0: + return flac_io + + flac_io.seek(4) + + blocks = [] + last_block = False + + while not last_block: + header = flac_io.read(4) + last_block = (header[0] & 0x80) != 0 + block_type = header[0] & 0x7F + block_length = struct.unpack('>I', b'\x00' + header[1:])[0] + block_data = flac_io.read(block_length) + + if block_type == 4 or block_type == 1: + pass + else: + header = bytes([(header[0] & (~0x80))]) + header[1:] + blocks.append(header + block_data) + + blocks.append(create_vorbis_comment_block(comment_dict, last_block=True)) + + new_flac_io = io.BytesIO() + new_flac_io.write(b'fLaC') + for block in blocks: + new_flac_io.write(block) + + new_flac_io.write(flac_io.read()) + return new_flac_io + + +class SaveAudio: + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + self.type = "output" + self.prefix_append = "" + + @classmethod + def INPUT_TYPES(s): + return {"required": { "audio": ("AUDIO", ), + "filename_prefix": ("STRING", {"default": "audio/ComfyUI"})}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + + RETURN_TYPES = () + FUNCTION = "save_audio" + + OUTPUT_NODE = True + + CATEGORY = "audio" + + def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): + filename_prefix += self.prefix_append + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) + results = list() + + metadata = {} + if not args.disable_metadata: + if prompt is not None: + metadata["prompt"] = json.dumps(prompt) + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata[x] = json.dumps(extra_pnginfo[x]) + + for (batch_number, waveform) in enumerate(audio["waveform"].cpu()): + filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) + file = f"{filename_with_batch_num}_{counter:05}_.flac" + + buff = io.BytesIO() + torchaudio.save(buff, waveform, audio["sample_rate"], format="FLAC") + + buff = insert_or_replace_vorbis_comment(buff, metadata) + + with open(os.path.join(full_output_folder, file), 'wb') as f: + f.write(buff.getbuffer()) + + results.append({ + "filename": file, + "subfolder": subfolder, + "type": self.type + }) + counter += 1 + + return { "ui": { "audio": results } } + +class PreviewAudio(SaveAudio): + def __init__(self): + self.output_dir = folder_paths.get_temp_directory() + self.type = "temp" + self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) + + @classmethod + def INPUT_TYPES(s): + return {"required": + {"audio": ("AUDIO", ), }, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + +class LoadAudio: + @classmethod + def INPUT_TYPES(s): + input_dir = folder_paths.get_input_directory() + files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"]) + return {"required": {"audio": (sorted(files), {"audio_upload": True})}} + + CATEGORY = "audio" + + RETURN_TYPES = ("AUDIO", ) + FUNCTION = "load" + + def load(self, audio): + audio_path = folder_paths.get_annotated_filepath(audio) + waveform, sample_rate = torchaudio.load(audio_path) + audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate} + return (audio, ) + + @classmethod + def IS_CHANGED(s, audio): + image_path = folder_paths.get_annotated_filepath(audio) + m = hashlib.sha256() + with open(image_path, 'rb') as f: + m.update(f.read()) + return m.digest().hex() + + @classmethod + def VALIDATE_INPUTS(s, audio): + if not folder_paths.exists_annotated_filepath(audio): + return "Invalid audio file: {}".format(audio) + return True + +NODE_CLASS_MAPPINGS = { + "EmptyLatentAudio": EmptyLatentAudio, + "VAEEncodeAudio": VAEEncodeAudio, + "VAEDecodeAudio": VAEDecodeAudio, + "SaveAudio": SaveAudio, + "LoadAudio": LoadAudio, + "PreviewAudio": PreviewAudio, +} diff --git a/comfy_extras/nodes_canny.py b/comfy_extras/nodes_canny.py new file mode 100644 index 0000000000000000000000000000000000000000..d85e6b85691dcdcc55e31705039934846d466fc1 --- /dev/null +++ b/comfy_extras/nodes_canny.py @@ -0,0 +1,25 @@ +from kornia.filters import canny +import comfy.model_management + + +class Canny: + @classmethod + def INPUT_TYPES(s): + return {"required": {"image": ("IMAGE",), + "low_threshold": ("FLOAT", {"default": 0.4, "min": 0.01, "max": 0.99, "step": 0.01}), + "high_threshold": ("FLOAT", {"default": 0.8, "min": 0.01, "max": 0.99, "step": 0.01}) + }} + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "detect_edge" + + CATEGORY = "image/preprocessors" + + def detect_edge(self, image, low_threshold, high_threshold): + output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold) + img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1) + return (img_out,) + +NODE_CLASS_MAPPINGS = { + "Canny": Canny, +} diff --git a/comfy_extras/nodes_clip_sdxl.py b/comfy_extras/nodes_clip_sdxl.py new file mode 100644 index 0000000000000000000000000000000000000000..3087b917b41dba951d69029bdd278c258d97e40f --- /dev/null +++ b/comfy_extras/nodes_clip_sdxl.py @@ -0,0 +1,56 @@ +import torch +from nodes import MAX_RESOLUTION + +class CLIPTextEncodeSDXLRefiner: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "ascore": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 1000.0, "step": 0.01}), + "width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), + "height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), + "text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ), + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "encode" + + CATEGORY = "advanced/conditioning" + + def encode(self, clip, ascore, width, height, text): + tokens = clip.tokenize(text) + cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) + return ([[cond, {"pooled_output": pooled, "aesthetic_score": ascore, "width": width,"height": height}]], ) + +class CLIPTextEncodeSDXL: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), + "height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), + "crop_w": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}), + "crop_h": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}), + "target_width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), + "target_height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), + "text_g": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ), + "text_l": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ), + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "encode" + + CATEGORY = "advanced/conditioning" + + def encode(self, clip, width, height, crop_w, crop_h, target_width, target_height, text_g, text_l): + tokens = clip.tokenize(text_g) + tokens["l"] = clip.tokenize(text_l)["l"] + if len(tokens["l"]) != len(tokens["g"]): + empty = clip.tokenize("") + while len(tokens["l"]) < len(tokens["g"]): + tokens["l"] += empty["l"] + while len(tokens["l"]) > len(tokens["g"]): + tokens["g"] += empty["g"] + cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) + return ([[cond, {"pooled_output": pooled, "width": width, "height": height, "crop_w": crop_w, "crop_h": crop_h, "target_width": target_width, "target_height": target_height}]], ) + +NODE_CLASS_MAPPINGS = { + "CLIPTextEncodeSDXLRefiner": CLIPTextEncodeSDXLRefiner, + "CLIPTextEncodeSDXL": CLIPTextEncodeSDXL, +} diff --git a/comfy_extras/nodes_compositing.py b/comfy_extras/nodes_compositing.py new file mode 100644 index 0000000000000000000000000000000000000000..48fe5e3ddc60e1e150299c35f044cae831c02300 --- /dev/null +++ b/comfy_extras/nodes_compositing.py @@ -0,0 +1,215 @@ +import numpy as np +import torch +import comfy.utils +from enum import Enum + +def resize_mask(mask, shape): + return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1) + +class PorterDuffMode(Enum): + ADD = 0 + CLEAR = 1 + DARKEN = 2 + DST = 3 + DST_ATOP = 4 + DST_IN = 5 + DST_OUT = 6 + DST_OVER = 7 + LIGHTEN = 8 + MULTIPLY = 9 + OVERLAY = 10 + SCREEN = 11 + SRC = 12 + SRC_ATOP = 13 + SRC_IN = 14 + SRC_OUT = 15 + SRC_OVER = 16 + XOR = 17 + + +def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_image: torch.Tensor, dst_alpha: torch.Tensor, mode: PorterDuffMode): + # convert mask to alpha + src_alpha = 1 - src_alpha + dst_alpha = 1 - dst_alpha + # premultiply alpha + src_image = src_image * src_alpha + dst_image = dst_image * dst_alpha + + # composite ops below assume alpha-premultiplied images + if mode == PorterDuffMode.ADD: + out_alpha = torch.clamp(src_alpha + dst_alpha, 0, 1) + out_image = torch.clamp(src_image + dst_image, 0, 1) + elif mode == PorterDuffMode.CLEAR: + out_alpha = torch.zeros_like(dst_alpha) + out_image = torch.zeros_like(dst_image) + elif mode == PorterDuffMode.DARKEN: + out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha + out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.min(src_image, dst_image) + elif mode == PorterDuffMode.DST: + out_alpha = dst_alpha + out_image = dst_image + elif mode == PorterDuffMode.DST_ATOP: + out_alpha = src_alpha + out_image = src_alpha * dst_image + (1 - dst_alpha) * src_image + elif mode == PorterDuffMode.DST_IN: + out_alpha = src_alpha * dst_alpha + out_image = dst_image * src_alpha + elif mode == PorterDuffMode.DST_OUT: + out_alpha = (1 - src_alpha) * dst_alpha + out_image = (1 - src_alpha) * dst_image + elif mode == PorterDuffMode.DST_OVER: + out_alpha = dst_alpha + (1 - dst_alpha) * src_alpha + out_image = dst_image + (1 - dst_alpha) * src_image + elif mode == PorterDuffMode.LIGHTEN: + out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha + out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.max(src_image, dst_image) + elif mode == PorterDuffMode.MULTIPLY: + out_alpha = src_alpha * dst_alpha + out_image = src_image * dst_image + elif mode == PorterDuffMode.OVERLAY: + out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha + out_image = torch.where(2 * dst_image < dst_alpha, 2 * src_image * dst_image, + src_alpha * dst_alpha - 2 * (dst_alpha - src_image) * (src_alpha - dst_image)) + elif mode == PorterDuffMode.SCREEN: + out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha + out_image = src_image + dst_image - src_image * dst_image + elif mode == PorterDuffMode.SRC: + out_alpha = src_alpha + out_image = src_image + elif mode == PorterDuffMode.SRC_ATOP: + out_alpha = dst_alpha + out_image = dst_alpha * src_image + (1 - src_alpha) * dst_image + elif mode == PorterDuffMode.SRC_IN: + out_alpha = src_alpha * dst_alpha + out_image = src_image * dst_alpha + elif mode == PorterDuffMode.SRC_OUT: + out_alpha = (1 - dst_alpha) * src_alpha + out_image = (1 - dst_alpha) * src_image + elif mode == PorterDuffMode.SRC_OVER: + out_alpha = src_alpha + (1 - src_alpha) * dst_alpha + out_image = src_image + (1 - src_alpha) * dst_image + elif mode == PorterDuffMode.XOR: + out_alpha = (1 - dst_alpha) * src_alpha + (1 - src_alpha) * dst_alpha + out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + else: + return None, None + + # back to non-premultiplied alpha + out_image = torch.where(out_alpha > 1e-5, out_image / out_alpha, torch.zeros_like(out_image)) + out_image = torch.clamp(out_image, 0, 1) + # convert alpha to mask + out_alpha = 1 - out_alpha + return out_image, out_alpha + + +class PorterDuffImageComposite: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "source": ("IMAGE",), + "source_alpha": ("MASK",), + "destination": ("IMAGE",), + "destination_alpha": ("MASK",), + "mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}), + }, + } + + RETURN_TYPES = ("IMAGE", "MASK") + FUNCTION = "composite" + CATEGORY = "mask/compositing" + + def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode): + batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha)) + out_images = [] + out_alphas = [] + + for i in range(batch_size): + src_image = source[i] + dst_image = destination[i] + + assert src_image.shape[2] == dst_image.shape[2] # inputs need to have same number of channels + + src_alpha = source_alpha[i].unsqueeze(2) + dst_alpha = destination_alpha[i].unsqueeze(2) + + if dst_alpha.shape[:2] != dst_image.shape[:2]: + upscale_input = dst_alpha.unsqueeze(0).permute(0, 3, 1, 2) + upscale_output = comfy.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center') + dst_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0) + if src_image.shape != dst_image.shape: + upscale_input = src_image.unsqueeze(0).permute(0, 3, 1, 2) + upscale_output = comfy.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center') + src_image = upscale_output.permute(0, 2, 3, 1).squeeze(0) + if src_alpha.shape != dst_alpha.shape: + upscale_input = src_alpha.unsqueeze(0).permute(0, 3, 1, 2) + upscale_output = comfy.utils.common_upscale(upscale_input, dst_alpha.shape[1], dst_alpha.shape[0], upscale_method='bicubic', crop='center') + src_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0) + + out_image, out_alpha = porter_duff_composite(src_image, src_alpha, dst_image, dst_alpha, PorterDuffMode[mode]) + + out_images.append(out_image) + out_alphas.append(out_alpha.squeeze(2)) + + result = (torch.stack(out_images), torch.stack(out_alphas)) + return result + + +class SplitImageWithAlpha: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + } + } + + CATEGORY = "mask/compositing" + RETURN_TYPES = ("IMAGE", "MASK") + FUNCTION = "split_image_with_alpha" + + def split_image_with_alpha(self, image: torch.Tensor): + out_images = [i[:,:,:3] for i in image] + out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image] + result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas)) + return result + + +class JoinImageWithAlpha: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "alpha": ("MASK",), + } + } + + CATEGORY = "mask/compositing" + RETURN_TYPES = ("IMAGE",) + FUNCTION = "join_image_with_alpha" + + def join_image_with_alpha(self, image: torch.Tensor, alpha: torch.Tensor): + batch_size = min(len(image), len(alpha)) + out_images = [] + + alpha = 1.0 - resize_mask(alpha, image.shape[1:]) + for i in range(batch_size): + out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2)) + + result = (torch.stack(out_images),) + return result + + +NODE_CLASS_MAPPINGS = { + "PorterDuffImageComposite": PorterDuffImageComposite, + "SplitImageWithAlpha": SplitImageWithAlpha, + "JoinImageWithAlpha": JoinImageWithAlpha, +} + + +NODE_DISPLAY_NAME_MAPPINGS = { + "PorterDuffImageComposite": "Porter-Duff Image Composite", + "SplitImageWithAlpha": "Split Image with Alpha", + "JoinImageWithAlpha": "Join Image with Alpha", +} diff --git a/comfy_extras/nodes_cond.py b/comfy_extras/nodes_cond.py new file mode 100644 index 0000000000000000000000000000000000000000..4c3a1d5bf63e732c1d738b49e47da8ff7714f8a1 --- /dev/null +++ b/comfy_extras/nodes_cond.py @@ -0,0 +1,25 @@ + + +class CLIPTextEncodeControlnet: + @classmethod + def INPUT_TYPES(s): + return {"required": {"clip": ("CLIP", ), "conditioning": ("CONDITIONING", ), "text": ("STRING", {"multiline": True, "dynamicPrompts": True})}} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "encode" + + CATEGORY = "_for_testing/conditioning" + + def encode(self, clip, conditioning, text): + tokens = clip.tokenize(text) + cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) + c = [] + for t in conditioning: + n = [t[0], t[1].copy()] + n[1]['cross_attn_controlnet'] = cond + n[1]['pooled_output_controlnet'] = pooled + c.append(n) + return (c, ) + +NODE_CLASS_MAPPINGS = { + "CLIPTextEncodeControlnet": CLIPTextEncodeControlnet +} diff --git a/comfy_extras/nodes_controlnet.py b/comfy_extras/nodes_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..2d20e1fed7c26c8115f4b9878b0b54911e7a2e7f --- /dev/null +++ b/comfy_extras/nodes_controlnet.py @@ -0,0 +1,60 @@ +from comfy.cldm.control_types import UNION_CONTROLNET_TYPES +import nodes +import comfy.utils + +class SetUnionControlNetType: + @classmethod + def INPUT_TYPES(s): + return {"required": {"control_net": ("CONTROL_NET", ), + "type": (["auto"] + list(UNION_CONTROLNET_TYPES.keys()),) + }} + + CATEGORY = "conditioning/controlnet" + RETURN_TYPES = ("CONTROL_NET",) + + FUNCTION = "set_controlnet_type" + + def set_controlnet_type(self, control_net, type): + control_net = control_net.copy() + type_number = UNION_CONTROLNET_TYPES.get(type, -1) + if type_number >= 0: + control_net.set_extra_arg("control_type", [type_number]) + else: + control_net.set_extra_arg("control_type", []) + + return (control_net,) + +class ControlNetInpaintingAliMamaApply(nodes.ControlNetApplyAdvanced): + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "control_net": ("CONTROL_NET", ), + "vae": ("VAE", ), + "image": ("IMAGE", ), + "mask": ("MASK", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) + }} + + FUNCTION = "apply_inpaint_controlnet" + + CATEGORY = "conditioning/controlnet" + + def apply_inpaint_controlnet(self, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent): + extra_concat = [] + if control_net.concat_mask: + mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])) + mask_apply = comfy.utils.common_upscale(mask, image.shape[2], image.shape[1], "bilinear", "center").round() + image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3]) + extra_concat = [mask] + + return self.apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat) + + + +NODE_CLASS_MAPPINGS = { + "SetUnionControlNetType": SetUnionControlNetType, + "ControlNetInpaintingAliMamaApply": ControlNetInpaintingAliMamaApply, +} diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..219975e2787e414c93b41a7beb783bcad23019d5 --- /dev/null +++ b/comfy_extras/nodes_custom_sampler.py @@ -0,0 +1,703 @@ +import comfy.samplers +import comfy.sample +from comfy.k_diffusion import sampling as k_diffusion_sampling +import latent_preview +import torch +import comfy.utils +import node_helpers + + +class BasicScheduler: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "scheduler": (comfy.samplers.SCHEDULER_NAMES, ), + "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), + "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + } + } + RETURN_TYPES = ("SIGMAS",) + CATEGORY = "sampling/custom_sampling/schedulers" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, model, scheduler, steps, denoise): + total_steps = steps + if denoise < 1.0: + if denoise <= 0.0: + return (torch.FloatTensor([]),) + total_steps = int(steps/denoise) + + sigmas = comfy.samplers.calculate_sigmas(model.get_model_object("model_sampling"), scheduler, total_steps).cpu() + sigmas = sigmas[-(steps + 1):] + return (sigmas, ) + + +class KarrasScheduler: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), + "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), + "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), + "rho": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + } + } + RETURN_TYPES = ("SIGMAS",) + CATEGORY = "sampling/custom_sampling/schedulers" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, steps, sigma_max, sigma_min, rho): + sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho) + return (sigmas, ) + +class ExponentialScheduler: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), + "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), + "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), + } + } + RETURN_TYPES = ("SIGMAS",) + CATEGORY = "sampling/custom_sampling/schedulers" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, steps, sigma_max, sigma_min): + sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max) + return (sigmas, ) + +class PolyexponentialScheduler: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), + "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), + "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), + "rho": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + } + } + RETURN_TYPES = ("SIGMAS",) + CATEGORY = "sampling/custom_sampling/schedulers" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, steps, sigma_max, sigma_min, rho): + sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho) + return (sigmas, ) + +class SDTurboScheduler: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "steps": ("INT", {"default": 1, "min": 1, "max": 10}), + "denoise": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}), + } + } + RETURN_TYPES = ("SIGMAS",) + CATEGORY = "sampling/custom_sampling/schedulers" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, model, steps, denoise): + start_step = 10 - int(10 * denoise) + timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps] + sigmas = model.get_model_object("model_sampling").sigma(timesteps) + sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) + return (sigmas, ) + +class BetaSamplingScheduler: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), + "alpha": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 50.0, "step":0.01, "round": False}), + "beta": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 50.0, "step":0.01, "round": False}), + } + } + RETURN_TYPES = ("SIGMAS",) + CATEGORY = "sampling/custom_sampling/schedulers" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, model, steps, alpha, beta): + sigmas = comfy.samplers.beta_scheduler(model.get_model_object("model_sampling"), steps, alpha=alpha, beta=beta) + return (sigmas, ) + +class VPScheduler: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), + "beta_d": ("FLOAT", {"default": 19.9, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), #TODO: fix default values + "beta_min": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), + "eps_s": ("FLOAT", {"default": 0.001, "min": 0.0, "max": 1.0, "step":0.0001, "round": False}), + } + } + RETURN_TYPES = ("SIGMAS",) + CATEGORY = "sampling/custom_sampling/schedulers" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, steps, beta_d, beta_min, eps_s): + sigmas = k_diffusion_sampling.get_sigmas_vp(n=steps, beta_d=beta_d, beta_min=beta_min, eps_s=eps_s) + return (sigmas, ) + +class SplitSigmas: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"sigmas": ("SIGMAS", ), + "step": ("INT", {"default": 0, "min": 0, "max": 10000}), + } + } + RETURN_TYPES = ("SIGMAS","SIGMAS") + RETURN_NAMES = ("high_sigmas", "low_sigmas") + CATEGORY = "sampling/custom_sampling/sigmas" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, sigmas, step): + sigmas1 = sigmas[:step + 1] + sigmas2 = sigmas[step:] + return (sigmas1, sigmas2) + +class SplitSigmasDenoise: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"sigmas": ("SIGMAS", ), + "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + } + } + RETURN_TYPES = ("SIGMAS","SIGMAS") + RETURN_NAMES = ("high_sigmas", "low_sigmas") + CATEGORY = "sampling/custom_sampling/sigmas" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, sigmas, denoise): + steps = max(sigmas.shape[-1] - 1, 0) + total_steps = round(steps * denoise) + sigmas1 = sigmas[:-(total_steps)] + sigmas2 = sigmas[-(total_steps + 1):] + return (sigmas1, sigmas2) + +class FlipSigmas: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"sigmas": ("SIGMAS", ), + } + } + RETURN_TYPES = ("SIGMAS",) + CATEGORY = "sampling/custom_sampling/sigmas" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, sigmas): + if len(sigmas) == 0: + return (sigmas,) + + sigmas = sigmas.flip(0) + if sigmas[0] == 0: + sigmas[0] = 0.0001 + return (sigmas,) + +class KSamplerSelect: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"sampler_name": (comfy.samplers.SAMPLER_NAMES, ), + } + } + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "sampling/custom_sampling/samplers" + + FUNCTION = "get_sampler" + + def get_sampler(self, sampler_name): + sampler = comfy.samplers.sampler_object(sampler_name) + return (sampler, ) + +class SamplerDPMPP_3M_SDE: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "noise_device": (['gpu', 'cpu'], ), + } + } + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "sampling/custom_sampling/samplers" + + FUNCTION = "get_sampler" + + def get_sampler(self, eta, s_noise, noise_device): + if noise_device == 'cpu': + sampler_name = "dpmpp_3m_sde" + else: + sampler_name = "dpmpp_3m_sde_gpu" + sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise}) + return (sampler, ) + +class SamplerDPMPP_2M_SDE: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"solver_type": (['midpoint', 'heun'], ), + "eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "noise_device": (['gpu', 'cpu'], ), + } + } + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "sampling/custom_sampling/samplers" + + FUNCTION = "get_sampler" + + def get_sampler(self, solver_type, eta, s_noise, noise_device): + if noise_device == 'cpu': + sampler_name = "dpmpp_2m_sde" + else: + sampler_name = "dpmpp_2m_sde_gpu" + sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type}) + return (sampler, ) + + +class SamplerDPMPP_SDE: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "r": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "noise_device": (['gpu', 'cpu'], ), + } + } + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "sampling/custom_sampling/samplers" + + FUNCTION = "get_sampler" + + def get_sampler(self, eta, s_noise, r, noise_device): + if noise_device == 'cpu': + sampler_name = "dpmpp_sde" + else: + sampler_name = "dpmpp_sde_gpu" + sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r}) + return (sampler, ) + +class SamplerDPMPP_2S_Ancestral: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + } + } + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "sampling/custom_sampling/samplers" + + FUNCTION = "get_sampler" + + def get_sampler(self, eta, s_noise): + sampler = comfy.samplers.ksampler("dpmpp_2s_ancestral", {"eta": eta, "s_noise": s_noise}) + return (sampler, ) + +class SamplerEulerAncestral: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + } + } + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "sampling/custom_sampling/samplers" + + FUNCTION = "get_sampler" + + def get_sampler(self, eta, s_noise): + sampler = comfy.samplers.ksampler("euler_ancestral", {"eta": eta, "s_noise": s_noise}) + return (sampler, ) + +class SamplerEulerAncestralCFGPP: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step":0.01, "round": False}), + "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step":0.01, "round": False}), + }} + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "sampling/custom_sampling/samplers" + + FUNCTION = "get_sampler" + + def get_sampler(self, eta, s_noise): + sampler = comfy.samplers.ksampler( + "euler_ancestral_cfg_pp", + {"eta": eta, "s_noise": s_noise}) + return (sampler, ) + +class SamplerLMS: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"order": ("INT", {"default": 4, "min": 1, "max": 100}), + } + } + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "sampling/custom_sampling/samplers" + + FUNCTION = "get_sampler" + + def get_sampler(self, order): + sampler = comfy.samplers.ksampler("lms", {"order": order}) + return (sampler, ) + +class SamplerDPMAdaptative: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"order": ("INT", {"default": 3, "min": 2, "max": 3}), + "rtol": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "atol": ("FLOAT", {"default": 0.0078, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "h_init": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "pcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "icoeff": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "dcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "accept_safety": ("FLOAT", {"default": 0.81, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "eta": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), + } + } + RETURN_TYPES = ("SAMPLER",) + CATEGORY = "sampling/custom_sampling/samplers" + + FUNCTION = "get_sampler" + + def get_sampler(self, order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise): + sampler = comfy.samplers.ksampler("dpm_adaptive", {"order": order, "rtol": rtol, "atol": atol, "h_init": h_init, "pcoeff": pcoeff, + "icoeff": icoeff, "dcoeff": dcoeff, "accept_safety": accept_safety, "eta": eta, + "s_noise":s_noise }) + return (sampler, ) + +class Noise_EmptyNoise: + def __init__(self): + self.seed = 0 + + def generate_noise(self, input_latent): + latent_image = input_latent["samples"] + return torch.zeros(latent_image.shape, dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") + + +class Noise_RandomNoise: + def __init__(self, seed): + self.seed = seed + + def generate_noise(self, input_latent): + latent_image = input_latent["samples"] + batch_inds = input_latent["batch_index"] if "batch_index" in input_latent else None + return comfy.sample.prepare_noise(latent_image, self.seed, batch_inds) + +class SamplerCustom: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "add_noise": ("BOOLEAN", {"default": True}), + "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), + "positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "sampler": ("SAMPLER", ), + "sigmas": ("SIGMAS", ), + "latent_image": ("LATENT", ), + } + } + + RETURN_TYPES = ("LATENT","LATENT") + RETURN_NAMES = ("output", "denoised_output") + + FUNCTION = "sample" + + CATEGORY = "sampling/custom_sampling" + + def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image): + latent = latent_image + latent_image = latent["samples"] + latent = latent.copy() + latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image) + latent["samples"] = latent_image + + if not add_noise: + noise = Noise_EmptyNoise().generate_noise(latent) + else: + noise = Noise_RandomNoise(noise_seed).generate_noise(latent) + + noise_mask = None + if "noise_mask" in latent: + noise_mask = latent["noise_mask"] + + x0_output = {} + callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output) + + disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED + samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed) + + out = latent.copy() + out["samples"] = samples + if "x0" in x0_output: + out_denoised = latent.copy() + out_denoised["samples"] = model.model.process_latent_out(x0_output["x0"].cpu()) + else: + out_denoised = out + return (out, out_denoised) + +class Guider_Basic(comfy.samplers.CFGGuider): + def set_conds(self, positive): + self.inner_set_conds({"positive": positive}) + +class BasicGuider: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "conditioning": ("CONDITIONING", ), + } + } + + RETURN_TYPES = ("GUIDER",) + + FUNCTION = "get_guider" + CATEGORY = "sampling/custom_sampling/guiders" + + def get_guider(self, model, conditioning): + guider = Guider_Basic(model) + guider.set_conds(conditioning) + return (guider,) + +class CFGGuider: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), + } + } + + RETURN_TYPES = ("GUIDER",) + + FUNCTION = "get_guider" + CATEGORY = "sampling/custom_sampling/guiders" + + def get_guider(self, model, positive, negative, cfg): + guider = comfy.samplers.CFGGuider(model) + guider.set_conds(positive, negative) + guider.set_cfg(cfg) + return (guider,) + +class Guider_DualCFG(comfy.samplers.CFGGuider): + def set_cfg(self, cfg1, cfg2): + self.cfg1 = cfg1 + self.cfg2 = cfg2 + + def set_conds(self, positive, middle, negative): + middle = node_helpers.conditioning_set_values(middle, {"prompt_type": "negative"}) + self.inner_set_conds({"positive": positive, "middle": middle, "negative": negative}) + + def predict_noise(self, x, timestep, model_options={}, seed=None): + negative_cond = self.conds.get("negative", None) + middle_cond = self.conds.get("middle", None) + + out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, self.conds.get("positive", None)], x, timestep, model_options) + return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1 + +class DualCFGGuider: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "cond1": ("CONDITIONING", ), + "cond2": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "cfg_conds": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), + "cfg_cond2_negative": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), + } + } + + RETURN_TYPES = ("GUIDER",) + + FUNCTION = "get_guider" + CATEGORY = "sampling/custom_sampling/guiders" + + def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative): + guider = Guider_DualCFG(model) + guider.set_conds(cond1, cond2, negative) + guider.set_cfg(cfg_conds, cfg_cond2_negative) + return (guider,) + +class DisableNoise: + @classmethod + def INPUT_TYPES(s): + return {"required":{ + } + } + + RETURN_TYPES = ("NOISE",) + FUNCTION = "get_noise" + CATEGORY = "sampling/custom_sampling/noise" + + def get_noise(self): + return (Noise_EmptyNoise(),) + + +class RandomNoise(DisableNoise): + @classmethod + def INPUT_TYPES(s): + return {"required":{ + "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + } + } + + def get_noise(self, noise_seed): + return (Noise_RandomNoise(noise_seed),) + + +class SamplerCustomAdvanced: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"noise": ("NOISE", ), + "guider": ("GUIDER", ), + "sampler": ("SAMPLER", ), + "sigmas": ("SIGMAS", ), + "latent_image": ("LATENT", ), + } + } + + RETURN_TYPES = ("LATENT","LATENT") + RETURN_NAMES = ("output", "denoised_output") + + FUNCTION = "sample" + + CATEGORY = "sampling/custom_sampling" + + def sample(self, noise, guider, sampler, sigmas, latent_image): + latent = latent_image + latent_image = latent["samples"] + latent = latent.copy() + latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image) + latent["samples"] = latent_image + + noise_mask = None + if "noise_mask" in latent: + noise_mask = latent["noise_mask"] + + x0_output = {} + callback = latent_preview.prepare_callback(guider.model_patcher, sigmas.shape[-1] - 1, x0_output) + + disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED + samples = guider.sample(noise.generate_noise(latent), latent_image, sampler, sigmas, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise.seed) + samples = samples.to(comfy.model_management.intermediate_device()) + + out = latent.copy() + out["samples"] = samples + if "x0" in x0_output: + out_denoised = latent.copy() + out_denoised["samples"] = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu()) + else: + out_denoised = out + return (out, out_denoised) + +class AddNoise: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "noise": ("NOISE", ), + "sigmas": ("SIGMAS", ), + "latent_image": ("LATENT", ), + } + } + + RETURN_TYPES = ("LATENT",) + + FUNCTION = "add_noise" + + CATEGORY = "_for_testing/custom_sampling/noise" + + def add_noise(self, model, noise, sigmas, latent_image): + if len(sigmas) == 0: + return latent_image + + latent = latent_image + latent_image = latent["samples"] + + noisy = noise.generate_noise(latent) + + model_sampling = model.get_model_object("model_sampling") + process_latent_out = model.get_model_object("process_latent_out") + process_latent_in = model.get_model_object("process_latent_in") + + if len(sigmas) > 1: + scale = torch.abs(sigmas[0] - sigmas[-1]) + else: + scale = sigmas[0] + + if torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image. + latent_image = process_latent_in(latent_image) + noisy = model_sampling.noise_scaling(scale, noisy, latent_image) + noisy = process_latent_out(noisy) + noisy = torch.nan_to_num(noisy, nan=0.0, posinf=0.0, neginf=0.0) + + out = latent.copy() + out["samples"] = noisy + return (out,) + + +NODE_CLASS_MAPPINGS = { + "SamplerCustom": SamplerCustom, + "BasicScheduler": BasicScheduler, + "KarrasScheduler": KarrasScheduler, + "ExponentialScheduler": ExponentialScheduler, + "PolyexponentialScheduler": PolyexponentialScheduler, + "VPScheduler": VPScheduler, + "BetaSamplingScheduler": BetaSamplingScheduler, + "SDTurboScheduler": SDTurboScheduler, + "KSamplerSelect": KSamplerSelect, + "SamplerEulerAncestral": SamplerEulerAncestral, + "SamplerEulerAncestralCFGPP": SamplerEulerAncestralCFGPP, + "SamplerLMS": SamplerLMS, + "SamplerDPMPP_3M_SDE": SamplerDPMPP_3M_SDE, + "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE, + "SamplerDPMPP_SDE": SamplerDPMPP_SDE, + "SamplerDPMPP_2S_Ancestral": SamplerDPMPP_2S_Ancestral, + "SamplerDPMAdaptative": SamplerDPMAdaptative, + "SplitSigmas": SplitSigmas, + "SplitSigmasDenoise": SplitSigmasDenoise, + "FlipSigmas": FlipSigmas, + + "CFGGuider": CFGGuider, + "DualCFGGuider": DualCFGGuider, + "BasicGuider": BasicGuider, + "RandomNoise": RandomNoise, + "DisableNoise": DisableNoise, + "AddNoise": AddNoise, + "SamplerCustomAdvanced": SamplerCustomAdvanced, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "SamplerEulerAncestralCFGPP": "SamplerEulerAncestralCFG++", +} diff --git a/comfy_extras/nodes_differential_diffusion.py b/comfy_extras/nodes_differential_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..98dbbf102dac861cfb65ed19ad1af499abf7465d --- /dev/null +++ b/comfy_extras/nodes_differential_diffusion.py @@ -0,0 +1,42 @@ +# code adapted from https://github.com/exx8/differential-diffusion + +import torch + +class DifferentialDiffusion(): + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL", ), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "apply" + CATEGORY = "_for_testing" + INIT = False + + def apply(self, model): + model = model.clone() + model.set_model_denoise_mask_function(self.forward) + return (model,) + + def forward(self, sigma: torch.Tensor, denoise_mask: torch.Tensor, extra_options: dict): + model = extra_options["model"] + step_sigmas = extra_options["sigmas"] + sigma_to = model.inner_model.model_sampling.sigma_min + if step_sigmas[-1] > sigma_to: + sigma_to = step_sigmas[-1] + sigma_from = step_sigmas[0] + + ts_from = model.inner_model.model_sampling.timestep(sigma_from) + ts_to = model.inner_model.model_sampling.timestep(sigma_to) + current_ts = model.inner_model.model_sampling.timestep(sigma[0]) + + threshold = (current_ts - ts_to) / (ts_from - ts_to) + + return (denoise_mask >= threshold).to(denoise_mask.dtype) + + +NODE_CLASS_MAPPINGS = { + "DifferentialDiffusion": DifferentialDiffusion, +} +NODE_DISPLAY_NAME_MAPPINGS = { + "DifferentialDiffusion": "Differential Diffusion", +} diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py new file mode 100644 index 0000000000000000000000000000000000000000..b690432b55bfc0249db3daa924069bb9f0456d7c --- /dev/null +++ b/comfy_extras/nodes_flux.py @@ -0,0 +1,47 @@ +import node_helpers + +class CLIPTextEncodeFlux: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "clip": ("CLIP", ), + "clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}), + "t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}), + "guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}), + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "encode" + + CATEGORY = "advanced/conditioning/flux" + + def encode(self, clip, clip_l, t5xxl, guidance): + tokens = clip.tokenize(clip_l) + tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"] + + output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True) + cond = output.pop("cond") + output["guidance"] = guidance + return ([[cond, output]], ) + +class FluxGuidance: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "conditioning": ("CONDITIONING", ), + "guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}), + }} + + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "append" + + CATEGORY = "advanced/conditioning/flux" + + def append(self, conditioning, guidance): + c = node_helpers.conditioning_set_values(conditioning, {"guidance": guidance}) + return (c, ) + + +NODE_CLASS_MAPPINGS = { + "CLIPTextEncodeFlux": CLIPTextEncodeFlux, + "FluxGuidance": FluxGuidance, +} diff --git a/comfy_extras/nodes_freelunch.py b/comfy_extras/nodes_freelunch.py new file mode 100644 index 0000000000000000000000000000000000000000..e3ac58447b29f604debb5bfc0aed3a5f100a4ae9 --- /dev/null +++ b/comfy_extras/nodes_freelunch.py @@ -0,0 +1,113 @@ +#code originally taken from: https://github.com/ChenyangSi/FreeU (under MIT License) + +import torch +import logging + +def Fourier_filter(x, threshold, scale): + # FFT + x_freq = torch.fft.fftn(x.float(), dim=(-2, -1)) + x_freq = torch.fft.fftshift(x_freq, dim=(-2, -1)) + + B, C, H, W = x_freq.shape + mask = torch.ones((B, C, H, W), device=x.device) + + crow, ccol = H // 2, W //2 + mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale + x_freq = x_freq * mask + + # IFFT + x_freq = torch.fft.ifftshift(x_freq, dim=(-2, -1)) + x_filtered = torch.fft.ifftn(x_freq, dim=(-2, -1)).real + + return x_filtered.to(x.dtype) + + +class FreeU: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "b1": ("FLOAT", {"default": 1.1, "min": 0.0, "max": 10.0, "step": 0.01}), + "b2": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 10.0, "step": 0.01}), + "s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}), + "s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "model_patches/unet" + + def patch(self, model, b1, b2, s1, s2): + model_channels = model.model.model_config.unet_config["model_channels"] + scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)} + on_cpu_devices = {} + + def output_block_patch(h, hsp, transformer_options): + scale = scale_dict.get(int(h.shape[1]), None) + if scale is not None: + h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * scale[0] + if hsp.device not in on_cpu_devices: + try: + hsp = Fourier_filter(hsp, threshold=1, scale=scale[1]) + except: + logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device)) + on_cpu_devices[hsp.device] = True + hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device) + else: + hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device) + + return h, hsp + + m = model.clone() + m.set_model_output_block_patch(output_block_patch) + return (m, ) + +class FreeU_V2: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "b1": ("FLOAT", {"default": 1.3, "min": 0.0, "max": 10.0, "step": 0.01}), + "b2": ("FLOAT", {"default": 1.4, "min": 0.0, "max": 10.0, "step": 0.01}), + "s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}), + "s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "model_patches/unet" + + def patch(self, model, b1, b2, s1, s2): + model_channels = model.model.model_config.unet_config["model_channels"] + scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)} + on_cpu_devices = {} + + def output_block_patch(h, hsp, transformer_options): + scale = scale_dict.get(int(h.shape[1]), None) + if scale is not None: + hidden_mean = h.mean(1).unsqueeze(1) + B = hidden_mean.shape[0] + hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True) + hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True) + hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3) + + h[:,:h.shape[1] // 2] = h[:,:h.shape[1] // 2] * ((scale[0] - 1 ) * hidden_mean + 1) + + if hsp.device not in on_cpu_devices: + try: + hsp = Fourier_filter(hsp, threshold=1, scale=scale[1]) + except: + logging.warning("Device {} does not support the torch.fft functions used in the FreeU node, switching to CPU.".format(hsp.device)) + on_cpu_devices[hsp.device] = True + hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device) + else: + hsp = Fourier_filter(hsp.cpu(), threshold=1, scale=scale[1]).to(hsp.device) + + return h, hsp + + m = model.clone() + m.set_model_output_block_patch(output_block_patch) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "FreeU": FreeU, + "FreeU_V2": FreeU_V2, +} diff --git a/comfy_extras/nodes_gits.py b/comfy_extras/nodes_gits.py new file mode 100644 index 0000000000000000000000000000000000000000..7bfae4ce688721c3b671676c2fc92ed0708f499f --- /dev/null +++ b/comfy_extras/nodes_gits.py @@ -0,0 +1,369 @@ +# from https://github.com/zju-pi/diff-sampler/tree/main/gits-main +import numpy as np +import torch + +def loglinear_interp(t_steps, num_steps): + """ + Performs log-linear interpolation of a given array of decreasing numbers. + """ + xs = np.linspace(0, 1, len(t_steps)) + ys = np.log(t_steps[::-1]) + + new_xs = np.linspace(0, 1, num_steps) + new_ys = np.interp(new_xs, xs, ys) + + interped_ys = np.exp(new_ys)[::-1].copy() + return interped_ys + +NOISE_LEVELS = { + 0.80: [ + [14.61464119, 7.49001646, 0.02916753], + [14.61464119, 11.54541874, 6.77309084, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 3.07277966, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 5.85520077, 2.05039096, 0.02916753], + [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 2.05039096, 0.02916753], + [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 3.07277966, 1.56271636, 0.02916753], + [14.61464119, 12.96784878, 11.54541874, 8.75849152, 7.49001646, 5.85520077, 3.07277966, 1.56271636, 0.02916753], + [14.61464119, 13.76078796, 12.2308979, 10.90732002, 8.75849152, 7.49001646, 5.85520077, 3.07277966, 1.56271636, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 8.75849152, 7.49001646, 5.85520077, 3.07277966, 1.56271636, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 5.85520077, 3.07277966, 1.56271636, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.07277966, 1.56271636, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.07277966, 1.56271636, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.07277966, 1.56271636, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.07277966, 1.56271636, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.1956799, 1.98035145, 0.86115354, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.1956799, 1.98035145, 0.86115354, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.07277966, 1.84880662, 0.83188516, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.88507891, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.07277966, 1.84880662, 0.83188516, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.88507891, 7.49001646, 6.77309084, 5.85520077, 4.86714602, 3.75677586, 2.84484982, 1.78698075, 0.803307, 0.02916753], + ], + 0.85: [ + [14.61464119, 7.49001646, 0.02916753], + [14.61464119, 7.49001646, 1.84880662, 0.02916753], + [14.61464119, 11.54541874, 6.77309084, 1.56271636, 0.02916753], + [14.61464119, 11.54541874, 7.11996698, 3.07277966, 1.24153244, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 5.09240818, 2.84484982, 0.95350921, 0.02916753], + [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.09240818, 2.84484982, 0.95350921, 0.02916753], + [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.58536053, 3.1956799, 1.84880662, 0.803307, 0.02916753], + [14.61464119, 12.96784878, 11.54541874, 8.75849152, 7.49001646, 5.58536053, 3.1956799, 1.84880662, 0.803307, 0.02916753], + [14.61464119, 12.96784878, 11.54541874, 8.75849152, 7.49001646, 6.14220476, 4.65472794, 3.07277966, 1.84880662, 0.803307, 0.02916753], + [14.61464119, 13.76078796, 12.2308979, 10.90732002, 8.75849152, 7.49001646, 6.14220476, 4.65472794, 3.07277966, 1.84880662, 0.803307, 0.02916753], + [14.61464119, 13.76078796, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.65472794, 3.07277966, 1.84880662, 0.803307, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.65472794, 3.07277966, 1.84880662, 0.803307, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.65472794, 3.07277966, 1.84880662, 0.803307, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.60512662, 2.6383388, 1.56271636, 0.72133851, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.46139455, 2.45070267, 1.56271636, 0.72133851, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.46139455, 2.45070267, 1.56271636, 0.72133851, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.46139455, 2.45070267, 1.56271636, 0.72133851, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.46139455, 2.45070267, 1.56271636, 0.72133851, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.90732002, 10.31284904, 9.75859547, 9.24142551, 8.75849152, 8.30717278, 7.88507891, 7.49001646, 6.77309084, 5.85520077, 4.65472794, 3.46139455, 2.45070267, 1.56271636, 0.72133851, 0.02916753], + ], + 0.90: [ + [14.61464119, 6.77309084, 0.02916753], + [14.61464119, 7.49001646, 1.56271636, 0.02916753], + [14.61464119, 7.49001646, 3.07277966, 0.95350921, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 2.54230714, 0.89115214, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 4.86714602, 2.54230714, 0.89115214, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 5.09240818, 3.07277966, 1.61558151, 0.69515091, 0.02916753], + [14.61464119, 12.2308979, 8.75849152, 7.11996698, 4.86714602, 3.07277966, 1.61558151, 0.69515091, 0.02916753], + [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 2.95596409, 1.61558151, 0.69515091, 0.02916753], + [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.19988537, 1.24153244, 0.57119018, 0.02916753], + [14.61464119, 12.96784878, 10.90732002, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.19988537, 1.24153244, 0.57119018, 0.02916753], + [14.61464119, 12.96784878, 11.54541874, 9.24142551, 8.30717278, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.19988537, 1.24153244, 0.57119018, 0.02916753], + [14.61464119, 12.96784878, 11.54541874, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.75677586, 2.84484982, 1.84880662, 1.08895338, 0.52423614, 0.02916753], + [14.61464119, 13.76078796, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 4.86714602, 3.75677586, 2.84484982, 1.84880662, 1.08895338, 0.52423614, 0.02916753], + [14.61464119, 13.76078796, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.44769001, 5.58536053, 4.45427561, 3.32507086, 2.45070267, 1.61558151, 0.95350921, 0.45573691, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.44769001, 5.58536053, 4.45427561, 3.32507086, 2.45070267, 1.61558151, 0.95350921, 0.45573691, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.86714602, 3.91689563, 3.07277966, 2.27973175, 1.56271636, 0.95350921, 0.45573691, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.86714602, 3.91689563, 3.07277966, 2.27973175, 1.56271636, 0.95350921, 0.45573691, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 4.86714602, 3.91689563, 3.07277966, 2.27973175, 1.56271636, 0.95350921, 0.45573691, 0.02916753], + [14.61464119, 13.76078796, 12.96784878, 12.2308979, 11.54541874, 10.31284904, 9.24142551, 8.75849152, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 5.09240818, 4.45427561, 3.60512662, 2.95596409, 2.19988537, 1.51179266, 0.89115214, 0.43325692, 0.02916753], + ], + 0.95: [ + [14.61464119, 6.77309084, 0.02916753], + [14.61464119, 6.77309084, 1.56271636, 0.02916753], + [14.61464119, 7.49001646, 2.84484982, 0.89115214, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 2.36326075, 0.803307, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 2.95596409, 1.56271636, 0.64427125, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 4.86714602, 2.95596409, 1.56271636, 0.64427125, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 4.86714602, 3.07277966, 1.91321158, 1.08895338, 0.50118381, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.07277966, 1.91321158, 1.08895338, 0.50118381, 0.02916753], + [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 3.07277966, 1.91321158, 1.08895338, 0.50118381, 0.02916753], + [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.19988537, 1.41535246, 0.803307, 0.38853383, 0.02916753], + [14.61464119, 12.2308979, 8.75849152, 7.49001646, 5.85520077, 4.65472794, 3.46139455, 2.6383388, 1.84880662, 1.24153244, 0.72133851, 0.34370604, 0.02916753], + [14.61464119, 12.96784878, 10.90732002, 8.75849152, 7.49001646, 5.85520077, 4.65472794, 3.46139455, 2.6383388, 1.84880662, 1.24153244, 0.72133851, 0.34370604, 0.02916753], + [14.61464119, 12.96784878, 10.90732002, 8.75849152, 7.49001646, 6.14220476, 4.86714602, 3.75677586, 2.95596409, 2.19988537, 1.56271636, 1.05362725, 0.64427125, 0.32104823, 0.02916753], + [14.61464119, 12.96784878, 10.90732002, 8.75849152, 7.49001646, 6.44769001, 5.58536053, 4.65472794, 3.60512662, 2.95596409, 2.19988537, 1.56271636, 1.05362725, 0.64427125, 0.32104823, 0.02916753], + [14.61464119, 12.96784878, 11.54541874, 9.24142551, 8.30717278, 7.49001646, 6.44769001, 5.58536053, 4.65472794, 3.60512662, 2.95596409, 2.19988537, 1.56271636, 1.05362725, 0.64427125, 0.32104823, 0.02916753], + [14.61464119, 12.96784878, 11.54541874, 9.24142551, 8.30717278, 7.49001646, 6.44769001, 5.58536053, 4.65472794, 3.75677586, 3.07277966, 2.45070267, 1.78698075, 1.24153244, 0.83188516, 0.50118381, 0.22545385, 0.02916753], + [14.61464119, 12.96784878, 11.54541874, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 5.09240818, 4.45427561, 3.60512662, 2.95596409, 2.36326075, 1.72759056, 1.24153244, 0.83188516, 0.50118381, 0.22545385, 0.02916753], + [14.61464119, 13.76078796, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 5.09240818, 4.45427561, 3.60512662, 2.95596409, 2.36326075, 1.72759056, 1.24153244, 0.83188516, 0.50118381, 0.22545385, 0.02916753], + [14.61464119, 13.76078796, 12.2308979, 10.90732002, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 5.09240818, 4.45427561, 3.75677586, 3.07277966, 2.45070267, 1.91321158, 1.46270394, 1.05362725, 0.72133851, 0.43325692, 0.19894916, 0.02916753], + ], + 1.00: [ + [14.61464119, 1.56271636, 0.02916753], + [14.61464119, 6.77309084, 0.95350921, 0.02916753], + [14.61464119, 6.77309084, 2.36326075, 0.803307, 0.02916753], + [14.61464119, 7.11996698, 3.07277966, 1.56271636, 0.59516323, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 2.84484982, 1.41535246, 0.57119018, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 2.84484982, 1.61558151, 0.86115354, 0.38853383, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 4.86714602, 2.84484982, 1.61558151, 0.86115354, 0.38853383, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 4.86714602, 3.07277966, 1.98035145, 1.24153244, 0.72133851, 0.34370604, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.07277966, 1.98035145, 1.24153244, 0.72133851, 0.34370604, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.27973175, 1.51179266, 0.95350921, 0.54755926, 0.25053367, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.36326075, 1.61558151, 1.08895338, 0.72133851, 0.41087446, 0.17026083, 0.02916753], + [14.61464119, 11.54541874, 8.75849152, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.36326075, 1.61558151, 1.08895338, 0.72133851, 0.41087446, 0.17026083, 0.02916753], + [14.61464119, 11.54541874, 8.75849152, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.84484982, 2.12350607, 1.56271636, 1.08895338, 0.72133851, 0.41087446, 0.17026083, 0.02916753], + [14.61464119, 11.54541874, 8.75849152, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.84484982, 2.19988537, 1.61558151, 1.162866, 0.803307, 0.50118381, 0.27464288, 0.09824532, 0.02916753], + [14.61464119, 11.54541874, 8.75849152, 7.49001646, 5.85520077, 4.65472794, 3.75677586, 3.07277966, 2.45070267, 1.84880662, 1.36964464, 1.01931262, 0.72133851, 0.45573691, 0.25053367, 0.09824532, 0.02916753], + [14.61464119, 11.54541874, 8.75849152, 7.49001646, 6.14220476, 5.09240818, 4.26497746, 3.46139455, 2.84484982, 2.19988537, 1.67050016, 1.24153244, 0.92192322, 0.64427125, 0.43325692, 0.25053367, 0.09824532, 0.02916753], + [14.61464119, 11.54541874, 8.75849152, 7.49001646, 6.14220476, 5.09240818, 4.26497746, 3.60512662, 2.95596409, 2.45070267, 1.91321158, 1.51179266, 1.12534678, 0.83188516, 0.59516323, 0.38853383, 0.22545385, 0.09824532, 0.02916753], + [14.61464119, 12.2308979, 9.24142551, 8.30717278, 7.49001646, 6.14220476, 5.09240818, 4.26497746, 3.60512662, 2.95596409, 2.45070267, 1.91321158, 1.51179266, 1.12534678, 0.83188516, 0.59516323, 0.38853383, 0.22545385, 0.09824532, 0.02916753], + [14.61464119, 12.2308979, 9.24142551, 8.30717278, 7.49001646, 6.77309084, 5.85520077, 5.09240818, 4.26497746, 3.60512662, 2.95596409, 2.45070267, 1.91321158, 1.51179266, 1.12534678, 0.83188516, 0.59516323, 0.38853383, 0.22545385, 0.09824532, 0.02916753], + ], + 1.05: [ + [14.61464119, 0.95350921, 0.02916753], + [14.61464119, 6.77309084, 0.89115214, 0.02916753], + [14.61464119, 6.77309084, 2.05039096, 0.72133851, 0.02916753], + [14.61464119, 6.77309084, 2.84484982, 1.28281462, 0.52423614, 0.02916753], + [14.61464119, 6.77309084, 3.07277966, 1.61558151, 0.803307, 0.34370604, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 2.84484982, 1.56271636, 0.803307, 0.34370604, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 2.84484982, 1.61558151, 0.95350921, 0.52423614, 0.22545385, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 3.07277966, 1.98035145, 1.24153244, 0.74807048, 0.41087446, 0.17026083, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.27973175, 1.51179266, 0.95350921, 0.59516323, 0.34370604, 0.13792117, 0.02916753], + [14.61464119, 7.49001646, 5.09240818, 3.46139455, 2.45070267, 1.61558151, 1.08895338, 0.72133851, 0.45573691, 0.25053367, 0.09824532, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 5.09240818, 3.46139455, 2.45070267, 1.61558151, 1.08895338, 0.72133851, 0.45573691, 0.25053367, 0.09824532, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.36326075, 1.61558151, 1.08895338, 0.72133851, 0.45573691, 0.25053367, 0.09824532, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.45070267, 1.72759056, 1.24153244, 0.86115354, 0.59516323, 0.38853383, 0.22545385, 0.09824532, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.84484982, 2.19988537, 1.61558151, 1.162866, 0.83188516, 0.59516323, 0.38853383, 0.22545385, 0.09824532, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.84484982, 2.19988537, 1.67050016, 1.28281462, 0.95350921, 0.72133851, 0.52423614, 0.34370604, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.95596409, 2.36326075, 1.84880662, 1.41535246, 1.08895338, 0.83188516, 0.61951244, 0.45573691, 0.32104823, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.65472794, 3.60512662, 2.95596409, 2.45070267, 1.91321158, 1.51179266, 1.20157266, 0.95350921, 0.74807048, 0.57119018, 0.43325692, 0.29807833, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 11.54541874, 8.30717278, 7.11996698, 5.85520077, 4.65472794, 3.60512662, 2.95596409, 2.45070267, 1.91321158, 1.51179266, 1.20157266, 0.95350921, 0.74807048, 0.57119018, 0.43325692, 0.29807833, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 11.54541874, 8.30717278, 7.11996698, 5.85520077, 4.65472794, 3.60512662, 2.95596409, 2.45070267, 1.98035145, 1.61558151, 1.32549286, 1.08895338, 0.86115354, 0.69515091, 0.54755926, 0.41087446, 0.29807833, 0.19894916, 0.09824532, 0.02916753], + ], + 1.10: [ + [14.61464119, 0.89115214, 0.02916753], + [14.61464119, 2.36326075, 0.72133851, 0.02916753], + [14.61464119, 5.85520077, 1.61558151, 0.57119018, 0.02916753], + [14.61464119, 6.77309084, 2.45070267, 1.08895338, 0.45573691, 0.02916753], + [14.61464119, 6.77309084, 2.95596409, 1.56271636, 0.803307, 0.34370604, 0.02916753], + [14.61464119, 6.77309084, 3.07277966, 1.61558151, 0.89115214, 0.4783645, 0.19894916, 0.02916753], + [14.61464119, 6.77309084, 3.07277966, 1.84880662, 1.08895338, 0.64427125, 0.34370604, 0.13792117, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 2.84484982, 1.61558151, 0.95350921, 0.54755926, 0.27464288, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 2.95596409, 1.91321158, 1.24153244, 0.803307, 0.4783645, 0.25053367, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 3.07277966, 2.05039096, 1.41535246, 0.95350921, 0.64427125, 0.41087446, 0.22545385, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.27973175, 1.61558151, 1.12534678, 0.803307, 0.54755926, 0.36617002, 0.22545385, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 3.32507086, 2.45070267, 1.72759056, 1.24153244, 0.89115214, 0.64427125, 0.45573691, 0.32104823, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 5.09240818, 3.60512662, 2.84484982, 2.05039096, 1.51179266, 1.08895338, 0.803307, 0.59516323, 0.43325692, 0.29807833, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 5.09240818, 3.60512662, 2.84484982, 2.12350607, 1.61558151, 1.24153244, 0.95350921, 0.72133851, 0.54755926, 0.41087446, 0.29807833, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.45070267, 1.84880662, 1.41535246, 1.08895338, 0.83188516, 0.64427125, 0.50118381, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 5.85520077, 4.45427561, 3.1956799, 2.45070267, 1.91321158, 1.51179266, 1.20157266, 0.95350921, 0.74807048, 0.59516323, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 5.85520077, 4.45427561, 3.46139455, 2.84484982, 2.19988537, 1.72759056, 1.36964464, 1.08895338, 0.86115354, 0.69515091, 0.54755926, 0.43325692, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.46139455, 2.84484982, 2.19988537, 1.72759056, 1.36964464, 1.08895338, 0.86115354, 0.69515091, 0.54755926, 0.43325692, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 11.54541874, 7.49001646, 5.85520077, 4.45427561, 3.46139455, 2.84484982, 2.19988537, 1.72759056, 1.36964464, 1.08895338, 0.89115214, 0.72133851, 0.59516323, 0.4783645, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.09824532, 0.02916753], + ], + 1.15: [ + [14.61464119, 0.83188516, 0.02916753], + [14.61464119, 1.84880662, 0.59516323, 0.02916753], + [14.61464119, 5.85520077, 1.56271636, 0.52423614, 0.02916753], + [14.61464119, 5.85520077, 1.91321158, 0.83188516, 0.34370604, 0.02916753], + [14.61464119, 5.85520077, 2.45070267, 1.24153244, 0.59516323, 0.25053367, 0.02916753], + [14.61464119, 5.85520077, 2.84484982, 1.51179266, 0.803307, 0.41087446, 0.17026083, 0.02916753], + [14.61464119, 5.85520077, 2.84484982, 1.56271636, 0.89115214, 0.50118381, 0.25053367, 0.09824532, 0.02916753], + [14.61464119, 6.77309084, 3.07277966, 1.84880662, 1.12534678, 0.72133851, 0.43325692, 0.22545385, 0.09824532, 0.02916753], + [14.61464119, 6.77309084, 3.07277966, 1.91321158, 1.24153244, 0.803307, 0.52423614, 0.34370604, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 2.95596409, 1.91321158, 1.24153244, 0.803307, 0.52423614, 0.34370604, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 3.07277966, 2.05039096, 1.36964464, 0.95350921, 0.69515091, 0.4783645, 0.32104823, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 3.07277966, 2.12350607, 1.51179266, 1.08895338, 0.803307, 0.59516323, 0.43325692, 0.29807833, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 3.07277966, 2.12350607, 1.51179266, 1.08895338, 0.803307, 0.59516323, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 3.07277966, 2.19988537, 1.61558151, 1.24153244, 0.95350921, 0.74807048, 0.59516323, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.45070267, 1.78698075, 1.32549286, 1.01931262, 0.803307, 0.64427125, 0.50118381, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.45070267, 1.78698075, 1.32549286, 1.01931262, 0.803307, 0.64427125, 0.52423614, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.45070267, 1.84880662, 1.41535246, 1.12534678, 0.89115214, 0.72133851, 0.59516323, 0.4783645, 0.38853383, 0.32104823, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.45070267, 1.84880662, 1.41535246, 1.12534678, 0.89115214, 0.72133851, 0.59516323, 0.50118381, 0.41087446, 0.34370604, 0.27464288, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.86714602, 3.1956799, 2.45070267, 1.84880662, 1.41535246, 1.12534678, 0.89115214, 0.72133851, 0.59516323, 0.50118381, 0.41087446, 0.34370604, 0.29807833, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + ], + 1.20: [ + [14.61464119, 0.803307, 0.02916753], + [14.61464119, 1.56271636, 0.52423614, 0.02916753], + [14.61464119, 2.36326075, 0.92192322, 0.36617002, 0.02916753], + [14.61464119, 2.84484982, 1.24153244, 0.59516323, 0.25053367, 0.02916753], + [14.61464119, 5.85520077, 2.05039096, 0.95350921, 0.45573691, 0.17026083, 0.02916753], + [14.61464119, 5.85520077, 2.45070267, 1.24153244, 0.64427125, 0.29807833, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.45070267, 1.36964464, 0.803307, 0.45573691, 0.25053367, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.84484982, 1.61558151, 0.95350921, 0.59516323, 0.36617002, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.84484982, 1.67050016, 1.08895338, 0.74807048, 0.50118381, 0.32104823, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.95596409, 1.84880662, 1.24153244, 0.83188516, 0.59516323, 0.41087446, 0.27464288, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 3.07277966, 1.98035145, 1.36964464, 0.95350921, 0.69515091, 0.50118381, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 6.77309084, 3.46139455, 2.36326075, 1.56271636, 1.08895338, 0.803307, 0.59516323, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 6.77309084, 3.46139455, 2.45070267, 1.61558151, 1.162866, 0.86115354, 0.64427125, 0.50118381, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.12350607, 1.51179266, 1.08895338, 0.83188516, 0.64427125, 0.50118381, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.12350607, 1.51179266, 1.08895338, 0.83188516, 0.64427125, 0.50118381, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.12350607, 1.51179266, 1.08895338, 0.83188516, 0.64427125, 0.50118381, 0.41087446, 0.34370604, 0.27464288, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.19988537, 1.61558151, 1.20157266, 0.92192322, 0.72133851, 0.57119018, 0.45573691, 0.36617002, 0.29807833, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.19988537, 1.61558151, 1.24153244, 0.95350921, 0.74807048, 0.59516323, 0.4783645, 0.38853383, 0.32104823, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 7.49001646, 4.65472794, 3.07277966, 2.19988537, 1.61558151, 1.24153244, 0.95350921, 0.74807048, 0.59516323, 0.50118381, 0.41087446, 0.34370604, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + ], + 1.25: [ + [14.61464119, 0.72133851, 0.02916753], + [14.61464119, 1.56271636, 0.50118381, 0.02916753], + [14.61464119, 2.05039096, 0.803307, 0.32104823, 0.02916753], + [14.61464119, 2.36326075, 0.95350921, 0.43325692, 0.17026083, 0.02916753], + [14.61464119, 2.84484982, 1.24153244, 0.59516323, 0.27464288, 0.09824532, 0.02916753], + [14.61464119, 3.07277966, 1.51179266, 0.803307, 0.43325692, 0.22545385, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.36326075, 1.24153244, 0.72133851, 0.41087446, 0.22545385, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.45070267, 1.36964464, 0.83188516, 0.52423614, 0.34370604, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.84484982, 1.61558151, 0.98595673, 0.64427125, 0.43325692, 0.27464288, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.84484982, 1.67050016, 1.08895338, 0.74807048, 0.52423614, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.84484982, 1.72759056, 1.162866, 0.803307, 0.59516323, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.95596409, 1.84880662, 1.24153244, 0.86115354, 0.64427125, 0.4783645, 0.36617002, 0.27464288, 0.19894916, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.95596409, 1.84880662, 1.28281462, 0.92192322, 0.69515091, 0.52423614, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.95596409, 1.91321158, 1.32549286, 0.95350921, 0.72133851, 0.54755926, 0.43325692, 0.34370604, 0.27464288, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.95596409, 1.91321158, 1.32549286, 0.95350921, 0.72133851, 0.57119018, 0.45573691, 0.36617002, 0.29807833, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.95596409, 1.91321158, 1.32549286, 0.95350921, 0.74807048, 0.59516323, 0.4783645, 0.38853383, 0.32104823, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 3.07277966, 2.05039096, 1.41535246, 1.05362725, 0.803307, 0.61951244, 0.50118381, 0.41087446, 0.34370604, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 3.07277966, 2.05039096, 1.41535246, 1.05362725, 0.803307, 0.64427125, 0.52423614, 0.43325692, 0.36617002, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 3.07277966, 2.05039096, 1.46270394, 1.08895338, 0.83188516, 0.66947293, 0.54755926, 0.45573691, 0.38853383, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + ], + 1.30: [ + [14.61464119, 0.72133851, 0.02916753], + [14.61464119, 1.24153244, 0.43325692, 0.02916753], + [14.61464119, 1.56271636, 0.59516323, 0.22545385, 0.02916753], + [14.61464119, 1.84880662, 0.803307, 0.36617002, 0.13792117, 0.02916753], + [14.61464119, 2.36326075, 1.01931262, 0.52423614, 0.25053367, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.36964464, 0.74807048, 0.41087446, 0.22545385, 0.09824532, 0.02916753], + [14.61464119, 3.07277966, 1.56271636, 0.89115214, 0.54755926, 0.34370604, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 3.07277966, 1.61558151, 0.95350921, 0.61951244, 0.41087446, 0.27464288, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.45070267, 1.36964464, 0.83188516, 0.54755926, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.45070267, 1.41535246, 0.92192322, 0.64427125, 0.45573691, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.6383388, 1.56271636, 1.01931262, 0.72133851, 0.50118381, 0.36617002, 0.27464288, 0.19894916, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.84484982, 1.61558151, 1.05362725, 0.74807048, 0.54755926, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.84484982, 1.61558151, 1.08895338, 0.77538133, 0.57119018, 0.43325692, 0.34370604, 0.27464288, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.84484982, 1.61558151, 1.08895338, 0.803307, 0.59516323, 0.45573691, 0.36617002, 0.29807833, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.84484982, 1.61558151, 1.08895338, 0.803307, 0.59516323, 0.4783645, 0.38853383, 0.32104823, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.84484982, 1.72759056, 1.162866, 0.83188516, 0.64427125, 0.50118381, 0.41087446, 0.34370604, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.84484982, 1.72759056, 1.162866, 0.83188516, 0.64427125, 0.52423614, 0.43325692, 0.36617002, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.84484982, 1.78698075, 1.24153244, 0.92192322, 0.72133851, 0.57119018, 0.45573691, 0.38853383, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.84484982, 1.78698075, 1.24153244, 0.92192322, 0.72133851, 0.57119018, 0.4783645, 0.41087446, 0.36617002, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + ], + 1.35: [ + [14.61464119, 0.69515091, 0.02916753], + [14.61464119, 0.95350921, 0.34370604, 0.02916753], + [14.61464119, 1.56271636, 0.57119018, 0.19894916, 0.02916753], + [14.61464119, 1.61558151, 0.69515091, 0.29807833, 0.09824532, 0.02916753], + [14.61464119, 1.84880662, 0.83188516, 0.43325692, 0.22545385, 0.09824532, 0.02916753], + [14.61464119, 2.45070267, 1.162866, 0.64427125, 0.36617002, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.36964464, 0.803307, 0.50118381, 0.32104823, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.41535246, 0.83188516, 0.54755926, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.56271636, 0.95350921, 0.64427125, 0.45573691, 0.32104823, 0.22545385, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.56271636, 0.95350921, 0.64427125, 0.45573691, 0.34370604, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 3.07277966, 1.61558151, 1.01931262, 0.72133851, 0.52423614, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 3.07277966, 1.61558151, 1.01931262, 0.72133851, 0.52423614, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 3.07277966, 1.61558151, 1.05362725, 0.74807048, 0.54755926, 0.43325692, 0.34370604, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 3.07277966, 1.72759056, 1.12534678, 0.803307, 0.59516323, 0.45573691, 0.36617002, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 3.07277966, 1.72759056, 1.12534678, 0.803307, 0.59516323, 0.4783645, 0.38853383, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.45070267, 1.51179266, 1.01931262, 0.74807048, 0.57119018, 0.45573691, 0.36617002, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.6383388, 1.61558151, 1.08895338, 0.803307, 0.61951244, 0.50118381, 0.41087446, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.6383388, 1.61558151, 1.08895338, 0.803307, 0.64427125, 0.52423614, 0.43325692, 0.36617002, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 5.85520077, 2.6383388, 1.61558151, 1.08895338, 0.803307, 0.64427125, 0.52423614, 0.45573691, 0.38853383, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + ], + 1.40: [ + [14.61464119, 0.59516323, 0.02916753], + [14.61464119, 0.95350921, 0.34370604, 0.02916753], + [14.61464119, 1.08895338, 0.43325692, 0.13792117, 0.02916753], + [14.61464119, 1.56271636, 0.64427125, 0.27464288, 0.09824532, 0.02916753], + [14.61464119, 1.61558151, 0.803307, 0.43325692, 0.22545385, 0.09824532, 0.02916753], + [14.61464119, 2.05039096, 0.95350921, 0.54755926, 0.34370604, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 2.45070267, 1.24153244, 0.72133851, 0.43325692, 0.27464288, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 2.45070267, 1.24153244, 0.74807048, 0.50118381, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 2.45070267, 1.28281462, 0.803307, 0.52423614, 0.36617002, 0.27464288, 0.19894916, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.45070267, 1.28281462, 0.803307, 0.54755926, 0.38853383, 0.29807833, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.41535246, 0.86115354, 0.59516323, 0.43325692, 0.32104823, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.51179266, 0.95350921, 0.64427125, 0.45573691, 0.34370604, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.51179266, 0.95350921, 0.64427125, 0.4783645, 0.36617002, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.56271636, 0.98595673, 0.69515091, 0.52423614, 0.41087446, 0.34370604, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.56271636, 1.01931262, 0.72133851, 0.54755926, 0.43325692, 0.36617002, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.61558151, 1.05362725, 0.74807048, 0.57119018, 0.45573691, 0.38853383, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.61558151, 1.08895338, 0.803307, 0.61951244, 0.50118381, 0.41087446, 0.36617002, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.61558151, 1.08895338, 0.803307, 0.61951244, 0.50118381, 0.43325692, 0.38853383, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.61558151, 1.08895338, 0.803307, 0.64427125, 0.52423614, 0.45573691, 0.41087446, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + ], + 1.45: [ + [14.61464119, 0.59516323, 0.02916753], + [14.61464119, 0.803307, 0.25053367, 0.02916753], + [14.61464119, 0.95350921, 0.34370604, 0.09824532, 0.02916753], + [14.61464119, 1.24153244, 0.54755926, 0.25053367, 0.09824532, 0.02916753], + [14.61464119, 1.56271636, 0.72133851, 0.36617002, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 1.61558151, 0.803307, 0.45573691, 0.27464288, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 1.91321158, 0.95350921, 0.57119018, 0.36617002, 0.25053367, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 2.19988537, 1.08895338, 0.64427125, 0.41087446, 0.27464288, 0.19894916, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.45070267, 1.24153244, 0.74807048, 0.50118381, 0.34370604, 0.25053367, 0.19894916, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.45070267, 1.24153244, 0.74807048, 0.50118381, 0.36617002, 0.27464288, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.45070267, 1.28281462, 0.803307, 0.54755926, 0.41087446, 0.32104823, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.45070267, 1.28281462, 0.803307, 0.57119018, 0.43325692, 0.34370604, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.45070267, 1.28281462, 0.83188516, 0.59516323, 0.45573691, 0.36617002, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.45070267, 1.28281462, 0.83188516, 0.59516323, 0.45573691, 0.36617002, 0.32104823, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.51179266, 0.95350921, 0.69515091, 0.52423614, 0.41087446, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.51179266, 0.95350921, 0.69515091, 0.52423614, 0.43325692, 0.36617002, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.56271636, 0.98595673, 0.72133851, 0.54755926, 0.45573691, 0.38853383, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.56271636, 1.01931262, 0.74807048, 0.57119018, 0.4783645, 0.41087446, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.84484982, 1.56271636, 1.01931262, 0.74807048, 0.59516323, 0.50118381, 0.43325692, 0.38853383, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + ], + 1.50: [ + [14.61464119, 0.54755926, 0.02916753], + [14.61464119, 0.803307, 0.25053367, 0.02916753], + [14.61464119, 0.86115354, 0.32104823, 0.09824532, 0.02916753], + [14.61464119, 1.24153244, 0.54755926, 0.25053367, 0.09824532, 0.02916753], + [14.61464119, 1.56271636, 0.72133851, 0.36617002, 0.19894916, 0.09824532, 0.02916753], + [14.61464119, 1.61558151, 0.803307, 0.45573691, 0.27464288, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 1.61558151, 0.83188516, 0.52423614, 0.34370604, 0.25053367, 0.17026083, 0.09824532, 0.02916753], + [14.61464119, 1.84880662, 0.95350921, 0.59516323, 0.38853383, 0.27464288, 0.19894916, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 1.84880662, 0.95350921, 0.59516323, 0.41087446, 0.29807833, 0.22545385, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 1.84880662, 0.95350921, 0.61951244, 0.43325692, 0.32104823, 0.25053367, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.19988537, 1.12534678, 0.72133851, 0.50118381, 0.36617002, 0.27464288, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.19988537, 1.12534678, 0.72133851, 0.50118381, 0.36617002, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.36326075, 1.24153244, 0.803307, 0.57119018, 0.43325692, 0.34370604, 0.29807833, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.36326075, 1.24153244, 0.803307, 0.57119018, 0.43325692, 0.34370604, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.36326075, 1.24153244, 0.803307, 0.59516323, 0.45573691, 0.36617002, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.36326075, 1.24153244, 0.803307, 0.59516323, 0.45573691, 0.38853383, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.45070267, 1.32549286, 0.86115354, 0.64427125, 0.50118381, 0.41087446, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.45070267, 1.36964464, 0.92192322, 0.69515091, 0.54755926, 0.45573691, 0.41087446, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + [14.61464119, 2.45070267, 1.41535246, 0.95350921, 0.72133851, 0.57119018, 0.4783645, 0.43325692, 0.38853383, 0.36617002, 0.34370604, 0.32104823, 0.29807833, 0.27464288, 0.25053367, 0.22545385, 0.19894916, 0.17026083, 0.13792117, 0.09824532, 0.02916753], + ], +} + +class GITSScheduler: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"coeff": ("FLOAT", {"default": 1.20, "min": 0.80, "max": 1.50, "step": 0.05}), + "steps": ("INT", {"default": 10, "min": 2, "max": 1000}), + "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + } + } + RETURN_TYPES = ("SIGMAS",) + CATEGORY = "sampling/custom_sampling/schedulers" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, coeff, steps, denoise): + total_steps = steps + if denoise < 1.0: + if denoise <= 0.0: + return (torch.FloatTensor([]),) + total_steps = round(steps * denoise) + + if steps <= 20: + sigmas = NOISE_LEVELS[round(coeff, 2)][steps-2][:] + else: + sigmas = NOISE_LEVELS[round(coeff, 2)][-1][:] + sigmas = loglinear_interp(sigmas, steps + 1) + + sigmas = sigmas[-(total_steps + 1):] + sigmas[-1] = 0 + return (torch.FloatTensor(sigmas), ) + +NODE_CLASS_MAPPINGS = { + "GITSScheduler": GITSScheduler, +} diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py new file mode 100644 index 0000000000000000000000000000000000000000..b03eaf6a2044d6a9c05413b9ca9e0b985bbdd26e --- /dev/null +++ b/comfy_extras/nodes_hunyuan.py @@ -0,0 +1,25 @@ +class CLIPTextEncodeHunyuanDiT: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "clip": ("CLIP", ), + "bert": ("STRING", {"multiline": True, "dynamicPrompts": True}), + "mt5xl": ("STRING", {"multiline": True, "dynamicPrompts": True}), + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "encode" + + CATEGORY = "advanced/conditioning" + + def encode(self, clip, bert, mt5xl): + tokens = clip.tokenize(bert) + tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"] + + output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True) + cond = output.pop("cond") + return ([[cond, output]], ) + + +NODE_CLASS_MAPPINGS = { + "CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT, +} diff --git a/comfy_extras/nodes_hypernetwork.py b/comfy_extras/nodes_hypernetwork.py new file mode 100644 index 0000000000000000000000000000000000000000..cafafa6ab5596c9d59e98f7e70afa3d3546c42fd --- /dev/null +++ b/comfy_extras/nodes_hypernetwork.py @@ -0,0 +1,120 @@ +import comfy.utils +import folder_paths +import torch +import logging + +def load_hypernetwork_patch(path, strength): + sd = comfy.utils.load_torch_file(path, safe_load=True) + activation_func = sd.get('activation_func', 'linear') + is_layer_norm = sd.get('is_layer_norm', False) + use_dropout = sd.get('use_dropout', False) + activate_output = sd.get('activate_output', False) + last_layer_dropout = sd.get('last_layer_dropout', False) + + valid_activation = { + "linear": torch.nn.Identity, + "relu": torch.nn.ReLU, + "leakyrelu": torch.nn.LeakyReLU, + "elu": torch.nn.ELU, + "swish": torch.nn.Hardswish, + "tanh": torch.nn.Tanh, + "sigmoid": torch.nn.Sigmoid, + "softsign": torch.nn.Softsign, + "mish": torch.nn.Mish, + } + + if activation_func not in valid_activation: + logging.error("Unsupported Hypernetwork format, if you report it I might implement it. {} {} {} {} {} {}".format(path, activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout)) + return None + + out = {} + + for d in sd: + try: + dim = int(d) + except: + continue + + output = [] + for index in [0, 1]: + attn_weights = sd[dim][index] + keys = attn_weights.keys() + + linears = filter(lambda a: a.endswith(".weight"), keys) + linears = list(map(lambda a: a[:-len(".weight")], linears)) + layers = [] + + i = 0 + while i < len(linears): + lin_name = linears[i] + last_layer = (i == (len(linears) - 1)) + penultimate_layer = (i == (len(linears) - 2)) + + lin_weight = attn_weights['{}.weight'.format(lin_name)] + lin_bias = attn_weights['{}.bias'.format(lin_name)] + layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0]) + layer.load_state_dict({"weight": lin_weight, "bias": lin_bias}) + layers.append(layer) + if activation_func != "linear": + if (not last_layer) or (activate_output): + layers.append(valid_activation[activation_func]()) + if is_layer_norm: + i += 1 + ln_name = linears[i] + ln_weight = attn_weights['{}.weight'.format(ln_name)] + ln_bias = attn_weights['{}.bias'.format(ln_name)] + ln = torch.nn.LayerNorm(ln_weight.shape[0]) + ln.load_state_dict({"weight": ln_weight, "bias": ln_bias}) + layers.append(ln) + if use_dropout: + if (not last_layer) and (not penultimate_layer or last_layer_dropout): + layers.append(torch.nn.Dropout(p=0.3)) + i += 1 + + output.append(torch.nn.Sequential(*layers)) + out[dim] = torch.nn.ModuleList(output) + + class hypernetwork_patch: + def __init__(self, hypernet, strength): + self.hypernet = hypernet + self.strength = strength + def __call__(self, q, k, v, extra_options): + dim = k.shape[-1] + if dim in self.hypernet: + hn = self.hypernet[dim] + k = k + hn[0](k) * self.strength + v = v + hn[1](v) * self.strength + + return q, k, v + + def to(self, device): + for d in self.hypernet.keys(): + self.hypernet[d] = self.hypernet[d].to(device) + return self + + return hypernetwork_patch(out, strength) + +class HypernetworkLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "hypernetwork_name": (folder_paths.get_filename_list("hypernetworks"), ), + "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "load_hypernetwork" + + CATEGORY = "loaders" + + def load_hypernetwork(self, model, hypernetwork_name, strength): + hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name) + model_hypernetwork = model.clone() + patch = load_hypernetwork_patch(hypernetwork_path, strength) + if patch is not None: + model_hypernetwork.set_model_attn1_patch(patch) + model_hypernetwork.set_model_attn2_patch(patch) + return (model_hypernetwork,) + +NODE_CLASS_MAPPINGS = { + "HypernetworkLoader": HypernetworkLoader +} diff --git a/comfy_extras/nodes_hypertile.py b/comfy_extras/nodes_hypertile.py new file mode 100644 index 0000000000000000000000000000000000000000..227133f3978e156cebd7496c2f442916c35afda7 --- /dev/null +++ b/comfy_extras/nodes_hypertile.py @@ -0,0 +1,83 @@ +#Taken from: https://github.com/tfernd/HyperTile/ + +import math +from einops import rearrange +# Use torch rng for consistency across generations +from torch import randint + +def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int: + min_value = min(min_value, value) + + # All big divisors of value (inclusive) + divisors = [i for i in range(min_value, value + 1) if value % i == 0] + + ns = [value // i for i in divisors[:max_options]] # has at least 1 element + + if len(ns) - 1 > 0: + idx = randint(low=0, high=len(ns) - 1, size=(1,)).item() + else: + idx = 0 + + return ns[idx] + +class HyperTile: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "tile_size": ("INT", {"default": 256, "min": 1, "max": 2048}), + "swap_size": ("INT", {"default": 2, "min": 1, "max": 128}), + "max_depth": ("INT", {"default": 0, "min": 0, "max": 10}), + "scale_depth": ("BOOLEAN", {"default": False}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "model_patches/unet" + + def patch(self, model, tile_size, swap_size, max_depth, scale_depth): + model_channels = model.model.model_config.unet_config["model_channels"] + + latent_tile_size = max(32, tile_size) // 8 + self.temp = None + + def hypertile_in(q, k, v, extra_options): + model_chans = q.shape[-2] + orig_shape = extra_options['original_shape'] + apply_to = [] + for i in range(max_depth + 1): + apply_to.append((orig_shape[-2] / (2 ** i)) * (orig_shape[-1] / (2 ** i))) + + if model_chans in apply_to: + shape = extra_options["original_shape"] + aspect_ratio = shape[-1] / shape[-2] + + hw = q.size(1) + h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio)) + + factor = (2 ** apply_to.index(model_chans)) if scale_depth else 1 + nh = random_divisor(h, latent_tile_size * factor, swap_size) + nw = random_divisor(w, latent_tile_size * factor, swap_size) + + if nh * nw > 1: + q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw) + self.temp = (nh, nw, h, w) + return q, k, v + + return q, k, v + def hypertile_out(out, extra_options): + if self.temp is not None: + nh, nw, h, w = self.temp + self.temp = None + out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw) + out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw) + return out + + + m = model.clone() + m.set_model_attn1_patch(hypertile_in) + m.set_model_attn1_output_patch(hypertile_out) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "HyperTile": HyperTile, +} diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py new file mode 100644 index 0000000000000000000000000000000000000000..af37666b29fcf5d6a572b715682fd031e28639bc --- /dev/null +++ b/comfy_extras/nodes_images.py @@ -0,0 +1,195 @@ +import nodes +import folder_paths +from comfy.cli_args import args + +from PIL import Image +from PIL.PngImagePlugin import PngInfo + +import numpy as np +import json +import os + +MAX_RESOLUTION = nodes.MAX_RESOLUTION + +class ImageCrop: + @classmethod + def INPUT_TYPES(s): + return {"required": { "image": ("IMAGE",), + "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), + "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), + "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + }} + RETURN_TYPES = ("IMAGE",) + FUNCTION = "crop" + + CATEGORY = "image/transform" + + def crop(self, image, width, height, x, y): + x = min(x, image.shape[2] - 1) + y = min(y, image.shape[1] - 1) + to_x = width + x + to_y = height + y + img = image[:,y:to_y, x:to_x, :] + return (img,) + +class RepeatImageBatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "image": ("IMAGE",), + "amount": ("INT", {"default": 1, "min": 1, "max": 4096}), + }} + RETURN_TYPES = ("IMAGE",) + FUNCTION = "repeat" + + CATEGORY = "image/batch" + + def repeat(self, image, amount): + s = image.repeat((amount, 1,1,1)) + return (s,) + +class ImageFromBatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "image": ("IMAGE",), + "batch_index": ("INT", {"default": 0, "min": 0, "max": 4095}), + "length": ("INT", {"default": 1, "min": 1, "max": 4096}), + }} + RETURN_TYPES = ("IMAGE",) + FUNCTION = "frombatch" + + CATEGORY = "image/batch" + + def frombatch(self, image, batch_index, length): + s_in = image + batch_index = min(s_in.shape[0] - 1, batch_index) + length = min(s_in.shape[0] - batch_index, length) + s = s_in[batch_index:batch_index + length].clone() + return (s,) + +class SaveAnimatedWEBP: + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + self.type = "output" + self.prefix_append = "" + + methods = {"default": 4, "fastest": 0, "slowest": 6} + @classmethod + def INPUT_TYPES(s): + return {"required": + {"images": ("IMAGE", ), + "filename_prefix": ("STRING", {"default": "ComfyUI"}), + "fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}), + "lossless": ("BOOLEAN", {"default": True}), + "quality": ("INT", {"default": 80, "min": 0, "max": 100}), + "method": (list(s.methods.keys()),), + # "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}), + }, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + + RETURN_TYPES = () + FUNCTION = "save_images" + + OUTPUT_NODE = True + + CATEGORY = "image/animation" + + def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None): + method = self.methods.get(method) + filename_prefix += self.prefix_append + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) + results = list() + pil_images = [] + for image in images: + i = 255. * image.cpu().numpy() + img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) + pil_images.append(img) + + metadata = pil_images[0].getexif() + if not args.disable_metadata: + if prompt is not None: + metadata[0x0110] = "prompt:{}".format(json.dumps(prompt)) + if extra_pnginfo is not None: + inital_exif = 0x010f + for x in extra_pnginfo: + metadata[inital_exif] = "{}:{}".format(x, json.dumps(extra_pnginfo[x])) + inital_exif -= 1 + + if num_frames == 0: + num_frames = len(pil_images) + + c = len(pil_images) + for i in range(0, c, num_frames): + file = f"{filename}_{counter:05}_.webp" + pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method) + results.append({ + "filename": file, + "subfolder": subfolder, + "type": self.type + }) + counter += 1 + + animated = num_frames != 1 + return { "ui": { "images": results, "animated": (animated,) } } + +class SaveAnimatedPNG: + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + self.type = "output" + self.prefix_append = "" + + @classmethod + def INPUT_TYPES(s): + return {"required": + {"images": ("IMAGE", ), + "filename_prefix": ("STRING", {"default": "ComfyUI"}), + "fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}), + "compress_level": ("INT", {"default": 4, "min": 0, "max": 9}) + }, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + + RETURN_TYPES = () + FUNCTION = "save_images" + + OUTPUT_NODE = True + + CATEGORY = "image/animation" + + def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): + filename_prefix += self.prefix_append + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) + results = list() + pil_images = [] + for image in images: + i = 255. * image.cpu().numpy() + img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) + pil_images.append(img) + + metadata = None + if not args.disable_metadata: + metadata = PngInfo() + if prompt is not None: + metadata.add(b"comf", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(prompt).encode("latin-1", "strict"), after_idat=True) + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata.add(b"comf", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True) + + file = f"{filename}_{counter:05}_.png" + pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:]) + results.append({ + "filename": file, + "subfolder": subfolder, + "type": self.type + }) + + return { "ui": { "images": results, "animated": (True,)} } + +NODE_CLASS_MAPPINGS = { + "ImageCrop": ImageCrop, + "RepeatImageBatch": RepeatImageBatch, + "ImageFromBatch": ImageFromBatch, + "SaveAnimatedWEBP": SaveAnimatedWEBP, + "SaveAnimatedPNG": SaveAnimatedPNG, +} diff --git a/comfy_extras/nodes_ip2p.py b/comfy_extras/nodes_ip2p.py new file mode 100644 index 0000000000000000000000000000000000000000..c2e70a84c10ca5cc1b3ca853a97adc3c64fbb315 --- /dev/null +++ b/comfy_extras/nodes_ip2p.py @@ -0,0 +1,45 @@ +import torch + +class InstructPixToPixConditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "vae": ("VAE", ), + "pixels": ("IMAGE", ), + }} + + RETURN_TYPES = ("CONDITIONING","CONDITIONING","LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + FUNCTION = "encode" + + CATEGORY = "conditioning/instructpix2pix" + + def encode(self, positive, negative, pixels, vae): + x = (pixels.shape[1] // 8) * 8 + y = (pixels.shape[2] // 8) * 8 + + if pixels.shape[1] != x or pixels.shape[2] != y: + x_offset = (pixels.shape[1] % 8) // 2 + y_offset = (pixels.shape[2] % 8) // 2 + pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:] + + concat_latent = vae.encode(pixels) + + out_latent = {} + out_latent["samples"] = torch.zeros_like(concat_latent) + + out = [] + for conditioning in [positive, negative]: + c = [] + for t in conditioning: + d = t[1].copy() + d["concat_latent_image"] = concat_latent + n = [t[0], d] + c.append(n) + out.append(c) + return (out[0], out[1], out_latent) + +NODE_CLASS_MAPPINGS = { + "InstructPixToPixConditioning": InstructPixToPixConditioning, +} diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py new file mode 100644 index 0000000000000000000000000000000000000000..eabae0885163bbbe4d3a6a6e6e1ac9c81ae795c4 --- /dev/null +++ b/comfy_extras/nodes_latent.py @@ -0,0 +1,155 @@ +import comfy.utils +import torch + +def reshape_latent_to(target_shape, latent): + if latent.shape[1:] != target_shape[1:]: + latent = comfy.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center") + return comfy.utils.repeat_to_batch_size(latent, target_shape[0]) + + +class LatentAdd: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "op" + + CATEGORY = "latent/advanced" + + def op(self, samples1, samples2): + samples_out = samples1.copy() + + s1 = samples1["samples"] + s2 = samples2["samples"] + + s2 = reshape_latent_to(s1.shape, s2) + samples_out["samples"] = s1 + s2 + return (samples_out,) + +class LatentSubtract: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "op" + + CATEGORY = "latent/advanced" + + def op(self, samples1, samples2): + samples_out = samples1.copy() + + s1 = samples1["samples"] + s2 = samples2["samples"] + + s2 = reshape_latent_to(s1.shape, s2) + samples_out["samples"] = s1 - s2 + return (samples_out,) + +class LatentMultiply: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT",), + "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + }} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "op" + + CATEGORY = "latent/advanced" + + def op(self, samples, multiplier): + samples_out = samples.copy() + + s1 = samples["samples"] + samples_out["samples"] = s1 * multiplier + return (samples_out,) + +class LatentInterpolate: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples1": ("LATENT",), + "samples2": ("LATENT",), + "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + }} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "op" + + CATEGORY = "latent/advanced" + + def op(self, samples1, samples2, ratio): + samples_out = samples1.copy() + + s1 = samples1["samples"] + s2 = samples2["samples"] + + s2 = reshape_latent_to(s1.shape, s2) + + m1 = torch.linalg.vector_norm(s1, dim=(1)) + m2 = torch.linalg.vector_norm(s2, dim=(1)) + + s1 = torch.nan_to_num(s1 / m1) + s2 = torch.nan_to_num(s2 / m2) + + t = (s1 * ratio + s2 * (1.0 - ratio)) + mt = torch.linalg.vector_norm(t, dim=(1)) + st = torch.nan_to_num(t / mt) + + samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio)) + return (samples_out,) + +class LatentBatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "batch" + + CATEGORY = "latent/batch" + + def batch(self, samples1, samples2): + samples_out = samples1.copy() + s1 = samples1["samples"] + s2 = samples2["samples"] + + if s1.shape[1:] != s2.shape[1:]: + s2 = comfy.utils.common_upscale(s2, s1.shape[3], s1.shape[2], "bilinear", "center") + s = torch.cat((s1, s2), dim=0) + samples_out["samples"] = s + samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])]) + return (samples_out,) + +class LatentBatchSeedBehavior: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT",), + "seed_behavior": (["random", "fixed"],{"default": "fixed"}),}} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "op" + + CATEGORY = "latent/advanced" + + def op(self, samples, seed_behavior): + samples_out = samples.copy() + latent = samples["samples"] + if seed_behavior == "random": + if 'batch_index' in samples_out: + samples_out.pop('batch_index') + elif seed_behavior == "fixed": + batch_number = samples_out.get("batch_index", [0])[0] + samples_out["batch_index"] = [batch_number] * latent.shape[0] + + return (samples_out,) + +NODE_CLASS_MAPPINGS = { + "LatentAdd": LatentAdd, + "LatentSubtract": LatentSubtract, + "LatentMultiply": LatentMultiply, + "LatentInterpolate": LatentInterpolate, + "LatentBatch": LatentBatch, + "LatentBatchSeedBehavior": LatentBatchSeedBehavior, +} diff --git a/comfy_extras/nodes_lora_extract.py b/comfy_extras/nodes_lora_extract.py new file mode 100644 index 0000000000000000000000000000000000000000..3c2f179d30f857d0b414c67d1d3aaae122b56323 --- /dev/null +++ b/comfy_extras/nodes_lora_extract.py @@ -0,0 +1,115 @@ +import torch +import comfy.model_management +import comfy.utils +import folder_paths +import os +import logging +from enum import Enum + +CLAMP_QUANTILE = 0.99 + +def extract_lora(diff, rank): + conv2d = (len(diff.shape) == 4) + kernel_size = None if not conv2d else diff.size()[2:4] + conv2d_3x3 = conv2d and kernel_size != (1, 1) + out_dim, in_dim = diff.size()[0:2] + rank = min(rank, in_dim, out_dim) + + if conv2d: + if conv2d_3x3: + diff = diff.flatten(start_dim=1) + else: + diff = diff.squeeze() + + + U, S, Vh = torch.linalg.svd(diff.float()) + U = U[:, :rank] + S = S[:rank] + U = U @ torch.diag(S) + Vh = Vh[:rank, :] + + dist = torch.cat([U.flatten(), Vh.flatten()]) + hi_val = torch.quantile(dist, CLAMP_QUANTILE) + low_val = -hi_val + + U = U.clamp(low_val, hi_val) + Vh = Vh.clamp(low_val, hi_val) + if conv2d: + U = U.reshape(out_dim, rank, 1, 1) + Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1]) + return (U, Vh) + +class LORAType(Enum): + STANDARD = 0 + FULL_DIFF = 1 + +LORA_TYPES = {"standard": LORAType.STANDARD, + "full_diff": LORAType.FULL_DIFF} + +def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False): + comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True) + sd = model_diff.model_state_dict(filter_prefix=prefix_model) + + for k in sd: + if k.endswith(".weight"): + weight_diff = sd[k] + if lora_type == LORAType.STANDARD: + if weight_diff.ndim < 2: + if bias_diff: + output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu() + continue + try: + out = extract_lora(weight_diff, rank) + output_sd["{}{}.lora_up.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[0].contiguous().half().cpu() + output_sd["{}{}.lora_down.weight".format(prefix_lora, k[len(prefix_model):-7])] = out[1].contiguous().half().cpu() + except: + logging.warning("Could not generate lora weights for key {}, is the weight difference a zero?".format(k)) + elif lora_type == LORAType.FULL_DIFF: + output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu() + + elif bias_diff and k.endswith(".bias"): + output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu() + return output_sd + +class LoraSave: + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + + @classmethod + def INPUT_TYPES(s): + return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}), + "rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1}), + "lora_type": (tuple(LORA_TYPES.keys()),), + "bias_diff": ("BOOLEAN", {"default": True}), + }, + "optional": {"model_diff": ("MODEL",), + "text_encoder_diff": ("CLIP",)}, + } + RETURN_TYPES = () + FUNCTION = "save" + OUTPUT_NODE = True + + CATEGORY = "_for_testing" + + def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None): + if model_diff is None and text_encoder_diff is None: + return {} + + lora_type = LORA_TYPES.get(lora_type) + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) + + output_sd = {} + if model_diff is not None: + output_sd = calc_lora_model(model_diff, rank, "diffusion_model.", "diffusion_model.", output_sd, lora_type, bias_diff=bias_diff) + if text_encoder_diff is not None: + output_sd = calc_lora_model(text_encoder_diff.patcher, rank, "", "text_encoders.", output_sd, lora_type, bias_diff=bias_diff) + + output_checkpoint = f"{filename}_{counter:05}_.safetensors" + output_checkpoint = os.path.join(full_output_folder, output_checkpoint) + + comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None) + return {} + +NODE_CLASS_MAPPINGS = { + "LoraSave": LoraSave +} diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..29589b4abade581e4c60ccc2cab23b582dd2f61a --- /dev/null +++ b/comfy_extras/nodes_mask.py @@ -0,0 +1,382 @@ +import numpy as np +import scipy.ndimage +import torch +import comfy.utils + +from nodes import MAX_RESOLUTION + +def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False): + source = source.to(destination.device) + if resize_source: + source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear") + + source = comfy.utils.repeat_to_batch_size(source, destination.shape[0]) + + x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier)) + y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier)) + + left, top = (x // multiplier, y // multiplier) + right, bottom = (left + source.shape[3], top + source.shape[2],) + + if mask is None: + mask = torch.ones_like(source) + else: + mask = mask.to(destination.device, copy=True) + mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear") + mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0]) + + # calculate the bounds of the source that will be overlapping the destination + # this prevents the source trying to overwrite latent pixels that are out of bounds + # of the destination + visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),) + + mask = mask[:, :, :visible_height, :visible_width] + inverse_mask = torch.ones_like(mask) - mask + + source_portion = mask * source[:, :, :visible_height, :visible_width] + destination_portion = inverse_mask * destination[:, :, top:bottom, left:right] + + destination[:, :, top:bottom, left:right] = source_portion + destination_portion + return destination + +class LatentCompositeMasked: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "destination": ("LATENT",), + "source": ("LATENT",), + "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), + "resize_source": ("BOOLEAN", {"default": False}), + }, + "optional": { + "mask": ("MASK",), + } + } + RETURN_TYPES = ("LATENT",) + FUNCTION = "composite" + + CATEGORY = "latent" + + def composite(self, destination, source, x, y, resize_source, mask = None): + output = destination.copy() + destination = destination["samples"].clone() + source = source["samples"] + output["samples"] = composite(destination, source, x, y, mask, 8, resize_source) + return (output,) + +class ImageCompositeMasked: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "destination": ("IMAGE",), + "source": ("IMAGE",), + "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "resize_source": ("BOOLEAN", {"default": False}), + }, + "optional": { + "mask": ("MASK",), + } + } + RETURN_TYPES = ("IMAGE",) + FUNCTION = "composite" + + CATEGORY = "image" + + def composite(self, destination, source, x, y, resize_source, mask = None): + destination = destination.clone().movedim(-1, 1) + output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1) + return (output,) + +class MaskToImage: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "mask": ("MASK",), + } + } + + CATEGORY = "mask" + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "mask_to_image" + + def mask_to_image(self, mask): + result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) + return (result,) + +class ImageToMask: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "channel": (["red", "green", "blue", "alpha"],), + } + } + + CATEGORY = "mask" + + RETURN_TYPES = ("MASK",) + FUNCTION = "image_to_mask" + + def image_to_mask(self, image, channel): + channels = ["red", "green", "blue", "alpha"] + mask = image[:, :, :, channels.index(channel)] + return (mask,) + +class ImageColorToMask: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}), + } + } + + CATEGORY = "mask" + + RETURN_TYPES = ("MASK",) + FUNCTION = "image_to_mask" + + def image_to_mask(self, image, color): + temp = (torch.clamp(image, 0, 1.0) * 255.0).round().to(torch.int) + temp = torch.bitwise_left_shift(temp[:,:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,:,1], 8) + temp[:,:,:,2] + mask = torch.where(temp == color, 255, 0).float() + return (mask,) + +class SolidMask: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), + "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), + } + } + + CATEGORY = "mask" + + RETURN_TYPES = ("MASK",) + + FUNCTION = "solid" + + def solid(self, value, width, height): + out = torch.full((1, height, width), value, dtype=torch.float32, device="cpu") + return (out,) + +class InvertMask: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "mask": ("MASK",), + } + } + + CATEGORY = "mask" + + RETURN_TYPES = ("MASK",) + + FUNCTION = "invert" + + def invert(self, mask): + out = 1.0 - mask + return (out,) + +class CropMask: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "mask": ("MASK",), + "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), + "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), + } + } + + CATEGORY = "mask" + + RETURN_TYPES = ("MASK",) + + FUNCTION = "crop" + + def crop(self, mask, x, y, width, height): + mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1])) + out = mask[:, y:y + height, x:x + width] + return (out,) + +class MaskComposite: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "destination": ("MASK",), + "source": ("MASK",), + "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "operation": (["multiply", "add", "subtract", "and", "or", "xor"],), + } + } + + CATEGORY = "mask" + + RETURN_TYPES = ("MASK",) + + FUNCTION = "combine" + + def combine(self, destination, source, x, y, operation): + output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone() + source = source.reshape((-1, source.shape[-2], source.shape[-1])) + + left, top = (x, y,) + right, bottom = (min(left + source.shape[-1], destination.shape[-1]), min(top + source.shape[-2], destination.shape[-2])) + visible_width, visible_height = (right - left, bottom - top,) + + source_portion = source[:, :visible_height, :visible_width] + destination_portion = destination[:, top:bottom, left:right] + + if operation == "multiply": + output[:, top:bottom, left:right] = destination_portion * source_portion + elif operation == "add": + output[:, top:bottom, left:right] = destination_portion + source_portion + elif operation == "subtract": + output[:, top:bottom, left:right] = destination_portion - source_portion + elif operation == "and": + output[:, top:bottom, left:right] = torch.bitwise_and(destination_portion.round().bool(), source_portion.round().bool()).float() + elif operation == "or": + output[:, top:bottom, left:right] = torch.bitwise_or(destination_portion.round().bool(), source_portion.round().bool()).float() + elif operation == "xor": + output[:, top:bottom, left:right] = torch.bitwise_xor(destination_portion.round().bool(), source_portion.round().bool()).float() + + output = torch.clamp(output, 0.0, 1.0) + + return (output,) + +class FeatherMask: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "mask": ("MASK",), + "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + } + } + + CATEGORY = "mask" + + RETURN_TYPES = ("MASK",) + + FUNCTION = "feather" + + def feather(self, mask, left, top, right, bottom): + output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone() + + left = min(left, output.shape[-1]) + right = min(right, output.shape[-1]) + top = min(top, output.shape[-2]) + bottom = min(bottom, output.shape[-2]) + + for x in range(left): + feather_rate = (x + 1.0) / left + output[:, :, x] *= feather_rate + + for x in range(right): + feather_rate = (x + 1) / right + output[:, :, -x] *= feather_rate + + for y in range(top): + feather_rate = (y + 1) / top + output[:, y, :] *= feather_rate + + for y in range(bottom): + feather_rate = (y + 1) / bottom + output[:, -y, :] *= feather_rate + + return (output,) + +class GrowMask: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "mask": ("MASK",), + "expand": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 1}), + "tapered_corners": ("BOOLEAN", {"default": True}), + }, + } + + CATEGORY = "mask" + + RETURN_TYPES = ("MASK",) + + FUNCTION = "expand_mask" + + def expand_mask(self, mask, expand, tapered_corners): + c = 0 if tapered_corners else 1 + kernel = np.array([[c, 1, c], + [1, 1, 1], + [c, 1, c]]) + mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1])) + out = [] + for m in mask: + output = m.numpy() + for _ in range(abs(expand)): + if expand < 0: + output = scipy.ndimage.grey_erosion(output, footprint=kernel) + else: + output = scipy.ndimage.grey_dilation(output, footprint=kernel) + output = torch.from_numpy(output) + out.append(output) + return (torch.stack(out, dim=0),) + +class ThresholdMask: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "mask": ("MASK",), + "value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), + } + } + + CATEGORY = "mask" + + RETURN_TYPES = ("MASK",) + FUNCTION = "image_to_mask" + + def image_to_mask(self, mask, value): + mask = (mask > value).float() + return (mask,) + + +NODE_CLASS_MAPPINGS = { + "LatentCompositeMasked": LatentCompositeMasked, + "ImageCompositeMasked": ImageCompositeMasked, + "MaskToImage": MaskToImage, + "ImageToMask": ImageToMask, + "ImageColorToMask": ImageColorToMask, + "SolidMask": SolidMask, + "InvertMask": InvertMask, + "CropMask": CropMask, + "MaskComposite": MaskComposite, + "FeatherMask": FeatherMask, + "GrowMask": GrowMask, + "ThresholdMask": ThresholdMask, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "ImageToMask": "Convert Image to Mask", + "MaskToImage": "Convert Mask to Image", +} diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py new file mode 100644 index 0000000000000000000000000000000000000000..918e6085aada9ef87ef04c47270317ab16246450 --- /dev/null +++ b/comfy_extras/nodes_model_advanced.py @@ -0,0 +1,326 @@ +import folder_paths +import comfy.sd +import comfy.model_sampling +import comfy.latent_formats +import nodes +import torch + +class LCM(comfy.model_sampling.EPS): + def calculate_denoised(self, sigma, model_output, model_input): + timestep = self.timestep(sigma).view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + x0 = model_input - model_output * sigma + + sigma_data = 0.5 + scaled_timestep = timestep * 10.0 #timestep_scaling + + c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 + + return c_out * x0 + c_skip * model_input + +class X0(comfy.model_sampling.EPS): + def calculate_denoised(self, sigma, model_output, model_input): + return model_output + +class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete): + original_timesteps = 50 + + def __init__(self, model_config=None): + super().__init__(model_config) + + self.skip_steps = self.num_timesteps // self.original_timesteps + + sigmas_valid = torch.zeros((self.original_timesteps), dtype=torch.float32) + for x in range(self.original_timesteps): + sigmas_valid[self.original_timesteps - 1 - x] = self.sigmas[self.num_timesteps - 1 - x * self.skip_steps] + + self.set_sigmas(sigmas_valid) + + def timestep(self, sigma): + log_sigma = sigma.log() + dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] + return (dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)).to(sigma.device) + + def sigma(self, timestep): + t = torch.clamp(((timestep.float().to(self.log_sigmas.device) - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1)) + low_idx = t.floor().long() + high_idx = t.ceil().long() + w = t.frac() + log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] + return log_sigma.exp().to(timestep.device) + + +def rescale_zero_terminal_snr_sigmas(sigmas): + alphas_cumprod = 1 / ((sigmas * sigmas) + 1) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= (alphas_bar_sqrt_T) + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas_bar[-1] = 4.8973451890853435e-08 + return ((1 - alphas_bar) / alphas_bar) ** 0.5 + +class ModelSamplingDiscrete: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "sampling": (["eps", "v_prediction", "lcm", "x0"],), + "zsnr": ("BOOLEAN", {"default": False}), + }} + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "advanced/model" + + def patch(self, model, sampling, zsnr): + m = model.clone() + + sampling_base = comfy.model_sampling.ModelSamplingDiscrete + if sampling == "eps": + sampling_type = comfy.model_sampling.EPS + elif sampling == "v_prediction": + sampling_type = comfy.model_sampling.V_PREDICTION + elif sampling == "lcm": + sampling_type = LCM + sampling_base = ModelSamplingDiscreteDistilled + elif sampling == "x0": + sampling_type = X0 + + class ModelSamplingAdvanced(sampling_base, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced(model.model.model_config) + if zsnr: + model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas)) + + m.add_object_patch("model_sampling", model_sampling) + return (m, ) + +class ModelSamplingStableCascade: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "shift": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 100.0, "step":0.01}), + }} + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "advanced/model" + + def patch(self, model, shift): + m = model.clone() + + sampling_base = comfy.model_sampling.StableCascadeSampling + sampling_type = comfy.model_sampling.EPS + + class ModelSamplingAdvanced(sampling_base, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced(model.model.model_config) + model_sampling.set_parameters(shift) + m.add_object_patch("model_sampling", model_sampling) + return (m, ) + +class ModelSamplingSD3: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "shift": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step":0.01}), + }} + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "advanced/model" + + def patch(self, model, shift, multiplier=1000): + m = model.clone() + + sampling_base = comfy.model_sampling.ModelSamplingDiscreteFlow + sampling_type = comfy.model_sampling.CONST + + class ModelSamplingAdvanced(sampling_base, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced(model.model.model_config) + model_sampling.set_parameters(shift=shift, multiplier=multiplier) + m.add_object_patch("model_sampling", model_sampling) + return (m, ) + +class ModelSamplingAuraFlow(ModelSamplingSD3): + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "shift": ("FLOAT", {"default": 1.73, "min": 0.0, "max": 100.0, "step":0.01}), + }} + + FUNCTION = "patch_aura" + + def patch_aura(self, model, shift): + return self.patch(model, shift, multiplier=1.0) + +class ModelSamplingFlux: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "max_shift": ("FLOAT", {"default": 1.15, "min": 0.0, "max": 100.0, "step":0.01}), + "base_shift": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step":0.01}), + "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + }} + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "advanced/model" + + def patch(self, model, max_shift, base_shift, width, height): + m = model.clone() + + x1 = 256 + x2 = 4096 + mm = (max_shift - base_shift) / (x2 - x1) + b = base_shift - mm * x1 + shift = (width * height / (8 * 8 * 2 * 2)) * mm + b + + sampling_base = comfy.model_sampling.ModelSamplingFlux + sampling_type = comfy.model_sampling.CONST + + class ModelSamplingAdvanced(sampling_base, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced(model.model.model_config) + model_sampling.set_parameters(shift=shift) + m.add_object_patch("model_sampling", model_sampling) + return (m, ) + + +class ModelSamplingContinuousEDM: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "sampling": (["v_prediction", "edm_playground_v2.5", "eps"],), + "sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), + "sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), + }} + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "advanced/model" + + def patch(self, model, sampling, sigma_max, sigma_min): + m = model.clone() + + latent_format = None + sigma_data = 1.0 + if sampling == "eps": + sampling_type = comfy.model_sampling.EPS + elif sampling == "v_prediction": + sampling_type = comfy.model_sampling.V_PREDICTION + elif sampling == "edm_playground_v2.5": + sampling_type = comfy.model_sampling.EDM + sigma_data = 0.5 + latent_format = comfy.latent_formats.SDXL_Playground_2_5() + + class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced(model.model.model_config) + model_sampling.set_parameters(sigma_min, sigma_max, sigma_data) + m.add_object_patch("model_sampling", model_sampling) + if latent_format is not None: + m.add_object_patch("latent_format", latent_format) + return (m, ) + +class ModelSamplingContinuousV: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "sampling": (["v_prediction"],), + "sigma_max": ("FLOAT", {"default": 500.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), + "sigma_min": ("FLOAT", {"default": 0.03, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), + }} + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "advanced/model" + + def patch(self, model, sampling, sigma_max, sigma_min): + m = model.clone() + + latent_format = None + sigma_data = 1.0 + if sampling == "v_prediction": + sampling_type = comfy.model_sampling.V_PREDICTION + + class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousV, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced(model.model.model_config) + model_sampling.set_parameters(sigma_min, sigma_max, sigma_data) + m.add_object_patch("model_sampling", model_sampling) + return (m, ) + +class RescaleCFG: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "multiplier": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "advanced/model" + + def patch(self, model, multiplier): + def rescale_cfg(args): + cond = args["cond"] + uncond = args["uncond"] + cond_scale = args["cond_scale"] + sigma = args["sigma"] + sigma = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1)) + x_orig = args["input"] + + #rescale cfg has to be done on v-pred model output + x = x_orig / (sigma * sigma + 1.0) + cond = ((x - (x_orig - cond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma) + uncond = ((x - (x_orig - uncond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma) + + #rescalecfg + x_cfg = uncond + cond_scale * (cond - uncond) + ro_pos = torch.std(cond, dim=(1,2,3), keepdim=True) + ro_cfg = torch.std(x_cfg, dim=(1,2,3), keepdim=True) + + x_rescaled = x_cfg * (ro_pos / ro_cfg) + x_final = multiplier * x_rescaled + (1.0 - multiplier) * x_cfg + + return x_orig - (x - x_final * sigma / (sigma * sigma + 1.0) ** 0.5) + + m = model.clone() + m.set_model_sampler_cfg_function(rescale_cfg) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "ModelSamplingDiscrete": ModelSamplingDiscrete, + "ModelSamplingContinuousEDM": ModelSamplingContinuousEDM, + "ModelSamplingContinuousV": ModelSamplingContinuousV, + "ModelSamplingStableCascade": ModelSamplingStableCascade, + "ModelSamplingSD3": ModelSamplingSD3, + "ModelSamplingAuraFlow": ModelSamplingAuraFlow, + "ModelSamplingFlux": ModelSamplingFlux, + "RescaleCFG": RescaleCFG, +} diff --git a/comfy_extras/nodes_model_downscale.py b/comfy_extras/nodes_model_downscale.py new file mode 100644 index 0000000000000000000000000000000000000000..58b5073ec08c9937017ef3a1ae5e09ffdd3f0c89 --- /dev/null +++ b/comfy_extras/nodes_model_downscale.py @@ -0,0 +1,54 @@ +import torch +import comfy.utils + +class PatchModelAddDownscale: + upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"] + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}), + "downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}), + "downscale_after_skip": ("BOOLEAN", {"default": True}), + "downscale_method": (s.upscale_methods,), + "upscale_method": (s.upscale_methods,), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing" + + def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method): + model_sampling = model.get_model_object("model_sampling") + sigma_start = model_sampling.percent_to_sigma(start_percent) + sigma_end = model_sampling.percent_to_sigma(end_percent) + + def input_block_patch(h, transformer_options): + if transformer_options["block"][1] == block_number: + sigma = transformer_options["sigmas"][0].item() + if sigma <= sigma_start and sigma >= sigma_end: + h = comfy.utils.common_upscale(h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled") + return h + + def output_block_patch(h, hsp, transformer_options): + if h.shape[2] != hsp.shape[2]: + h = comfy.utils.common_upscale(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled") + return h, hsp + + m = model.clone() + if downscale_after_skip: + m.set_model_input_block_patch_after_skip(input_block_patch) + else: + m.set_model_input_block_patch(input_block_patch) + m.set_model_output_block_patch(output_block_patch) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "PatchModelAddDownscale": PatchModelAddDownscale, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + # Sampling + "PatchModelAddDownscale": "PatchModelAddDownscale (Kohya Deep Shrink)", +} diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py new file mode 100644 index 0000000000000000000000000000000000000000..ccf601158d59ab707d07f1e438c739d4dbced131 --- /dev/null +++ b/comfy_extras/nodes_model_merging.py @@ -0,0 +1,371 @@ +import comfy.sd +import comfy.utils +import comfy.model_base +import comfy.model_management +import comfy.model_sampling + +import torch +import folder_paths +import json +import os + +from comfy.cli_args import args + +class ModelMergeSimple: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model1": ("MODEL",), + "model2": ("MODEL",), + "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "merge" + + CATEGORY = "advanced/model_merging" + + def merge(self, model1, model2, ratio): + m = model1.clone() + kp = model2.get_key_patches("diffusion_model.") + for k in kp: + m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) + return (m, ) + +class ModelSubtract: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model1": ("MODEL",), + "model2": ("MODEL",), + "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "merge" + + CATEGORY = "advanced/model_merging" + + def merge(self, model1, model2, multiplier): + m = model1.clone() + kp = model2.get_key_patches("diffusion_model.") + for k in kp: + m.add_patches({k: kp[k]}, - multiplier, multiplier) + return (m, ) + +class ModelAdd: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model1": ("MODEL",), + "model2": ("MODEL",), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "merge" + + CATEGORY = "advanced/model_merging" + + def merge(self, model1, model2): + m = model1.clone() + kp = model2.get_key_patches("diffusion_model.") + for k in kp: + m.add_patches({k: kp[k]}, 1.0, 1.0) + return (m, ) + + +class CLIPMergeSimple: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip1": ("CLIP",), + "clip2": ("CLIP",), + "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + }} + RETURN_TYPES = ("CLIP",) + FUNCTION = "merge" + + CATEGORY = "advanced/model_merging" + + def merge(self, clip1, clip2, ratio): + m = clip1.clone() + kp = clip2.get_key_patches() + for k in kp: + if k.endswith(".position_ids") or k.endswith(".logit_scale"): + continue + m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) + return (m, ) + + +class CLIPSubtract: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip1": ("CLIP",), + "clip2": ("CLIP",), + "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("CLIP",) + FUNCTION = "merge" + + CATEGORY = "advanced/model_merging" + + def merge(self, clip1, clip2, multiplier): + m = clip1.clone() + kp = clip2.get_key_patches() + for k in kp: + if k.endswith(".position_ids") or k.endswith(".logit_scale"): + continue + m.add_patches({k: kp[k]}, - multiplier, multiplier) + return (m, ) + + +class CLIPAdd: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip1": ("CLIP",), + "clip2": ("CLIP",), + }} + RETURN_TYPES = ("CLIP",) + FUNCTION = "merge" + + CATEGORY = "advanced/model_merging" + + def merge(self, clip1, clip2): + m = clip1.clone() + kp = clip2.get_key_patches() + for k in kp: + if k.endswith(".position_ids") or k.endswith(".logit_scale"): + continue + m.add_patches({k: kp[k]}, 1.0, 1.0) + return (m, ) + + +class ModelMergeBlocks: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model1": ("MODEL",), + "model2": ("MODEL",), + "input": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "middle": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "merge" + + CATEGORY = "advanced/model_merging" + + def merge(self, model1, model2, **kwargs): + m = model1.clone() + kp = model2.get_key_patches("diffusion_model.") + default_ratio = next(iter(kwargs.values())) + + for k in kp: + ratio = default_ratio + k_unet = k[len("diffusion_model."):] + + last_arg_size = 0 + for arg in kwargs: + if k_unet.startswith(arg) and last_arg_size < len(arg): + ratio = kwargs[arg] + last_arg_size = len(arg) + + m.add_patches({k: kp[k]}, 1.0 - ratio, ratio) + return (m, ) + +def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefix=None, output_dir=None, prompt=None, extra_pnginfo=None): + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, output_dir) + prompt_info = "" + if prompt is not None: + prompt_info = json.dumps(prompt) + + metadata = {} + + enable_modelspec = True + if isinstance(model.model, comfy.model_base.SDXL): + if isinstance(model.model, comfy.model_base.SDXL_instructpix2pix): + metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-edit" + else: + metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base" + elif isinstance(model.model, comfy.model_base.SDXLRefiner): + metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner" + elif isinstance(model.model, comfy.model_base.SVD_img2vid): + metadata["modelspec.architecture"] = "stable-video-diffusion-img2vid-v1" + elif isinstance(model.model, comfy.model_base.SD3): + metadata["modelspec.architecture"] = "stable-diffusion-v3-medium" #TODO: other SD3 variants + else: + enable_modelspec = False + + if enable_modelspec: + metadata["modelspec.sai_model_spec"] = "1.0.0" + metadata["modelspec.implementation"] = "sgm" + metadata["modelspec.title"] = "{} {}".format(filename, counter) + + #TODO: + # "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512", + # "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h", + # "v2-inpainting" + + extra_keys = {} + model_sampling = model.get_model_object("model_sampling") + if isinstance(model_sampling, comfy.model_sampling.ModelSamplingContinuousEDM): + if isinstance(model_sampling, comfy.model_sampling.V_PREDICTION): + extra_keys["edm_vpred.sigma_max"] = torch.tensor(model_sampling.sigma_max).float() + extra_keys["edm_vpred.sigma_min"] = torch.tensor(model_sampling.sigma_min).float() + + if model.model.model_type == comfy.model_base.ModelType.EPS: + metadata["modelspec.predict_key"] = "epsilon" + elif model.model.model_type == comfy.model_base.ModelType.V_PREDICTION: + metadata["modelspec.predict_key"] = "v" + + if not args.disable_metadata: + metadata["prompt"] = prompt_info + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata[x] = json.dumps(extra_pnginfo[x]) + + output_checkpoint = f"{filename}_{counter:05}_.safetensors" + output_checkpoint = os.path.join(full_output_folder, output_checkpoint) + + comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys) + +class CheckpointSave: + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "clip": ("CLIP",), + "vae": ("VAE",), + "filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} + RETURN_TYPES = () + FUNCTION = "save" + OUTPUT_NODE = True + + CATEGORY = "advanced/model_merging" + + def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None): + save_checkpoint(model, clip=clip, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo) + return {} + +class CLIPSave: + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip": ("CLIP",), + "filename_prefix": ("STRING", {"default": "clip/ComfyUI"}),}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} + RETURN_TYPES = () + FUNCTION = "save" + OUTPUT_NODE = True + + CATEGORY = "advanced/model_merging" + + def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None): + prompt_info = "" + if prompt is not None: + prompt_info = json.dumps(prompt) + + metadata = {} + if not args.disable_metadata: + metadata["format"] = "pt" + metadata["prompt"] = prompt_info + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata[x] = json.dumps(extra_pnginfo[x]) + + comfy.model_management.load_models_gpu([clip.load_model()], force_patch_weights=True) + clip_sd = clip.get_sd() + + for prefix in ["clip_l.", "clip_g.", ""]: + k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys())) + current_clip_sd = {} + for x in k: + current_clip_sd[x] = clip_sd.pop(x) + if len(current_clip_sd) == 0: + continue + + p = prefix[:-1] + replace_prefix = {} + filename_prefix_ = filename_prefix + if len(p) > 0: + filename_prefix_ = "{}_{}".format(filename_prefix_, p) + replace_prefix[prefix] = "" + replace_prefix["transformer."] = "" + + full_output_folder, filename, counter, subfolder, filename_prefix_ = folder_paths.get_save_image_path(filename_prefix_, self.output_dir) + + output_checkpoint = f"{filename}_{counter:05}_.safetensors" + output_checkpoint = os.path.join(full_output_folder, output_checkpoint) + + current_clip_sd = comfy.utils.state_dict_prefix_replace(current_clip_sd, replace_prefix) + + comfy.utils.save_torch_file(current_clip_sd, output_checkpoint, metadata=metadata) + return {} + +class VAESave: + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + + @classmethod + def INPUT_TYPES(s): + return {"required": { "vae": ("VAE",), + "filename_prefix": ("STRING", {"default": "vae/ComfyUI_vae"}),}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} + RETURN_TYPES = () + FUNCTION = "save" + OUTPUT_NODE = True + + CATEGORY = "advanced/model_merging" + + def save(self, vae, filename_prefix, prompt=None, extra_pnginfo=None): + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) + prompt_info = "" + if prompt is not None: + prompt_info = json.dumps(prompt) + + metadata = {} + if not args.disable_metadata: + metadata["prompt"] = prompt_info + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata[x] = json.dumps(extra_pnginfo[x]) + + output_checkpoint = f"{filename}_{counter:05}_.safetensors" + output_checkpoint = os.path.join(full_output_folder, output_checkpoint) + + comfy.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata) + return {} + +class ModelSave: + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "filename_prefix": ("STRING", {"default": "diffusion_models/ComfyUI"}),}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} + RETURN_TYPES = () + FUNCTION = "save" + OUTPUT_NODE = True + + CATEGORY = "advanced/model_merging" + + def save(self, model, filename_prefix, prompt=None, extra_pnginfo=None): + save_checkpoint(model, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo) + return {} + +NODE_CLASS_MAPPINGS = { + "ModelMergeSimple": ModelMergeSimple, + "ModelMergeBlocks": ModelMergeBlocks, + "ModelMergeSubtract": ModelSubtract, + "ModelMergeAdd": ModelAdd, + "CheckpointSave": CheckpointSave, + "CLIPMergeSimple": CLIPMergeSimple, + "CLIPMergeSubtract": CLIPSubtract, + "CLIPMergeAdd": CLIPAdd, + "CLIPSave": CLIPSave, + "VAESave": VAESave, + "ModelSave": ModelSave, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "CheckpointSave": "Save Checkpoint", +} diff --git a/comfy_extras/nodes_model_merging_model_specific.py b/comfy_extras/nodes_model_merging_model_specific.py new file mode 100644 index 0000000000000000000000000000000000000000..9557d9b1fa225b9cd62b24d0678dafafbc20d19b --- /dev/null +++ b/comfy_extras/nodes_model_merging_model_specific.py @@ -0,0 +1,110 @@ +import comfy_extras.nodes_model_merging + +class ModelMergeSD1(comfy_extras.nodes_model_merging.ModelMergeBlocks): + CATEGORY = "advanced/model_merging/model_specific" + @classmethod + def INPUT_TYPES(s): + arg_dict = { "model1": ("MODEL",), + "model2": ("MODEL",)} + + argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + + arg_dict["time_embed."] = argument + arg_dict["label_emb."] = argument + + for i in range(12): + arg_dict["input_blocks.{}.".format(i)] = argument + + for i in range(3): + arg_dict["middle_block.{}.".format(i)] = argument + + for i in range(12): + arg_dict["output_blocks.{}.".format(i)] = argument + + arg_dict["out."] = argument + + return {"required": arg_dict} + + +class ModelMergeSDXL(comfy_extras.nodes_model_merging.ModelMergeBlocks): + CATEGORY = "advanced/model_merging/model_specific" + + @classmethod + def INPUT_TYPES(s): + arg_dict = { "model1": ("MODEL",), + "model2": ("MODEL",)} + + argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + + arg_dict["time_embed."] = argument + arg_dict["label_emb."] = argument + + for i in range(9): + arg_dict["input_blocks.{}".format(i)] = argument + + for i in range(3): + arg_dict["middle_block.{}".format(i)] = argument + + for i in range(9): + arg_dict["output_blocks.{}".format(i)] = argument + + arg_dict["out."] = argument + + return {"required": arg_dict} + +class ModelMergeSD3_2B(comfy_extras.nodes_model_merging.ModelMergeBlocks): + CATEGORY = "advanced/model_merging/model_specific" + + @classmethod + def INPUT_TYPES(s): + arg_dict = { "model1": ("MODEL",), + "model2": ("MODEL",)} + + argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + + arg_dict["pos_embed."] = argument + arg_dict["x_embedder."] = argument + arg_dict["context_embedder."] = argument + arg_dict["y_embedder."] = argument + arg_dict["t_embedder."] = argument + + for i in range(24): + arg_dict["joint_blocks.{}.".format(i)] = argument + + arg_dict["final_layer."] = argument + + return {"required": arg_dict} + +class ModelMergeFlux1(comfy_extras.nodes_model_merging.ModelMergeBlocks): + CATEGORY = "advanced/model_merging/model_specific" + + @classmethod + def INPUT_TYPES(s): + arg_dict = { "model1": ("MODEL",), + "model2": ("MODEL",)} + + argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + + arg_dict["img_in."] = argument + arg_dict["time_in."] = argument + arg_dict["guidance_in"] = argument + arg_dict["vector_in."] = argument + arg_dict["txt_in."] = argument + + for i in range(19): + arg_dict["double_blocks.{}.".format(i)] = argument + + for i in range(38): + arg_dict["single_blocks.{}.".format(i)] = argument + + arg_dict["final_layer."] = argument + + return {"required": arg_dict} + +NODE_CLASS_MAPPINGS = { + "ModelMergeSD1": ModelMergeSD1, + "ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks + "ModelMergeSDXL": ModelMergeSDXL, + "ModelMergeSD3_2B": ModelMergeSD3_2B, + "ModelMergeFlux1": ModelMergeFlux1, +} diff --git a/comfy_extras/nodes_morphology.py b/comfy_extras/nodes_morphology.py new file mode 100644 index 0000000000000000000000000000000000000000..071521d879afe8f5cc956bd8b80d5dddd3f23d1a --- /dev/null +++ b/comfy_extras/nodes_morphology.py @@ -0,0 +1,49 @@ +import torch +import comfy.model_management + +from kornia.morphology import dilation, erosion, opening, closing, gradient, top_hat, bottom_hat + + +class Morphology: + @classmethod + def INPUT_TYPES(s): + return {"required": {"image": ("IMAGE",), + "operation": (["erode", "dilate", "open", "close", "gradient", "bottom_hat", "top_hat"],), + "kernel_size": ("INT", {"default": 3, "min": 3, "max": 999, "step": 1}), + }} + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "process" + + CATEGORY = "image/postprocessing" + + def process(self, image, operation, kernel_size): + device = comfy.model_management.get_torch_device() + kernel = torch.ones(kernel_size, kernel_size, device=device) + image_k = image.to(device).movedim(-1, 1) + if operation == "erode": + output = erosion(image_k, kernel) + elif operation == "dilate": + output = dilation(image_k, kernel) + elif operation == "open": + output = opening(image_k, kernel) + elif operation == "close": + output = closing(image_k, kernel) + elif operation == "gradient": + output = gradient(image_k, kernel) + elif operation == "top_hat": + output = top_hat(image_k, kernel) + elif operation == "bottom_hat": + output = bottom_hat(image_k, kernel) + else: + raise ValueError(f"Invalid operation {operation} for morphology. Must be one of 'erode', 'dilate', 'open', 'close', 'gradient', 'tophat', 'bottomhat'") + img_out = output.to(comfy.model_management.intermediate_device()).movedim(1, -1) + return (img_out,) + +NODE_CLASS_MAPPINGS = { + "Morphology": Morphology, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "Morphology": "ImageMorphology", +} \ No newline at end of file diff --git a/comfy_extras/nodes_pag.py b/comfy_extras/nodes_pag.py new file mode 100644 index 0000000000000000000000000000000000000000..eb28196f41c56fd45fda051a42d0814b96558fb8 --- /dev/null +++ b/comfy_extras/nodes_pag.py @@ -0,0 +1,56 @@ +#Modified/simplified version of the node from: https://github.com/pamparamm/sd-perturbed-attention +#If you want the one with more options see the above repo. + +#My modified one here is more basic but has less chances of breaking with ComfyUI updates. + +import comfy.model_patcher +import comfy.samplers + +class PerturbedAttentionGuidance: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": 0.01}), + } + } + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "model_patches/unet" + + def patch(self, model, scale): + unet_block = "middle" + unet_block_id = 0 + m = model.clone() + + def perturbed_attention(q, k, v, extra_options, mask=None): + return v + + def post_cfg_function(args): + model = args["model"] + cond_pred = args["cond_denoised"] + cond = args["cond"] + cfg_result = args["denoised"] + sigma = args["sigma"] + model_options = args["model_options"].copy() + x = args["input"] + + if scale == 0: + return cfg_result + + # Replace Self-attention with PAG + model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, perturbed_attention, "attn1", unet_block, unet_block_id) + (pag,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options) + + return cfg_result + (cond_pred - pag) * scale + + m.set_model_sampler_post_cfg_function(post_cfg_function) + + return (m,) + +NODE_CLASS_MAPPINGS = { + "PerturbedAttentionGuidance": PerturbedAttentionGuidance, +} diff --git a/comfy_extras/nodes_perpneg.py b/comfy_extras/nodes_perpneg.py new file mode 100644 index 0000000000000000000000000000000000000000..762c402202d4fe245864ba67c6f68a8954a1d0ec --- /dev/null +++ b/comfy_extras/nodes_perpneg.py @@ -0,0 +1,129 @@ +import torch +import comfy.model_management +import comfy.sampler_helpers +import comfy.samplers +import comfy.utils +import node_helpers + +def perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale): + pos = noise_pred_pos - noise_pred_nocond + neg = noise_pred_neg - noise_pred_nocond + + perp = neg - ((torch.mul(neg, pos).sum())/(torch.norm(pos)**2)) * pos + perp_neg = perp * neg_scale + cfg_result = noise_pred_nocond + cond_scale*(pos - perp_neg) + return cfg_result + +#TODO: This node should be removed, it has been replaced with PerpNegGuider +class PerpNeg: + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL", ), + "empty_conditioning": ("CONDITIONING", ), + "neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing" + DEPRECATED = True + + def patch(self, model, empty_conditioning, neg_scale): + m = model.clone() + nocond = comfy.sampler_helpers.convert_cond(empty_conditioning) + + def cfg_function(args): + model = args["model"] + noise_pred_pos = args["cond_denoised"] + noise_pred_neg = args["uncond_denoised"] + cond_scale = args["cond_scale"] + x = args["input"] + sigma = args["sigma"] + model_options = args["model_options"] + nocond_processed = comfy.samplers.encode_model_conds(model.extra_conds, nocond, x, x.device, "negative") + + (noise_pred_nocond,) = comfy.samplers.calc_cond_batch(model, [nocond_processed], x, sigma, model_options) + + cfg_result = x - perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_nocond, neg_scale, cond_scale) + return cfg_result + + m.set_model_sampler_cfg_function(cfg_function) + + return (m, ) + + +class Guider_PerpNeg(comfy.samplers.CFGGuider): + def set_conds(self, positive, negative, empty_negative_prompt): + empty_negative_prompt = node_helpers.conditioning_set_values(empty_negative_prompt, {"prompt_type": "negative"}) + self.inner_set_conds({"positive": positive, "empty_negative_prompt": empty_negative_prompt, "negative": negative}) + + def set_cfg(self, cfg, neg_scale): + self.cfg = cfg + self.neg_scale = neg_scale + + def predict_noise(self, x, timestep, model_options={}, seed=None): + # in CFGGuider.predict_noise, we call sampling_function(), which uses cfg_function() to compute pos & neg + # but we'd rather do a single batch of sampling pos, neg, and empty, so we call calc_cond_batch([pos,neg,empty]) directly + + positive_cond = self.conds.get("positive", None) + negative_cond = self.conds.get("negative", None) + empty_cond = self.conds.get("empty_negative_prompt", None) + + (noise_pred_pos, noise_pred_neg, noise_pred_empty) = \ + comfy.samplers.calc_cond_batch(self.inner_model, [positive_cond, negative_cond, empty_cond], x, timestep, model_options) + cfg_result = perp_neg(x, noise_pred_pos, noise_pred_neg, noise_pred_empty, self.neg_scale, self.cfg) + + # normally this would be done in cfg_function, but we skipped + # that for efficiency: we can compute the noise predictions in + # a single call to calc_cond_batch() (rather than two) + # so we replicate the hook here + for fn in model_options.get("sampler_post_cfg_function", []): + args = { + "denoised": cfg_result, + "cond": positive_cond, + "uncond": negative_cond, + "model": self.inner_model, + "uncond_denoised": noise_pred_neg, + "cond_denoised": noise_pred_pos, + "sigma": timestep, + "model_options": model_options, + "input": x, + # not in the original call in samplers.py:cfg_function, but made available for future hooks + "empty_cond": empty_cond, + "empty_cond_denoised": noise_pred_empty,} + cfg_result = fn(args) + + return cfg_result + +class PerpNegGuider: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "empty_conditioning": ("CONDITIONING", ), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), + "neg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), + } + } + + RETURN_TYPES = ("GUIDER",) + + FUNCTION = "get_guider" + CATEGORY = "_for_testing" + + def get_guider(self, model, positive, negative, empty_conditioning, cfg, neg_scale): + guider = Guider_PerpNeg(model) + guider.set_conds(positive, negative, empty_conditioning) + guider.set_cfg(cfg, neg_scale) + return (guider,) + +NODE_CLASS_MAPPINGS = { + "PerpNeg": PerpNeg, + "PerpNegGuider": PerpNegGuider, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "PerpNeg": "Perp-Neg (DEPRECATED by PerpNegGuider)", +} diff --git a/comfy_extras/nodes_photomaker.py b/comfy_extras/nodes_photomaker.py new file mode 100644 index 0000000000000000000000000000000000000000..29d127d74afc92131d72c64864186833a5f6ea5d --- /dev/null +++ b/comfy_extras/nodes_photomaker.py @@ -0,0 +1,187 @@ +import torch +import torch.nn as nn +import folder_paths +import comfy.clip_model +import comfy.clip_vision +import comfy.ops + +# code for model from: https://github.com/TencentARC/PhotoMaker/blob/main/photomaker/model.py under Apache License Version 2.0 +VISION_CONFIG_DICT = { + "hidden_size": 1024, + "image_size": 224, + "intermediate_size": 4096, + "num_attention_heads": 16, + "num_channels": 3, + "num_hidden_layers": 24, + "patch_size": 14, + "projection_dim": 768, + "hidden_act": "quick_gelu", +} + +class MLP(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True, operations=comfy.ops): + super().__init__() + if use_residual: + assert in_dim == out_dim + self.layernorm = operations.LayerNorm(in_dim) + self.fc1 = operations.Linear(in_dim, hidden_dim) + self.fc2 = operations.Linear(hidden_dim, out_dim) + self.use_residual = use_residual + self.act_fn = nn.GELU() + + def forward(self, x): + residual = x + x = self.layernorm(x) + x = self.fc1(x) + x = self.act_fn(x) + x = self.fc2(x) + if self.use_residual: + x = x + residual + return x + + +class FuseModule(nn.Module): + def __init__(self, embed_dim, operations): + super().__init__() + self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False, operations=operations) + self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True, operations=operations) + self.layer_norm = operations.LayerNorm(embed_dim) + + def fuse_fn(self, prompt_embeds, id_embeds): + stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1) + stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds + stacked_id_embeds = self.mlp2(stacked_id_embeds) + stacked_id_embeds = self.layer_norm(stacked_id_embeds) + return stacked_id_embeds + + def forward( + self, + prompt_embeds, + id_embeds, + class_tokens_mask, + ) -> torch.Tensor: + # id_embeds shape: [b, max_num_inputs, 1, 2048] + id_embeds = id_embeds.to(prompt_embeds.dtype) + num_inputs = class_tokens_mask.sum().unsqueeze(0) # TODO: check for training case + batch_size, max_num_inputs = id_embeds.shape[:2] + # seq_length: 77 + seq_length = prompt_embeds.shape[1] + # flat_id_embeds shape: [b*max_num_inputs, 1, 2048] + flat_id_embeds = id_embeds.view( + -1, id_embeds.shape[-2], id_embeds.shape[-1] + ) + # valid_id_mask [b*max_num_inputs] + valid_id_mask = ( + torch.arange(max_num_inputs, device=flat_id_embeds.device)[None, :] + < num_inputs[:, None] + ) + valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()] + + prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1]) + class_tokens_mask = class_tokens_mask.view(-1) + valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1]) + # slice out the image token embeddings + image_token_embeds = prompt_embeds[class_tokens_mask] + stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds) + assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}" + prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype)) + updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1) + return updated_prompt_embeds + +class PhotoMakerIDEncoder(comfy.clip_model.CLIPVisionModelProjection): + def __init__(self): + self.load_device = comfy.model_management.text_encoder_device() + offload_device = comfy.model_management.text_encoder_offload_device() + dtype = comfy.model_management.text_encoder_dtype(self.load_device) + + super().__init__(VISION_CONFIG_DICT, dtype, offload_device, comfy.ops.manual_cast) + self.visual_projection_2 = comfy.ops.manual_cast.Linear(1024, 1280, bias=False) + self.fuse_module = FuseModule(2048, comfy.ops.manual_cast) + + def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask): + b, num_inputs, c, h, w = id_pixel_values.shape + id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w) + + shared_id_embeds = self.vision_model(id_pixel_values)[2] + id_embeds = self.visual_projection(shared_id_embeds) + id_embeds_2 = self.visual_projection_2(shared_id_embeds) + + id_embeds = id_embeds.view(b, num_inputs, 1, -1) + id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1) + + id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1) + updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask) + + return updated_prompt_embeds + + +class PhotoMakerLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "photomaker_model_name": (folder_paths.get_filename_list("photomaker"), )}} + + RETURN_TYPES = ("PHOTOMAKER",) + FUNCTION = "load_photomaker_model" + + CATEGORY = "_for_testing/photomaker" + + def load_photomaker_model(self, photomaker_model_name): + photomaker_model_path = folder_paths.get_full_path("photomaker", photomaker_model_name) + photomaker_model = PhotoMakerIDEncoder() + data = comfy.utils.load_torch_file(photomaker_model_path, safe_load=True) + if "id_encoder" in data: + data = data["id_encoder"] + photomaker_model.load_state_dict(data) + return (photomaker_model,) + + +class PhotoMakerEncode: + @classmethod + def INPUT_TYPES(s): + return {"required": { "photomaker": ("PHOTOMAKER",), + "image": ("IMAGE",), + "clip": ("CLIP", ), + "text": ("STRING", {"multiline": True, "dynamicPrompts": True, "default": "photograph of photomaker"}), + }} + + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "apply_photomaker" + + CATEGORY = "_for_testing/photomaker" + + def apply_photomaker(self, photomaker, image, clip, text): + special_token = "photomaker" + pixel_values = comfy.clip_vision.clip_preprocess(image.to(photomaker.load_device)).float() + try: + index = text.split(" ").index(special_token) + 1 + except ValueError: + index = -1 + tokens = clip.tokenize(text, return_word_ids=True) + out_tokens = {} + for k in tokens: + out_tokens[k] = [] + for t in tokens[k]: + f = list(filter(lambda x: x[2] != index, t)) + while len(f) < len(t): + f.append(t[-1]) + out_tokens[k].append(f) + + cond, pooled = clip.encode_from_tokens(out_tokens, return_pooled=True) + + if index > 0: + token_index = index - 1 + num_id_images = 1 + class_tokens_mask = [True if token_index <= i < token_index+num_id_images else False for i in range(77)] + out = photomaker(id_pixel_values=pixel_values.unsqueeze(0), prompt_embeds=cond.to(photomaker.load_device), + class_tokens_mask=torch.tensor(class_tokens_mask, dtype=torch.bool, device=photomaker.load_device).unsqueeze(0)) + else: + out = cond + + return ([[out, {"pooled_output": pooled}]], ) + + +NODE_CLASS_MAPPINGS = { + "PhotoMakerLoader": PhotoMakerLoader, + "PhotoMakerEncode": PhotoMakerEncode, +} + diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..68f6ef51e791c78c16c179a292b70ee4ad39b656 --- /dev/null +++ b/comfy_extras/nodes_post_processing.py @@ -0,0 +1,279 @@ +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image +import math + +import comfy.utils +import comfy.model_management + + +class Blend: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image1": ("IMAGE",), + "image2": ("IMAGE",), + "blend_factor": ("FLOAT", { + "default": 0.5, + "min": 0.0, + "max": 1.0, + "step": 0.01 + }), + "blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light", "difference"],), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "blend_images" + + CATEGORY = "image/postprocessing" + + def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str): + image2 = image2.to(image1.device) + if image1.shape != image2.shape: + image2 = image2.permute(0, 3, 1, 2) + image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center') + image2 = image2.permute(0, 2, 3, 1) + + blended_image = self.blend_mode(image1, image2, blend_mode) + blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor + blended_image = torch.clamp(blended_image, 0, 1) + return (blended_image,) + + def blend_mode(self, img1, img2, mode): + if mode == "normal": + return img2 + elif mode == "multiply": + return img1 * img2 + elif mode == "screen": + return 1 - (1 - img1) * (1 - img2) + elif mode == "overlay": + return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2)) + elif mode == "soft_light": + return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1)) + elif mode == "difference": + return img1 - img2 + else: + raise ValueError(f"Unsupported blend mode: {mode}") + + def g(self, x): + return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x)) + +def gaussian_kernel(kernel_size: int, sigma: float, device=None): + x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size, device=device), torch.linspace(-1, 1, kernel_size, device=device), indexing="ij") + d = torch.sqrt(x * x + y * y) + g = torch.exp(-(d * d) / (2.0 * sigma * sigma)) + return g / g.sum() + +class Blur: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "blur_radius": ("INT", { + "default": 1, + "min": 1, + "max": 31, + "step": 1 + }), + "sigma": ("FLOAT", { + "default": 1.0, + "min": 0.1, + "max": 10.0, + "step": 0.1 + }), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "blur" + + CATEGORY = "image/postprocessing" + + def blur(self, image: torch.Tensor, blur_radius: int, sigma: float): + if blur_radius == 0: + return (image,) + + image = image.to(comfy.model_management.get_torch_device()) + batch_size, height, width, channels = image.shape + + kernel_size = blur_radius * 2 + 1 + kernel = gaussian_kernel(kernel_size, sigma, device=image.device).repeat(channels, 1, 1).unsqueeze(1) + + image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) + padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect') + blurred = F.conv2d(padded_image, kernel, padding=kernel_size // 2, groups=channels)[:,:,blur_radius:-blur_radius, blur_radius:-blur_radius] + blurred = blurred.permute(0, 2, 3, 1) + + return (blurred.to(comfy.model_management.intermediate_device()),) + +class Quantize: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "colors": ("INT", { + "default": 256, + "min": 1, + "max": 256, + "step": 1 + }), + "dither": (["none", "floyd-steinberg", "bayer-2", "bayer-4", "bayer-8", "bayer-16"],), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "quantize" + + CATEGORY = "image/postprocessing" + + def bayer(im, pal_im, order): + def normalized_bayer_matrix(n): + if n == 0: + return np.zeros((1,1), "float32") + else: + q = 4 ** n + m = q * normalized_bayer_matrix(n - 1) + return np.bmat(((m-1.5, m+0.5), (m+1.5, m-0.5))) / q + + num_colors = len(pal_im.getpalette()) // 3 + spread = 2 * 256 / num_colors + bayer_n = int(math.log2(order)) + bayer_matrix = torch.from_numpy(spread * normalized_bayer_matrix(bayer_n) + 0.5) + + result = torch.from_numpy(np.array(im).astype(np.float32)) + tw = math.ceil(result.shape[0] / bayer_matrix.shape[0]) + th = math.ceil(result.shape[1] / bayer_matrix.shape[1]) + tiled_matrix = bayer_matrix.tile(tw, th).unsqueeze(-1) + result.add_(tiled_matrix[:result.shape[0],:result.shape[1]]).clamp_(0, 255) + result = result.to(dtype=torch.uint8) + + im = Image.fromarray(result.cpu().numpy()) + im = im.quantize(palette=pal_im, dither=Image.Dither.NONE) + return im + + def quantize(self, image: torch.Tensor, colors: int, dither: str): + batch_size, height, width, _ = image.shape + result = torch.zeros_like(image) + + for b in range(batch_size): + im = Image.fromarray((image[b] * 255).to(torch.uint8).numpy(), mode='RGB') + + pal_im = im.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836 + + if dither == "none": + quantized_image = im.quantize(palette=pal_im, dither=Image.Dither.NONE) + elif dither == "floyd-steinberg": + quantized_image = im.quantize(palette=pal_im, dither=Image.Dither.FLOYDSTEINBERG) + elif dither.startswith("bayer"): + order = int(dither.split('-')[-1]) + quantized_image = Quantize.bayer(im, pal_im, order) + + quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255 + result[b] = quantized_array + + return (result,) + +class Sharpen: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("IMAGE",), + "sharpen_radius": ("INT", { + "default": 1, + "min": 1, + "max": 31, + "step": 1 + }), + "sigma": ("FLOAT", { + "default": 1.0, + "min": 0.1, + "max": 10.0, + "step": 0.01 + }), + "alpha": ("FLOAT", { + "default": 1.0, + "min": 0.0, + "max": 5.0, + "step": 0.01 + }), + }, + } + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "sharpen" + + CATEGORY = "image/postprocessing" + + def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha: float): + if sharpen_radius == 0: + return (image,) + + batch_size, height, width, channels = image.shape + image = image.to(comfy.model_management.get_torch_device()) + + kernel_size = sharpen_radius * 2 + 1 + kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10) + center = kernel_size // 2 + kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0 + kernel = kernel.repeat(channels, 1, 1).unsqueeze(1) + + tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) + tensor_image = F.pad(tensor_image, (sharpen_radius,sharpen_radius,sharpen_radius,sharpen_radius), 'reflect') + sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius] + sharpened = sharpened.permute(0, 2, 3, 1) + + result = torch.clamp(sharpened, 0, 1) + + return (result.to(comfy.model_management.intermediate_device()),) + +class ImageScaleToTotalPixels: + upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] + crop_methods = ["disabled", "center"] + + @classmethod + def INPUT_TYPES(s): + return {"required": { "image": ("IMAGE",), "upscale_method": (s.upscale_methods,), + "megapixels": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 16.0, "step": 0.01}), + }} + RETURN_TYPES = ("IMAGE",) + FUNCTION = "upscale" + + CATEGORY = "image/upscaling" + + def upscale(self, image, upscale_method, megapixels): + samples = image.movedim(-1,1) + total = int(megapixels * 1024 * 1024) + + scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) + width = round(samples.shape[3] * scale_by) + height = round(samples.shape[2] * scale_by) + + s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled") + s = s.movedim(1,-1) + return (s,) + +NODE_CLASS_MAPPINGS = { + "ImageBlend": Blend, + "ImageBlur": Blur, + "ImageQuantize": Quantize, + "ImageSharpen": Sharpen, + "ImageScaleToTotalPixels": ImageScaleToTotalPixels, +} diff --git a/comfy_extras/nodes_rebatch.py b/comfy_extras/nodes_rebatch.py new file mode 100644 index 0000000000000000000000000000000000000000..3010fbd4b69034399390894d74c1b8cc415b61f0 --- /dev/null +++ b/comfy_extras/nodes_rebatch.py @@ -0,0 +1,138 @@ +import torch + +class LatentRebatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "latents": ("LATENT",), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }} + RETURN_TYPES = ("LATENT",) + INPUT_IS_LIST = True + OUTPUT_IS_LIST = (True, ) + + FUNCTION = "rebatch" + + CATEGORY = "latent/batch" + + @staticmethod + def get_batch(latents, list_ind, offset): + '''prepare a batch out of the list of latents''' + samples = latents[list_ind]['samples'] + shape = samples.shape + mask = latents[list_ind]['noise_mask'] if 'noise_mask' in latents[list_ind] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu') + if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]: + torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear") + if mask.shape[0] < samples.shape[0]: + mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]] + if 'batch_index' in latents[list_ind]: + batch_inds = latents[list_ind]['batch_index'] + else: + batch_inds = [x+offset for x in range(shape[0])] + return samples, mask, batch_inds + + @staticmethod + def get_slices(indexable, num, batch_size): + '''divides an indexable object into num slices of length batch_size, and a remainder''' + slices = [] + for i in range(num): + slices.append(indexable[i*batch_size:(i+1)*batch_size]) + if num * batch_size < len(indexable): + return slices, indexable[num * batch_size:] + else: + return slices, None + + @staticmethod + def slice_batch(batch, num, batch_size): + result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch] + return list(zip(*result)) + + @staticmethod + def cat_batch(batch1, batch2): + if batch1[0] is None: + return batch2 + result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)] + return result + + def rebatch(self, latents, batch_size): + batch_size = batch_size[0] + + output_list = [] + current_batch = (None, None, None) + processed = 0 + + for i in range(len(latents)): + # fetch new entry of list + #samples, masks, indices = self.get_batch(latents, i) + next_batch = self.get_batch(latents, i, processed) + processed += len(next_batch[2]) + # set to current if current is None + if current_batch[0] is None: + current_batch = next_batch + # add previous to list if dimensions do not match + elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]: + sliced, _ = self.slice_batch(current_batch, 1, batch_size) + output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) + current_batch = next_batch + # cat if everything checks out + else: + current_batch = self.cat_batch(current_batch, next_batch) + + # add to list if dimensions gone above target batch size + if current_batch[0].shape[0] > batch_size: + num = current_batch[0].shape[0] // batch_size + sliced, remainder = self.slice_batch(current_batch, num, batch_size) + + for i in range(num): + output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]}) + + current_batch = remainder + + #add remainder + if current_batch[0] is not None: + sliced, _ = self.slice_batch(current_batch, 1, batch_size) + output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]}) + + #get rid of empty masks + for s in output_list: + if s['noise_mask'].mean() == 1.0: + del s['noise_mask'] + + return (output_list,) + +class ImageRebatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "images": ("IMAGE",), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + }} + RETURN_TYPES = ("IMAGE",) + INPUT_IS_LIST = True + OUTPUT_IS_LIST = (True, ) + + FUNCTION = "rebatch" + + CATEGORY = "image/batch" + + def rebatch(self, images, batch_size): + batch_size = batch_size[0] + + output_list = [] + all_images = [] + for img in images: + for i in range(img.shape[0]): + all_images.append(img[i:i+1]) + + for i in range(0, len(all_images), batch_size): + output_list.append(torch.cat(all_images[i:i+batch_size], dim=0)) + + return (output_list,) + +NODE_CLASS_MAPPINGS = { + "RebatchLatents": LatentRebatch, + "RebatchImages": ImageRebatch, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "RebatchLatents": "Rebatch Latents", + "RebatchImages": "Rebatch Images", +} diff --git a/comfy_extras/nodes_sag.py b/comfy_extras/nodes_sag.py new file mode 100644 index 0000000000000000000000000000000000000000..5e15b99e56712f05a64b12e41bbd5304dae1d665 --- /dev/null +++ b/comfy_extras/nodes_sag.py @@ -0,0 +1,169 @@ +import torch +from torch import einsum +import torch.nn.functional as F +import math + +from einops import rearrange, repeat +from comfy.ldm.modules.attention import optimized_attention +import comfy.samplers + +# from comfy/ldm/modules/attention.py +# but modified to return attention scores as well as output +def attention_basic_with_sim(q, k, v, heads, mask=None, attn_precision=None): + b, _, dim_head = q.shape + dim_head //= heads + scale = dim_head ** -0.5 + + h = heads + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, -1, heads, dim_head) + .permute(0, 2, 1, 3) + .reshape(b * heads, -1, dim_head) + .contiguous(), + (q, k, v), + ) + + # force cast to fp32 to avoid overflowing + if attn_precision == torch.float32: + sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale + else: + sim = einsum('b i d, b j d -> b i j', q, k) * scale + + del q, k + + if mask is not None: + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v) + out = ( + out.unsqueeze(0) + .reshape(b, heads, -1, dim_head) + .permute(0, 2, 1, 3) + .reshape(b, -1, heads * dim_head) + ) + return (out, sim) + +def create_blur_map(x0, attn, sigma=3.0, threshold=1.0): + # reshape and GAP the attention map + _, hw1, hw2 = attn.shape + b, _, lh, lw = x0.shape + attn = attn.reshape(b, -1, hw1, hw2) + # Global Average Pool + mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold + ratio = 2**(math.ceil(math.sqrt(lh * lw / hw1)) - 1).bit_length() + mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)] + + # Reshape + mask = ( + mask.reshape(b, *mid_shape) + .unsqueeze(1) + .type(attn.dtype) + ) + # Upsample + mask = F.interpolate(mask, (lh, lw)) + + blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma) + blurred = blurred * mask + x0 * (1 - mask) + return blurred + +def gaussian_blur_2d(img, kernel_size, sigma): + ksize_half = (kernel_size - 1) * 0.5 + + x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) + + pdf = torch.exp(-0.5 * (x / sigma).pow(2)) + + x_kernel = pdf / pdf.sum() + x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) + + kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) + kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) + + padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] + + img = F.pad(img, padding, mode="reflect") + img = F.conv2d(img, kernel2d, groups=img.shape[-3]) + return img + +class SelfAttentionGuidance: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.01}), + "blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing" + + def patch(self, model, scale, blur_sigma): + m = model.clone() + + attn_scores = None + + # TODO: make this work properly with chunked batches + # currently, we can only save the attn from one UNet call + def attn_and_record(q, k, v, extra_options): + nonlocal attn_scores + # if uncond, save the attention scores + heads = extra_options["n_heads"] + cond_or_uncond = extra_options["cond_or_uncond"] + b = q.shape[0] // len(cond_or_uncond) + if 1 in cond_or_uncond: + uncond_index = cond_or_uncond.index(1) + # do the entire attention operation, but save the attention scores to attn_scores + (out, sim) = attention_basic_with_sim(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"]) + # when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn] + n_slices = heads * b + attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)] + return out + else: + return optimized_attention(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"]) + + def post_cfg_function(args): + nonlocal attn_scores + uncond_attn = attn_scores + + sag_scale = scale + sag_sigma = blur_sigma + sag_threshold = 1.0 + model = args["model"] + uncond_pred = args["uncond_denoised"] + uncond = args["uncond"] + cfg_result = args["denoised"] + sigma = args["sigma"] + model_options = args["model_options"] + x = args["input"] + if min(cfg_result.shape[2:]) <= 4: #skip when too small to add padding + return cfg_result + + # create the adversarially blurred image + degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold) + degraded_noised = degraded + x - uncond_pred + # call into the UNet + (sag,) = comfy.samplers.calc_cond_batch(model, [uncond], degraded_noised, sigma, model_options) + return cfg_result + (degraded - sag) * sag_scale + + m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True) + + # from diffusers: + # unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch + m.set_model_attn1_replace(attn_and_record, "middle", 0, 0) + + return (m, ) + +NODE_CLASS_MAPPINGS = { + "SelfAttentionGuidance": SelfAttentionGuidance, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "SelfAttentionGuidance": "Self-Attention Guidance", +} diff --git a/comfy_extras/nodes_sd3.py b/comfy_extras/nodes_sd3.py new file mode 100644 index 0000000000000000000000000000000000000000..046096cba8a3c55fd32b6b3a9da4fdec22c8576a --- /dev/null +++ b/comfy_extras/nodes_sd3.py @@ -0,0 +1,107 @@ +import folder_paths +import comfy.sd +import comfy.model_management +import nodes +import torch + +class TripleCLIPLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ), "clip_name2": (folder_paths.get_filename_list("clip"), ), "clip_name3": (folder_paths.get_filename_list("clip"), ) + }} + RETURN_TYPES = ("CLIP",) + FUNCTION = "load_clip" + + CATEGORY = "advanced/loaders" + + def load_clip(self, clip_name1, clip_name2, clip_name3): + clip_path1 = folder_paths.get_full_path("clip", clip_name1) + clip_path2 = folder_paths.get_full_path("clip", clip_name2) + clip_path3 = folder_paths.get_full_path("clip", clip_name3) + clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings")) + return (clip,) + +class EmptySD3LatentImage: + def __init__(self): + self.device = comfy.model_management.intermediate_device() + + @classmethod + def INPUT_TYPES(s): + return {"required": { "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} + RETURN_TYPES = ("LATENT",) + FUNCTION = "generate" + + CATEGORY = "latent/sd3" + + def generate(self, width, height, batch_size=1): + latent = torch.ones([batch_size, 16, height // 8, width // 8], device=self.device) * 0.0609 + return ({"samples":latent}, ) + +class CLIPTextEncodeSD3: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "clip": ("CLIP", ), + "clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}), + "clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}), + "t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}), + "empty_padding": (["none", "empty_prompt"], ) + }} + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "encode" + + CATEGORY = "advanced/conditioning" + + def encode(self, clip, clip_l, clip_g, t5xxl, empty_padding): + no_padding = empty_padding == "none" + + tokens = clip.tokenize(clip_g) + if len(clip_g) == 0 and no_padding: + tokens["g"] = [] + + if len(clip_l) == 0 and no_padding: + tokens["l"] = [] + else: + tokens["l"] = clip.tokenize(clip_l)["l"] + + if len(t5xxl) == 0 and no_padding: + tokens["t5xxl"] = [] + else: + tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"] + if len(tokens["l"]) != len(tokens["g"]): + empty = clip.tokenize("") + while len(tokens["l"]) < len(tokens["g"]): + tokens["l"] += empty["l"] + while len(tokens["l"]) > len(tokens["g"]): + tokens["g"] += empty["g"] + cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True) + return ([[cond, {"pooled_output": pooled}]], ) + + +class ControlNetApplySD3(nodes.ControlNetApplyAdvanced): + @classmethod + def INPUT_TYPES(s): + return {"required": {"positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "control_net": ("CONTROL_NET", ), + "vae": ("VAE", ), + "image": ("IMAGE", ), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) + }} + CATEGORY = "conditioning/controlnet" + +NODE_CLASS_MAPPINGS = { + "TripleCLIPLoader": TripleCLIPLoader, + "EmptySD3LatentImage": EmptySD3LatentImage, + "CLIPTextEncodeSD3": CLIPTextEncodeSD3, + "ControlNetApplySD3": ControlNetApplySD3, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + # Sampling + "ControlNetApplySD3": "ControlNetApply SD3 and HunyuanDiT", +} diff --git a/comfy_extras/nodes_sdupscale.py b/comfy_extras/nodes_sdupscale.py new file mode 100644 index 0000000000000000000000000000000000000000..bba67e8ddff8064a90ec0f8e71e953ca4e56c4c6 --- /dev/null +++ b/comfy_extras/nodes_sdupscale.py @@ -0,0 +1,46 @@ +import torch +import comfy.utils + +class SD_4XUpscale_Conditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": { "images": ("IMAGE",), + "positive": ("CONDITIONING",), + "negative": ("CONDITIONING",), + "scale_ratio": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "noise_augmentation": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + }} + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + FUNCTION = "encode" + + CATEGORY = "conditioning/upscale_diffusion" + + def encode(self, images, positive, negative, scale_ratio, noise_augmentation): + width = max(1, round(images.shape[-2] * scale_ratio)) + height = max(1, round(images.shape[-3] * scale_ratio)) + + pixels = comfy.utils.common_upscale((images.movedim(-1,1) * 2.0) - 1.0, width // 4, height // 4, "bilinear", "center") + + out_cp = [] + out_cn = [] + + for t in positive: + n = [t[0], t[1].copy()] + n[1]['concat_image'] = pixels + n[1]['noise_augmentation'] = noise_augmentation + out_cp.append(n) + + for t in negative: + n = [t[0], t[1].copy()] + n[1]['concat_image'] = pixels + n[1]['noise_augmentation'] = noise_augmentation + out_cn.append(n) + + latent = torch.zeros([images.shape[0], 4, height // 4, width // 4]) + return (out_cp, out_cn, {"samples":latent}) + +NODE_CLASS_MAPPINGS = { + "SD_4XUpscale_Conditioning": SD_4XUpscale_Conditioning, +} diff --git a/comfy_extras/nodes_stable3d.py b/comfy_extras/nodes_stable3d.py new file mode 100644 index 0000000000000000000000000000000000000000..be2e34c28f49f160a21703c313305193ed00546f --- /dev/null +++ b/comfy_extras/nodes_stable3d.py @@ -0,0 +1,143 @@ +import torch +import nodes +import comfy.utils + +def camera_embeddings(elevation, azimuth): + elevation = torch.as_tensor([elevation]) + azimuth = torch.as_tensor([azimuth]) + embeddings = torch.stack( + [ + torch.deg2rad( + (90 - elevation) - (90) + ), # Zero123 polar is 90-elevation + torch.sin(torch.deg2rad(azimuth)), + torch.cos(torch.deg2rad(azimuth)), + torch.deg2rad( + 90 - torch.full_like(elevation, 0) + ), + ], dim=-1).unsqueeze(1) + + return embeddings + + +class StableZero123_Conditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip_vision": ("CLIP_VISION",), + "init_image": ("IMAGE",), + "vae": ("VAE",), + "width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), + "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), + }} + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + FUNCTION = "encode" + + CATEGORY = "conditioning/3d_models" + + def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth): + output = clip_vision.encode_image(init_image) + pooled = output.image_embeds.unsqueeze(0) + pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) + encode_pixels = pixels[:,:,:,:3] + t = vae.encode(encode_pixels) + cam_embeds = camera_embeddings(elevation, azimuth) + cond = torch.cat([pooled, cam_embeds.to(pooled.device).repeat((pooled.shape[0], 1, 1))], dim=-1) + + positive = [[cond, {"concat_latent_image": t}]] + negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]] + latent = torch.zeros([batch_size, 4, height // 8, width // 8]) + return (positive, negative, {"samples":latent}) + +class StableZero123_Conditioning_Batched: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip_vision": ("CLIP_VISION",), + "init_image": ("IMAGE",), + "vae": ("VAE",), + "width": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 256, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + "elevation": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), + "azimuth": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), + "elevation_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), + "azimuth_batch_increment": ("FLOAT", {"default": 0.0, "min": -180.0, "max": 180.0, "step": 0.1, "round": False}), + }} + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + FUNCTION = "encode" + + CATEGORY = "conditioning/3d_models" + + def encode(self, clip_vision, init_image, vae, width, height, batch_size, elevation, azimuth, elevation_batch_increment, azimuth_batch_increment): + output = clip_vision.encode_image(init_image) + pooled = output.image_embeds.unsqueeze(0) + pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) + encode_pixels = pixels[:,:,:,:3] + t = vae.encode(encode_pixels) + + cam_embeds = [] + for i in range(batch_size): + cam_embeds.append(camera_embeddings(elevation, azimuth)) + elevation += elevation_batch_increment + azimuth += azimuth_batch_increment + + cam_embeds = torch.cat(cam_embeds, dim=0) + cond = torch.cat([comfy.utils.repeat_to_batch_size(pooled, batch_size), cam_embeds], dim=-1) + + positive = [[cond, {"concat_latent_image": t}]] + negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t)}]] + latent = torch.zeros([batch_size, 4, height // 8, width // 8]) + return (positive, negative, {"samples":latent, "batch_index": [0] * batch_size}) + +class SV3D_Conditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip_vision": ("CLIP_VISION",), + "init_image": ("IMAGE",), + "vae": ("VAE",), + "width": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "video_frames": ("INT", {"default": 21, "min": 1, "max": 4096}), + "elevation": ("FLOAT", {"default": 0.0, "min": -90.0, "max": 90.0, "step": 0.1, "round": False}), + }} + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + FUNCTION = "encode" + + CATEGORY = "conditioning/3d_models" + + def encode(self, clip_vision, init_image, vae, width, height, video_frames, elevation): + output = clip_vision.encode_image(init_image) + pooled = output.image_embeds.unsqueeze(0) + pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) + encode_pixels = pixels[:,:,:,:3] + t = vae.encode(encode_pixels) + + azimuth = 0 + azimuth_increment = 360 / (max(video_frames, 2) - 1) + + elevations = [] + azimuths = [] + for i in range(video_frames): + elevations.append(elevation) + azimuths.append(azimuth) + azimuth += azimuth_increment + + positive = [[pooled, {"concat_latent_image": t, "elevation": elevations, "azimuth": azimuths}]] + negative = [[torch.zeros_like(pooled), {"concat_latent_image": torch.zeros_like(t), "elevation": elevations, "azimuth": azimuths}]] + latent = torch.zeros([video_frames, 4, height // 8, width // 8]) + return (positive, negative, {"samples":latent}) + + +NODE_CLASS_MAPPINGS = { + "StableZero123_Conditioning": StableZero123_Conditioning, + "StableZero123_Conditioning_Batched": StableZero123_Conditioning_Batched, + "SV3D_Conditioning": SV3D_Conditioning, +} diff --git a/comfy_extras/nodes_stable_cascade.py b/comfy_extras/nodes_stable_cascade.py new file mode 100644 index 0000000000000000000000000000000000000000..fcbbeb27fa96eb31514eb93773c50228dcb06d4e --- /dev/null +++ b/comfy_extras/nodes_stable_cascade.py @@ -0,0 +1,140 @@ +""" + This file is part of ComfyUI. + Copyright (C) 2024 Stability AI + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +""" + +import torch +import nodes +import comfy.utils + + +class StableCascade_EmptyLatentImage: + def __init__(self, device="cpu"): + self.device = device + + @classmethod + def INPUT_TYPES(s): + return {"required": { + "width": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}), + "compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}) + }} + RETURN_TYPES = ("LATENT", "LATENT") + RETURN_NAMES = ("stage_c", "stage_b") + FUNCTION = "generate" + + CATEGORY = "latent/stable_cascade" + + def generate(self, width, height, compression, batch_size=1): + c_latent = torch.zeros([batch_size, 16, height // compression, width // compression]) + b_latent = torch.zeros([batch_size, 4, height // 4, width // 4]) + return ({ + "samples": c_latent, + }, { + "samples": b_latent, + }) + +class StableCascade_StageC_VAEEncode: + def __init__(self, device="cpu"): + self.device = device + + @classmethod + def INPUT_TYPES(s): + return {"required": { + "image": ("IMAGE",), + "vae": ("VAE", ), + "compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}), + }} + RETURN_TYPES = ("LATENT", "LATENT") + RETURN_NAMES = ("stage_c", "stage_b") + FUNCTION = "generate" + + CATEGORY = "latent/stable_cascade" + + def generate(self, image, vae, compression): + width = image.shape[-2] + height = image.shape[-3] + out_width = (width // compression) * vae.downscale_ratio + out_height = (height // compression) * vae.downscale_ratio + + s = comfy.utils.common_upscale(image.movedim(-1,1), out_width, out_height, "bicubic", "center").movedim(1,-1) + + c_latent = vae.encode(s[:,:,:,:3]) + b_latent = torch.zeros([c_latent.shape[0], 4, (height // 8) * 2, (width // 8) * 2]) + return ({ + "samples": c_latent, + }, { + "samples": b_latent, + }) + +class StableCascade_StageB_Conditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": { "conditioning": ("CONDITIONING",), + "stage_c": ("LATENT",), + }} + RETURN_TYPES = ("CONDITIONING",) + + FUNCTION = "set_prior" + + CATEGORY = "conditioning/stable_cascade" + + def set_prior(self, conditioning, stage_c): + c = [] + for t in conditioning: + d = t[1].copy() + d['stable_cascade_prior'] = stage_c['samples'] + n = [t[0], d] + c.append(n) + return (c, ) + +class StableCascade_SuperResolutionControlnet: + def __init__(self, device="cpu"): + self.device = device + + @classmethod + def INPUT_TYPES(s): + return {"required": { + "image": ("IMAGE",), + "vae": ("VAE", ), + }} + RETURN_TYPES = ("IMAGE", "LATENT", "LATENT") + RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b") + FUNCTION = "generate" + + CATEGORY = "_for_testing/stable_cascade" + + def generate(self, image, vae): + width = image.shape[-2] + height = image.shape[-3] + batch_size = image.shape[0] + controlnet_input = vae.encode(image[:,:,:,:3]).movedim(1, -1) + + c_latent = torch.zeros([batch_size, 16, height // 16, width // 16]) + b_latent = torch.zeros([batch_size, 4, height // 2, width // 2]) + return (controlnet_input, { + "samples": c_latent, + }, { + "samples": b_latent, + }) + +NODE_CLASS_MAPPINGS = { + "StableCascade_EmptyLatentImage": StableCascade_EmptyLatentImage, + "StableCascade_StageB_Conditioning": StableCascade_StageB_Conditioning, + "StableCascade_StageC_VAEEncode": StableCascade_StageC_VAEEncode, + "StableCascade_SuperResolutionControlnet": StableCascade_SuperResolutionControlnet, +} diff --git a/comfy_extras/nodes_tomesd.py b/comfy_extras/nodes_tomesd.py new file mode 100644 index 0000000000000000000000000000000000000000..df0485063e68d3a38eac4755600748e096426231 --- /dev/null +++ b/comfy_extras/nodes_tomesd.py @@ -0,0 +1,177 @@ +#Taken from: https://github.com/dbolya/tomesd + +import torch +from typing import Tuple, Callable +import math + +def do_nothing(x: torch.Tensor, mode:str=None): + return x + + +def mps_gather_workaround(input, dim, index): + if input.shape[-1] == 1: + return torch.gather( + input.unsqueeze(-1), + dim - 1 if dim < 0 else dim, + index.unsqueeze(-1) + ).squeeze(-1) + else: + return torch.gather(input, dim, index) + + +def bipartite_soft_matching_random2d(metric: torch.Tensor, + w: int, h: int, sx: int, sy: int, r: int, + no_rand: bool = False) -> Tuple[Callable, Callable]: + """ + Partitions the tokens into src and dst and merges r tokens from src to dst. + Dst tokens are partitioned by choosing one randomy in each (sx, sy) region. + Args: + - metric [B, N, C]: metric to use for similarity + - w: image width in tokens + - h: image height in tokens + - sx: stride in the x dimension for dst, must divide w + - sy: stride in the y dimension for dst, must divide h + - r: number of tokens to remove (by merging) + - no_rand: if true, disable randomness (use top left corner only) + """ + B, N, _ = metric.shape + + if r <= 0 or w == 1 or h == 1: + return do_nothing, do_nothing + + gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather + + with torch.no_grad(): + + hsy, wsx = h // sy, w // sx + + # For each sy by sx kernel, randomly assign one token to be dst and the rest src + if no_rand: + rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64) + else: + rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device) + + # The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead + idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64) + idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype)) + idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx) + + # Image is not divisible by sx or sy so we need to move it into a new buffer + if (hsy * sy) < h or (wsx * sx) < w: + idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64) + idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view + else: + idx_buffer = idx_buffer_view + + # We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices + rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1) + + # We're finished with these + del idx_buffer, idx_buffer_view + + # rand_idx is currently dst|src, so split them + num_dst = hsy * wsx + a_idx = rand_idx[:, num_dst:, :] # src + b_idx = rand_idx[:, :num_dst, :] # dst + + def split(x): + C = x.shape[-1] + src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C)) + dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C)) + return src, dst + + # Cosine similarity between A and B + metric = metric / metric.norm(dim=-1, keepdim=True) + a, b = split(metric) + scores = a @ b.transpose(-1, -2) + + # Can't reduce more than the # tokens in src + r = min(a.shape[1], r) + + # Find the most similar greedily + node_max, node_idx = scores.max(dim=-1) + edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] + + unm_idx = edge_idx[..., r:, :] # Unmerged Tokens + src_idx = edge_idx[..., :r, :] # Merged Tokens + dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) + + def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: + src, dst = split(x) + n, t1, c = src.shape + + unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c)) + src = gather(src, dim=-2, index=src_idx.expand(n, r, c)) + dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode) + + return torch.cat([unm, dst], dim=1) + + def unmerge(x: torch.Tensor) -> torch.Tensor: + unm_len = unm_idx.shape[1] + unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] + _, _, c = unm.shape + + src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c)) + + # Combine back to the original shape + out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype) + out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst) + out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm) + out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src) + + return out + + return merge, unmerge + + +def get_functions(x, ratio, original_shape): + b, c, original_h, original_w = original_shape + original_tokens = original_h * original_w + downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1]))) + stride_x = 2 + stride_y = 2 + max_downsample = 1 + + if downsample <= max_downsample: + w = int(math.ceil(original_w / downsample)) + h = int(math.ceil(original_h / downsample)) + r = int(x.shape[1] * ratio) + no_rand = False + m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand) + return m, u + + nothing = lambda y: y + return nothing, nothing + + + +class TomePatchModel: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing" + + def patch(self, model, ratio): + self.u = None + def tomesd_m(q, k, v, extra_options): + #NOTE: In the reference code get_functions takes x (input of the transformer block) as the argument instead of q + #however from my basic testing it seems that using q instead gives better results + m, self.u = get_functions(q, ratio, extra_options["original_shape"]) + return m(q), k, v + def tomesd_u(n, extra_options): + return self.u(n) + + m = model.clone() + m.set_model_attn1_patch(tomesd_m) + m.set_model_attn1_output_patch(tomesd_u) + return (m, ) + + +NODE_CLASS_MAPPINGS = { + "TomePatchModel": TomePatchModel, +} diff --git a/comfy_extras/nodes_torch_compile.py b/comfy_extras/nodes_torch_compile.py new file mode 100644 index 0000000000000000000000000000000000000000..1d914fa93e5ce68d85c90f285830e3141822b6a0 --- /dev/null +++ b/comfy_extras/nodes_torch_compile.py @@ -0,0 +1,21 @@ +import torch + +class TorchCompileModel: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing" + EXPERIMENTAL = True + + def patch(self, model): + m = model.clone() + m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"))) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "TorchCompileModel": TorchCompileModel, +} diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py new file mode 100644 index 0000000000000000000000000000000000000000..bca79ef2e1395471e52032f079f1edacf396d6d0 --- /dev/null +++ b/comfy_extras/nodes_upscale_model.py @@ -0,0 +1,84 @@ +import os +import logging +from spandrel import ModelLoader, ImageModelDescriptor +from comfy import model_management +import torch +import comfy.utils +import folder_paths + +try: + from spandrel_extra_arches import EXTRA_REGISTRY + from spandrel import MAIN_REGISTRY + MAIN_REGISTRY.add(*EXTRA_REGISTRY) + logging.info("Successfully imported spandrel_extra_arches: support for non commercial upscale models.") +except: + pass + +class UpscaleModelLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model_name": (folder_paths.get_filename_list("upscale_models"), ), + }} + RETURN_TYPES = ("UPSCALE_MODEL",) + FUNCTION = "load_model" + + CATEGORY = "loaders" + + def load_model(self, model_name): + model_path = folder_paths.get_full_path("upscale_models", model_name) + sd = comfy.utils.load_torch_file(model_path, safe_load=True) + if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd: + sd = comfy.utils.state_dict_prefix_replace(sd, {"module.":""}) + out = ModelLoader().load_from_state_dict(sd).eval() + + if not isinstance(out, ImageModelDescriptor): + raise Exception("Upscale model must be a single-image model.") + + return (out, ) + + +class ImageUpscaleWithModel: + @classmethod + def INPUT_TYPES(s): + return {"required": { "upscale_model": ("UPSCALE_MODEL",), + "image": ("IMAGE",), + }} + RETURN_TYPES = ("IMAGE",) + FUNCTION = "upscale" + + CATEGORY = "image/upscaling" + + def upscale(self, upscale_model, image): + device = model_management.get_torch_device() + + memory_required = model_management.module_size(upscale_model.model) + memory_required += (512 * 512 * 3) * image.element_size() * max(upscale_model.scale, 1.0) * 384.0 #The 384.0 is an estimate of how much some of these models take, TODO: make it more accurate + memory_required += image.nelement() * image.element_size() + model_management.free_memory(memory_required, device) + + upscale_model.to(device) + in_img = image.movedim(-1,-3).to(device) + + tile = 512 + overlap = 32 + + oom = True + while oom: + try: + steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) + pbar = comfy.utils.ProgressBar(steps) + s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) + oom = False + except model_management.OOM_EXCEPTION as e: + tile //= 2 + if tile < 128: + raise e + + upscale_model.to("cpu") + s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0) + return (s,) + +NODE_CLASS_MAPPINGS = { + "UpscaleModelLoader": UpscaleModelLoader, + "ImageUpscaleWithModel": ImageUpscaleWithModel +} diff --git a/comfy_extras/nodes_video_model.py b/comfy_extras/nodes_video_model.py new file mode 100644 index 0000000000000000000000000000000000000000..1a0189ed4a8b6eb774cb66dcc02d8e687afd349f --- /dev/null +++ b/comfy_extras/nodes_video_model.py @@ -0,0 +1,134 @@ +import nodes +import torch +import comfy.utils +import comfy.sd +import folder_paths +import comfy_extras.nodes_model_merging + + +class ImageOnlyCheckpointLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), + }} + RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE") + FUNCTION = "load_checkpoint" + + CATEGORY = "loaders/video_models" + + def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): + ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) + out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) + return (out[0], out[3], out[2]) + + +class SVD_img2vid_Conditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip_vision": ("CLIP_VISION",), + "init_image": ("IMAGE",), + "vae": ("VAE",), + "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "video_frames": ("INT", {"default": 14, "min": 1, "max": 4096}), + "motion_bucket_id": ("INT", {"default": 127, "min": 1, "max": 1023}), + "fps": ("INT", {"default": 6, "min": 1, "max": 1024}), + "augmentation_level": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01}) + }} + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level): + output = clip_vision.encode_image(init_image) + pooled = output.image_embeds.unsqueeze(0) + pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) + encode_pixels = pixels[:,:,:,:3] + if augmentation_level > 0: + encode_pixels += torch.randn_like(pixels) * augmentation_level + t = vae.encode(encode_pixels) + positive = [[pooled, {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": t}]] + negative = [[torch.zeros_like(pooled), {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": torch.zeros_like(t)}]] + latent = torch.zeros([video_frames, 4, height // 8, width // 8]) + return (positive, negative, {"samples":latent}) + +class VideoLinearCFGGuidance: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "sampling/video_models" + + def patch(self, model, min_cfg): + def linear_cfg(args): + cond = args["cond"] + uncond = args["uncond"] + cond_scale = args["cond_scale"] + + scale = torch.linspace(min_cfg, cond_scale, cond.shape[0], device=cond.device).reshape((cond.shape[0], 1, 1, 1)) + return uncond + scale * (cond - uncond) + + m = model.clone() + m.set_model_sampler_cfg_function(linear_cfg) + return (m, ) + +class VideoTriangleCFGGuidance: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "sampling/video_models" + + def patch(self, model, min_cfg): + def linear_cfg(args): + cond = args["cond"] + uncond = args["uncond"] + cond_scale = args["cond_scale"] + period = 1.0 + values = torch.linspace(0, 1, cond.shape[0], device=cond.device) + values = 2 * (values / period - torch.floor(values / period + 0.5)).abs() + scale = (values * (cond_scale - min_cfg) + min_cfg).reshape((cond.shape[0], 1, 1, 1)) + + return uncond + scale * (cond - uncond) + + m = model.clone() + m.set_model_sampler_cfg_function(linear_cfg) + return (m, ) + +class ImageOnlyCheckpointSave(comfy_extras.nodes_model_merging.CheckpointSave): + CATEGORY = "_for_testing" + + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "clip_vision": ("CLIP_VISION",), + "vae": ("VAE",), + "filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} + + def save(self, model, clip_vision, vae, filename_prefix, prompt=None, extra_pnginfo=None): + comfy_extras.nodes_model_merging.save_checkpoint(model, clip_vision=clip_vision, vae=vae, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo) + return {} + +NODE_CLASS_MAPPINGS = { + "ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader, + "SVD_img2vid_Conditioning": SVD_img2vid_Conditioning, + "VideoLinearCFGGuidance": VideoLinearCFGGuidance, + "VideoTriangleCFGGuidance": VideoTriangleCFGGuidance, + "ImageOnlyCheckpointSave": ImageOnlyCheckpointSave, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "ImageOnlyCheckpointLoader": "Image Only Checkpoint Loader (img2vid model)", +} diff --git a/comfy_extras/nodes_webcam.py b/comfy_extras/nodes_webcam.py new file mode 100644 index 0000000000000000000000000000000000000000..32a0ba2f67b1e9f6a84dc806d47dfcce61ab86b3 --- /dev/null +++ b/comfy_extras/nodes_webcam.py @@ -0,0 +1,33 @@ +import nodes +import folder_paths + +MAX_RESOLUTION = nodes.MAX_RESOLUTION + + +class WebcamCapture(nodes.LoadImage): + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "image": ("WEBCAM", {}), + "width": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "height": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "capture_on_queue": ("BOOLEAN", {"default": True}), + } + } + RETURN_TYPES = ("IMAGE",) + FUNCTION = "load_capture" + + CATEGORY = "image" + + def load_capture(s, image, **kwargs): + return super().load_image(folder_paths.get_annotated_filepath(image)) + + +NODE_CLASS_MAPPINGS = { + "WebcamCapture": WebcamCapture, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "WebcamCapture": "Webcam Capture", +} \ No newline at end of file diff --git a/comfy_version.py b/comfy_version.py new file mode 100644 index 0000000000000000000000000000000000000000..25ae4eaed93d11fa648af22d120915c80a89cdbe --- /dev/null +++ b/comfy_version.py @@ -0,0 +1 @@ +version = '1dba801' diff --git a/comfyui_screenshot.png b/comfyui_screenshot.png new file mode 100644 index 0000000000000000000000000000000000000000..bde0decd2aebac342c6e12a161f8e2bdfdc15aed --- /dev/null +++ b/comfyui_screenshot.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:95d812d1c6696b8816800f657f238473faff7b76f36cacba51fc1cbd51d6ac28 +size 118577 diff --git a/cuda_malloc.py b/cuda_malloc.py new file mode 100644 index 0000000000000000000000000000000000000000..eb2857c5fe2db47c8e86fa64d3c53dcd7574408f --- /dev/null +++ b/cuda_malloc.py @@ -0,0 +1,90 @@ +import os +import importlib.util +from comfy.cli_args import args +import subprocess + +#Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import. +def get_gpu_names(): + if os.name == 'nt': + import ctypes + + # Define necessary C structures and types + class DISPLAY_DEVICEA(ctypes.Structure): + _fields_ = [ + ('cb', ctypes.c_ulong), + ('DeviceName', ctypes.c_char * 32), + ('DeviceString', ctypes.c_char * 128), + ('StateFlags', ctypes.c_ulong), + ('DeviceID', ctypes.c_char * 128), + ('DeviceKey', ctypes.c_char * 128) + ] + + # Load user32.dll + user32 = ctypes.windll.user32 + + # Call EnumDisplayDevicesA + def enum_display_devices(): + device_info = DISPLAY_DEVICEA() + device_info.cb = ctypes.sizeof(device_info) + device_index = 0 + gpu_names = set() + + while user32.EnumDisplayDevicesA(None, device_index, ctypes.byref(device_info), 0): + device_index += 1 + gpu_names.add(device_info.DeviceString.decode('utf-8')) + return gpu_names + return enum_display_devices() + else: + gpu_names = set() + out = subprocess.check_output(['nvidia-smi', '-L']) + for l in out.split(b'\n'): + if len(l) > 0: + gpu_names.add(l.decode('utf-8').split(' (UUID')[0]) + return gpu_names + +blacklist = {"GeForce GTX TITAN X", "GeForce GTX 980", "GeForce GTX 970", "GeForce GTX 960", "GeForce GTX 950", "GeForce 945M", + "GeForce 940M", "GeForce 930M", "GeForce 920M", "GeForce 910M", "GeForce GTX 750", "GeForce GTX 745", "Quadro K620", + "Quadro K1200", "Quadro K2200", "Quadro M500", "Quadro M520", "Quadro M600", "Quadro M620", "Quadro M1000", + "Quadro M1200", "Quadro M2000", "Quadro M2200", "Quadro M3000", "Quadro M4000", "Quadro M5000", "Quadro M5500", "Quadro M6000", + "GeForce MX110", "GeForce MX130", "GeForce 830M", "GeForce 840M", "GeForce GTX 850M", "GeForce GTX 860M", + "GeForce GTX 1650", "GeForce GTX 1630", "Tesla M4", "Tesla M6", "Tesla M10", "Tesla M40", "Tesla M60" + } + +def cuda_malloc_supported(): + try: + names = get_gpu_names() + except: + names = set() + for x in names: + if "NVIDIA" in x: + for b in blacklist: + if b in x: + return False + return True + + +if not args.cuda_malloc: + try: + version = "" + torch_spec = importlib.util.find_spec("torch") + for folder in torch_spec.submodule_search_locations: + ver_file = os.path.join(folder, "version.py") + if os.path.isfile(ver_file): + spec = importlib.util.spec_from_file_location("torch_version_import", ver_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + version = module.__version__ + if int(version[0]) >= 2: #enable by default for torch version 2.0 and up + args.cuda_malloc = cuda_malloc_supported() + except: + pass + + +if args.cuda_malloc and not args.disable_cuda_malloc: + env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None) + if env_var is None: + env_var = "backend:cudaMallocAsync" + else: + env_var += ",backend:cudaMallocAsync" + + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var